From 6c41c54cf93f5e8b5a3cab1de11ffcfa2627c957 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 16 Apr 2024 21:57:27 -0700 Subject: [PATCH] add ability to send protocol extension packets --- server/src/main.rs | 12 +++---- simple-wisp-client/src/main.rs | 4 +-- wisp/src/lib.rs | 19 ++++++----- wisp/src/packet.rs | 20 +++++++---- wisp/src/stream.rs | 62 ++++++++++++++++++++++++++++------ 5 files changed, 84 insertions(+), 33 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 623ff09..1910a99 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -253,7 +253,7 @@ async fn accept_http( } } -async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result { +async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result { let uri = format!( "{}:{}", packet.destination_hostname, packet.destination_port @@ -318,8 +318,8 @@ async fn accept_ws( println!("{:?}: connected", addr); // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // size is set to 128 - let (mut mux, fut) = if mux_options.enforce_auth { - let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?; + let (mux, fut) = if mux_options.enforce_auth { + let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?; if !mux .supported_extension_ids .iter() @@ -354,7 +354,7 @@ async fn accept_ws( } }); - while let Some((packet, mut stream)) = mux.server_new_stream().await { + while let Some((packet, stream)) = mux.server_new_stream().await { tokio::spawn(async move { if (mux_options.block_non_http && !(packet.destination_port == 80 || packet.destination_port == 443)) @@ -386,8 +386,8 @@ async fn accept_ws( } } } - let mut close_err = stream.get_close_handle(); - let mut close_ok = stream.get_close_handle(); + let close_err = stream.get_close_handle(); + let close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) .or_else(|err| async move { let _ = close_err.close(CloseReason::Unexpected).await; diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 0b3d2c4..fecbf44 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -164,7 +164,7 @@ async fn main() -> Result<(), Box> { extensions.push(Box::new(auth)); } - let (mut mux, fut) = if opts.wisp_v1 { + let (mux, fut) = if opts.wisp_v1 { ClientMux::new(rx, tx, None).await? } else { ClientMux::new(rx, tx, Some(extensions.as_slice())).await? @@ -212,7 +212,7 @@ async fn main() -> Result<(), Box> { let start_time = Instant::now(); for _ in 0..opts.streams { - let (mut cr, mut cw) = mux + let (cr, cw) = mux .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .await? .into_split(); diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index d68edf0..1cf170f 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -272,6 +272,9 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } + WsEvent::SendBytes(packet, channel) => { + let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await); + } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { let stream_id = next_free_stream_id; @@ -552,11 +555,11 @@ impl ServerMux { } /// Wait for a stream to be created. - pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { + pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { self.muxstream_recv.recv_async().await.ok() } - async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { self.close_tx .send_async(WsEvent::EndFut(reason)) .await @@ -567,7 +570,7 @@ impl ServerMux { /// /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// this function is called. - pub async fn close(&mut self) -> Result<(), WispError> { + pub async fn close(&self) -> Result<(), WispError> { self.close_internal(None).await } @@ -575,7 +578,7 @@ impl ServerMux { /// /// Also terminates the multiplexor future. Waiting for a new stream will never succed after /// this function is called. - pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + pub async fn close_extension_incompat(&self) -> Result<(), WispError> { self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await } @@ -696,7 +699,7 @@ impl ClientMux { /// Create a new stream, multiplexed through Wisp. pub async fn client_new_stream( - &mut self, + &self, stream_type: StreamType, host: String, port: u16, @@ -717,7 +720,7 @@ impl ClientMux { rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } - async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { self.stream_tx .send_async(WsEvent::EndFut(reason)) .await @@ -728,7 +731,7 @@ impl ClientMux { /// /// Also terminates the multiplexor future. Creating a stream is UB after calling this /// function. - pub async fn close(&mut self) -> Result<(), WispError> { + pub async fn close(&self) -> Result<(), WispError> { self.close_internal(None).await } @@ -736,7 +739,7 @@ impl ClientMux { /// /// Also terminates the multiplexor future. Creating a stream is UB after calling this /// function. - pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + pub async fn close_extension_incompat(&self) -> Result<(), WispError> { self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 2328c87..ec933bd 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -362,6 +362,14 @@ impl Packet { } } + pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes { + let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len()); + encoded.put_u8(packet_type); + encoded.put_u32_le(stream_id); + encoded.extend(bytes); + encoded.freeze() + } + fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { use PacketType as P; Ok(Self { @@ -494,13 +502,11 @@ impl TryFrom for Packet { impl From for Bytes { fn from(packet: Packet) -> Self { - let inner_u8 = packet.packet_type.as_u8(); - let inner = Bytes::from(packet.packet_type); - let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len()); - encoded.put_u8(inner_u8); - encoded.put_u32_le(packet.stream_id); - encoded.extend(inner); - encoded.freeze() + Packet::raw_encode( + packet.packet_type.as_u8(), + packet.stream_id, + packet.packet_type.into(), + ) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index e980bec..a7099c9 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -21,6 +21,7 @@ use std::{ pub(crate) enum WsEvent { SendPacket(Packet, oneshot::Sender>), + SendBytes(Bytes, oneshot::Sender>), Close(Packet, oneshot::Sender>), CreateStream( StreamType, @@ -49,7 +50,7 @@ pub struct MuxStreamRead { impl MuxStreamRead { /// Read an event from the stream. - pub async fn read(&mut self) -> Option { + pub async fn read(&self) -> Option { if self.is_closed.load(Ordering::Acquire) { return None; } @@ -79,7 +80,7 @@ impl MuxStreamRead { } pub(crate) fn into_stream(self) -> Pin + Send>> { - Box::pin(stream::unfold(self, |mut rx| async move { + Box::pin(stream::unfold(self, |rx| async move { Some((rx.read().await?, rx)) })) } @@ -100,7 +101,7 @@ pub struct MuxStreamWrite { impl MuxStreamWrite { /// Write data to the stream. - pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -147,8 +148,17 @@ impl MuxStreamWrite { } } + /// Get a protocol extension stream to send protocol extension packets. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: self.stream_id, + tx: self.tx.clone(), + is_closed: self.is_closed.clone(), + } + } + /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -171,12 +181,12 @@ impl MuxStreamWrite { let handle = self.get_close_handle(); Box::pin(sink_unfold::unfold( self, - |mut tx, data| async move { + |tx, data| async move { tx.write(data).await?; Ok(tx) }, handle, - move |mut handle| async { + move |handle| async { handle.close(CloseReason::Unknown).await?; Ok(handle) }, @@ -246,12 +256,12 @@ impl MuxStream { } /// Read an event from the stream. - pub async fn read(&mut self) -> Option { + pub async fn read(&self) -> Option { self.rx.read().await } /// Write data to the stream. - pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { self.tx.write(data).await } @@ -270,8 +280,13 @@ impl MuxStream { self.tx.get_close_handle() } + /// Get a protocol extension stream to send protocol extension packets. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + self.tx.get_protocol_extension_stream() + } + /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { self.tx.close(reason).await } @@ -300,7 +315,7 @@ pub struct MuxStreamCloser { impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -320,6 +335,33 @@ impl MuxStreamCloser { } } +/// Stream for sending arbitrary protocol extension packets. +pub struct MuxProtocolExtensionStream { + /// ID of the stream. + pub stream_id: u32, + tx: mpsc::Sender, + is_closed: Arc, +} + +impl MuxProtocolExtensionStream { + /// Send a protocol extension packet. + pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.tx + .send_async(WsEvent::SendBytes( + Packet::raw_encode(packet_type, self.stream_id, data), + tx, + )) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; + Ok(()) + } +} + pin_project! { /// Multiplexor stream that implements futures `Stream + Sink`. pub struct MuxStreamIo {