implement get_protocol_extension_stream for non 0

This commit is contained in:
Toshit Chawda 2025-02-19 16:41:05 -08:00
parent faf59fa74f
commit 6593ba5783
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
2 changed files with 42 additions and 2 deletions

View file

@ -1,10 +1,12 @@
use std::sync::Arc;
use futures::channel::oneshot;
use bytes::BufMut;
use futures::{channel::oneshot, SinkExt};
use crate::{
locked_sink::LockedWebSocketWrite,
packet::{ClosePacket, CloseReason},
ws::TransportWrite,
ws::{PayloadMut, PayloadRef, TransportWrite},
WispError,
};
@ -41,3 +43,29 @@ impl<W: TransportWrite + 'static> MuxStreamCloser<W> {
self.inner.is_disconnected().then(|| self.info.get_reason())
}
}
/// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream<W: TransportWrite> {
pub(crate) info: Arc<StreamInfo>,
pub(crate) tx: LockedWebSocketWrite<W>,
pub(crate) inner: flume::Sender<WsEvent<W>>,
}
impl<W: TransportWrite> MuxProtocolExtensionStream<W> {
/// Send a protocol extension packet with this stream's ID.
pub async fn send(&mut self, packet_type: u8, data: PayloadRef<'_>) -> Result<(), WispError> {
if self.inner.is_disconnected() {
return Err(WispError::StreamAlreadyClosed);
}
let mut encoded = PayloadMut::with_capacity(1 + 4 + data.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(self.info.id);
encoded.extend(data.as_ref());
self.tx.lock().await;
let ret = self.tx.get().send(encoded.into()).await;
self.tx.unlock();
ret
}
}

View file

@ -166,6 +166,14 @@ impl<W: TransportWrite> MuxStreamWrite<W> {
}
}
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
MuxProtocolExtensionStream {
info: self.info.clone(),
tx: self.write.clone(),
inner: self.inner.sender().clone(),
}
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.inner.is_disconnected() {
@ -307,6 +315,10 @@ impl<W: TransportWrite> MuxStream<W> {
self.write.get_close_handle()
}
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
self.write.get_protocol_extension_stream()
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.write.close(reason).await