From 6593ba57834e581e2a32279acaee53cdcd176e9f Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 19 Feb 2025 16:41:05 -0800 Subject: [PATCH] implement get_protocol_extension_stream for non 0 --- wisp/src/stream/handles.rs | 32 ++++++++++++++++++++++++++++++-- wisp/src/stream/mod.rs | 12 ++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/wisp/src/stream/handles.rs b/wisp/src/stream/handles.rs index c5e5741..a003e26 100644 --- a/wisp/src/stream/handles.rs +++ b/wisp/src/stream/handles.rs @@ -1,10 +1,12 @@ use std::sync::Arc; -use futures::channel::oneshot; +use bytes::BufMut; +use futures::{channel::oneshot, SinkExt}; use crate::{ + locked_sink::LockedWebSocketWrite, packet::{ClosePacket, CloseReason}, - ws::TransportWrite, + ws::{PayloadMut, PayloadRef, TransportWrite}, WispError, }; @@ -41,3 +43,29 @@ impl MuxStreamCloser { self.inner.is_disconnected().then(|| self.info.get_reason()) } } + +/// Stream for sending arbitrary protocol extension packets. +pub struct MuxProtocolExtensionStream { + pub(crate) info: Arc, + pub(crate) tx: LockedWebSocketWrite, + pub(crate) inner: flume::Sender>, +} + +impl MuxProtocolExtensionStream { + /// Send a protocol extension packet with this stream's ID. + pub async fn send(&mut self, packet_type: u8, data: PayloadRef<'_>) -> Result<(), WispError> { + if self.inner.is_disconnected() { + return Err(WispError::StreamAlreadyClosed); + } + + let mut encoded = PayloadMut::with_capacity(1 + 4 + data.len()); + encoded.put_u8(packet_type); + encoded.put_u32_le(self.info.id); + encoded.extend(data.as_ref()); + + self.tx.lock().await; + let ret = self.tx.get().send(encoded.into()).await; + self.tx.unlock(); + ret + } +} diff --git a/wisp/src/stream/mod.rs b/wisp/src/stream/mod.rs index 1497ff8..51d498c 100644 --- a/wisp/src/stream/mod.rs +++ b/wisp/src/stream/mod.rs @@ -166,6 +166,14 @@ impl MuxStreamWrite { } } + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + info: self.info.clone(), + tx: self.write.clone(), + inner: self.inner.sender().clone(), + } + } + /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.inner.is_disconnected() { @@ -307,6 +315,10 @@ impl MuxStream { self.write.get_close_handle() } + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + self.write.get_protocol_extension_stream() + } + /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { self.write.close(reason).await