From 53a399856ff9d3697b0e9222f878e6f81b0a879c Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 10 Mar 2024 07:57:28 -0700 Subject: [PATCH] fix udp --- wisp/Cargo.toml | 2 +- wisp/src/lib.rs | 30 ++++++----- wisp/src/packet.rs | 3 +- wisp/src/stream.rs | 126 ++++++++++++++++++++++++--------------------- 4 files changed, 86 insertions(+), 75 deletions(-) diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index b41e6ae..aa56b57 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "1.2.1" +version = "1.2.2" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 59796d9..bf07ea2 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,5 +1,4 @@ #![deny(missing_docs)] -#![feature(impl_trait_in_assoc_type)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for easily creating [Wisp] clients and servers. //! @@ -109,6 +108,7 @@ impl std::error::Error for WispError {} struct MuxMapValue { stream: mpsc::UnboundedSender, + stream_type: StreamType, flow_control: Arc, flow_control_event: Arc, } @@ -200,6 +200,7 @@ impl ServerMuxInner { packet.stream_id, MuxMapValue { stream: ch_tx, + stream_type, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), }, @@ -224,13 +225,15 @@ impl ServerMuxInner { Data(data) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.stream.unbounded_send(MuxEvent::Send(data)); - stream.flow_control.store( - stream - .flow_control - .load(Ordering::Acquire) - .saturating_sub(1), - Ordering::Release, - ); + if stream.stream_type == StreamType::Tcp { + stream.flow_control.store( + stream + .flow_control + .load(Ordering::Acquire) + .saturating_sub(1), + Ordering::Release, + ); + } } } Continue(_) => unreachable!(), @@ -388,10 +391,12 @@ impl ClientMuxInner { } Continue(inner_packet) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); + if stream.stream_type == StreamType::Tcp { + stream + .flow_control + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.flow_control_event.notify(u32::MAX); + } } } Close(inner_packet) => { @@ -490,6 +495,7 @@ impl ClientMux { stream_id, MuxMapValue { stream: ch_tx, + stream_type, flow_control: flow_control.clone(), flow_control_event: evt.clone(), }, diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index fd59a4c..7ef129d 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,5 +1,4 @@ -use crate::ws; -use crate::WispError; +use crate::{ws, WispError}; use bytes::{Buf, BufMut, Bytes}; /// Wisp stream type. diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index be6ef87..3364101 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,3 +1,5 @@ +use crate::{sink_unfold, ws, ClosePacket, CloseReason, Packet, Role, StreamType, WispError}; + use async_io_stream::IoStream; use bytes::Bytes; use event_listener::Event; @@ -21,31 +23,31 @@ pub enum MuxEvent { /// The other side has sent data. Send(Bytes), /// The other side has closed. - Close(crate::ClosePacket), + Close(ClosePacket), } pub(crate) enum WsEvent { - Close(u32, crate::CloseReason, oneshot::Sender>), + Close(u32, CloseReason, oneshot::Sender>), EndFut, } /// Read side of a multiplexor stream. pub struct MuxStreamRead where - W: crate::ws::WebSocketWrite, + W: ws::WebSocketWrite, { /// ID of the stream. pub stream_id: u32, /// Type of the stream. - pub stream_type: crate::StreamType, - role: crate::Role, - tx: crate::ws::LockedWebSocketWrite, + pub stream_type: StreamType, + role: Role, + tx: ws::LockedWebSocketWrite, rx: mpsc::UnboundedReceiver, is_closed: Arc, flow_control: Arc, } -impl MuxStreamRead { +impl MuxStreamRead { /// Read an event from the stream. pub async fn read(&mut self) -> Option { if self.is_closed.load(Ordering::Acquire) { @@ -53,12 +55,10 @@ impl MuxStreamRead { } match self.rx.next().await? { MuxEvent::Send(bytes) => { - if self.role == crate::Role::Server && self.stream_type == crate::StreamType::Tcp { + if self.role == Role::Server && self.stream_type == StreamType::Tcp { let old_val = self.flow_control.fetch_add(1, Ordering::AcqRel); self.tx - .write_frame( - crate::Packet::new_continue(self.stream_id, old_val + 1).into(), - ) + .write_frame(Packet::new_continue(self.stream_id, old_val + 1).into()) .await .ok()?; } @@ -88,35 +88,38 @@ impl MuxStreamRead { /// Write side of a multiplexor stream. pub struct MuxStreamWrite where - W: crate::ws::WebSocketWrite, + W: ws::WebSocketWrite, { /// ID of the stream. pub stream_id: u32, - role: crate::Role, - tx: crate::ws::LockedWebSocketWrite, + /// Type of the stream. + pub stream_type: StreamType, + role: Role, + tx: ws::LockedWebSocketWrite, close_channel: mpsc::UnboundedSender, is_closed: Arc, continue_recieved: Arc, flow_control: Arc, } -impl MuxStreamWrite { +impl MuxStreamWrite { /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { - return Err(crate::WispError::StreamAlreadyClosed); + return Err(WispError::StreamAlreadyClosed); } - if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) == 0 { + if self.role == Role::Client + && self.stream_type == StreamType::Tcp + && self.flow_control.load(Ordering::Acquire) == 0 + { self.continue_recieved.listen().await; } self.tx - .write_frame(crate::Packet::new_data(self.stream_id, data).into()) + .write_frame(Packet::new_data(self.stream_id, data).into()) .await?; - if self.role == crate::Role::Client { + if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( - self.flow_control - .load(Ordering::Acquire) - .saturating_sub(1), + self.flow_control.load(Ordering::Acquire).saturating_sub(1), Ordering::Release, ); } @@ -143,45 +146,48 @@ impl MuxStreamWrite { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { - return Err(crate::WispError::StreamAlreadyClosed); + return Err(WispError::StreamAlreadyClosed); } - let (tx, rx) = oneshot::channel::>(); + let (tx, rx) = oneshot::channel::>(); self.close_channel .unbounded_send(WsEvent::Close(self.stream_id, reason, tx)) - .map_err(|x| crate::WispError::Other(Box::new(x)))?; - rx.await - .map_err(|x| crate::WispError::Other(Box::new(x)))??; + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; self.is_closed.store(true, Ordering::Release); Ok(()) } - pub(crate) fn into_sink(self) -> Pin + Send>> { + pub(crate) fn into_sink(self) -> Pin + Send>> { let handle = self.get_close_handle(); - Box::pin(crate::sink_unfold::unfold(self, |tx, data| async move { - tx.write(data).await?; - Ok(tx) - }, move || { - handle.close_sync(crate::CloseReason::Unknown) - })) + Box::pin(sink_unfold::unfold( + self, + |tx, data| async move { + tx.write(data).await?; + Ok(tx) + }, + move || handle.close_sync(CloseReason::Unknown), + )) } } -impl Drop for MuxStreamWrite { +impl Drop for MuxStreamWrite { fn drop(&mut self) { - let (tx, _) = oneshot::channel::>(); - let _ = self - .close_channel - .unbounded_send(WsEvent::Close(self.stream_id, crate::CloseReason::Unknown, tx)); + let (tx, _) = oneshot::channel::>(); + let _ = self.close_channel.unbounded_send(WsEvent::Close( + self.stream_id, + CloseReason::Unknown, + tx, + )); } } /// Multiplexor stream. pub struct MuxStream where - W: crate::ws::WebSocketWrite, + W: ws::WebSocketWrite, { /// ID of the stream. pub stream_id: u32, @@ -189,18 +195,18 @@ where tx: MuxStreamWrite, } -impl MuxStream { +impl MuxStream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( stream_id: u32, - role: crate::Role, - stream_type: crate::StreamType, + role: Role, + stream_type: StreamType, rx: mpsc::UnboundedReceiver, - tx: crate::ws::LockedWebSocketWrite, + tx: ws::LockedWebSocketWrite, close_channel: mpsc::UnboundedSender, is_closed: Arc, flow_control: Arc, - continue_recieved: Arc + continue_recieved: Arc, ) -> Self { Self { stream_id, @@ -215,6 +221,7 @@ impl MuxStream { }, tx: MuxStreamWrite { stream_id, + stream_type, role, tx, close_channel, @@ -231,7 +238,7 @@ impl MuxStream { } /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { self.tx.write(data).await } @@ -251,7 +258,7 @@ impl MuxStream { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { self.tx.close(reason).await } @@ -280,29 +287,28 @@ pub struct MuxStreamCloser { impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { - return Err(crate::WispError::StreamAlreadyClosed); + return Err(WispError::StreamAlreadyClosed); } - let (tx, rx) = oneshot::channel::>(); + let (tx, rx) = oneshot::channel::>(); self.close_channel .unbounded_send(WsEvent::Close(self.stream_id, reason, tx)) - .map_err(|x| crate::WispError::Other(Box::new(x)))?; - rx.await - .map_err(|x| crate::WispError::Other(Box::new(x)))??; + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; self.is_closed.store(true, Ordering::Release); Ok(()) } /// Close the stream. This function does not check if it was actually closed. - pub(crate) fn close_sync(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { + pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { - return Err(crate::WispError::StreamAlreadyClosed); + return Err(WispError::StreamAlreadyClosed); } - let (tx, _) = oneshot::channel::>(); + let (tx, _) = oneshot::channel::>(); self.close_channel .unbounded_send(WsEvent::Close(self.stream_id, reason, tx)) - .map_err(|x| crate::WispError::Other(Box::new(x)))?; + .map_err(|x| WispError::Other(Box::new(x)))?; self.is_closed.store(true, Ordering::Release); Ok(()) } @@ -314,7 +320,7 @@ pin_project! { #[pin] rx: Pin + Send>>, #[pin] - tx: Pin + Send>>, + tx: Pin + Send>>, } }