From f5b50bcc988d05a90c600f993d139fa23a854f55 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 7 Sep 2024 10:41:49 -0700 Subject: [PATCH] congestion stream types --- client/src/utils.rs | 21 ++++----- server/src/handle/twisp.rs | 4 ++ server/src/handle/wisp.rs | 5 +- wisp/src/extensions/mod.rs | 4 ++ wisp/src/extensions/password.rs | 4 ++ wisp/src/extensions/udp.rs | 4 ++ wisp/src/inner.rs | 82 +++++++++++++++++++++------------ wisp/src/lib.rs | 32 ++++++------- wisp/src/stream.rs | 9 +++- 9 files changed, 103 insertions(+), 62 deletions(-) diff --git a/client/src/utils.rs b/client/src/utils.rs index 306132a..8d53ce4 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -87,20 +87,15 @@ impl IncomingBody { impl Stream for IncomingBody { type Item = std::io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.incoming.poll_frame(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(match item { - Some(frame) => frame - .map(|x| { - x.into_data() - .map_err(|_| std::io::Error::other("not data frame")) + self.project().incoming.poll_frame(cx).map(|x| { + x.map(|x| { + x.map_err(std::io::Error::other).and_then(|x| { + x.into_data().map_err(|_| { + std::io::Error::other("trailer frame recieved; not implemented") }) - .ok(), - None => None, - }), - Poll::Pending => Poll::>::Pending, - } + }) + }) + }) } } diff --git a/server/src/handle/twisp.rs b/server/src/handle/twisp.rs index 33e1271..aa34b57 100644 --- a/server/src/handle/twisp.rs +++ b/server/src/handle/twisp.rs @@ -37,6 +37,10 @@ impl ProtocolExtension for TWispServerProtocolExtension { &[0xF0] } + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[0x03] + } + fn encode(&self) -> Bytes { Bytes::new() } diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 425aeda..6b8af0b 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -243,7 +243,10 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { let mut set: JoinSet<()> = JoinSet::new(); let event: Arc = Event::new().into(); - set.spawn(tokio::task::unconstrained(fut.map(|_| {}))); + let mux_id = id.clone(); + set.spawn(tokio::task::unconstrained(fut.map(move |x| { + trace!("wisp client id {:?} multiplexor result {:?}", mux_id, x) + }))); while let Some((connect, stream)) = mux.server_new_stream().await { set.spawn(handle_stream( diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 7de097f..141d45d 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -65,6 +65,10 @@ pub trait ProtocolExtension: std::fmt::Debug { /// /// Used to decide whether to call the protocol extension's packet handler. fn get_supported_packets(&self) -> &'static [u8]; + /// Get stream types that should be treated as TCP. + /// + /// Used to decide whether to handle congestion control for that stream type. + fn get_congestion_stream_types(&self) -> &'static [u8]; /// Encode self into Bytes. fn encode(&self) -> Bytes; diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index 05bd489..6246c6c 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -99,6 +99,10 @@ impl ProtocolExtension for PasswordProtocolExtension { &[] } + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[] + } + fn encode(&self) -> Bytes { match self.role { Role::Server => Bytes::new(), diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index 8510277..b2b5150 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -39,6 +39,10 @@ impl ProtocolExtension for UdpProtocolExtension { &[] } + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[] + } + fn encode(&self) -> Bytes { Bytes::new() } diff --git a/wisp/src/inner.rs b/wisp/src/inner.rs index 58a2338..142714c 100644 --- a/wisp/src/inner.rs +++ b/wisp/src/inner.rs @@ -31,6 +31,7 @@ struct MuxMapValue { stream: mpsc::Sender, stream_type: StreamType, + should_flow_control: bool, flow_control: Arc, flow_control_event: Arc, @@ -44,11 +45,12 @@ pub struct MuxInner { rx: Option, tx: LockedWebSocketWrite, extensions: Vec, + tcp_extensions: Vec, role: Role, // gets taken by the mux task - fut_rx: Option>, - fut_tx: mpsc::Sender, + actor_rx: Option>, + actor_tx: mpsc::Sender, fut_exited: Arc, stream_map: IntMap, @@ -59,16 +61,29 @@ pub struct MuxInner { server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, } +pub struct MuxInnerResult { + pub mux: MuxInner, + pub actor_exited: Arc, + pub actor_tx: mpsc::Sender, +} + impl MuxInner { + fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec { + extensions + .iter() + .flat_map(|x| x.get_congestion_stream_types()) + .copied() + .chain(std::iter::once(StreamType::Tcp.into())) + .collect() + } + pub fn new_server( rx: R, tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, ) -> ( - Self, - Arc, - mpsc::Sender, + MuxInnerResult, mpsc::Receiver<(ConnectPacket, MuxStream)>, ) { let (fut_tx, fut_rx) = mpsc::bounded::(256); @@ -77,26 +92,29 @@ impl MuxInner { let fut_exited = Arc::new(AtomicBool::new(false)); ( - Self { - rx: Some(rx), - tx, + MuxInnerResult { + mux: Self { + rx: Some(rx), + tx, - fut_rx: Some(fut_rx), - fut_tx, - fut_exited: fut_exited.clone(), + actor_rx: Some(fut_rx), + actor_tx: fut_tx, + fut_exited: fut_exited.clone(), - extensions, - buffer_size, - target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, + tcp_extensions: Self::get_tcp_extensions(&extensions), + extensions, + buffer_size, + target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, - role: Role::Server, + role: Role::Server, - stream_map: IntMap::default(), + stream_map: IntMap::default(), - server_tx, + server_tx, + }, + actor_exited: fut_exited, + actor_tx: ret_fut_tx, }, - fut_exited, - ret_fut_tx, server_rx, ) } @@ -106,21 +124,22 @@ impl MuxInner { tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, - ) -> (Self, Arc, mpsc::Sender) { + ) -> MuxInnerResult { let (fut_tx, fut_rx) = mpsc::bounded::(256); let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let ret_fut_tx = fut_tx.clone(); let fut_exited = Arc::new(AtomicBool::new(false)); - ( - Self { + MuxInnerResult { + mux: Self { rx: Some(rx), tx, - fut_rx: Some(fut_rx), - fut_tx, + actor_rx: Some(fut_rx), + actor_tx: fut_tx, fut_exited: fut_exited.clone(), + tcp_extensions: Self::get_tcp_extensions(&extensions), extensions, buffer_size, target_buffer_size: 0, @@ -131,9 +150,9 @@ impl MuxInner { server_tx, }, - fut_exited, - ret_fut_tx, - ) + actor_exited: fut_exited, + actor_tx: ret_fut_tx, + } } pub async fn into_future(mut self) -> Result<(), WispError> { @@ -157,6 +176,7 @@ impl MuxInner { ) -> Result<(MuxMapValue, MuxStream), WispError> { let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + let should_flow_control = self.tcp_extensions.contains(&stream_type.into()); let flow_control_event: Arc = Event::new().into(); let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); @@ -170,6 +190,7 @@ impl MuxInner { stream: ch_tx, stream_type, + should_flow_control, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), @@ -182,11 +203,12 @@ impl MuxInner { self.role, stream_type, ch_rx, - self.fut_tx.clone(), + self.actor_tx.clone(), self.tx.clone(), is_closed, is_closed_event, close_reason, + should_flow_control, flow_control, flow_control_event, self.target_buffer_size, @@ -233,7 +255,7 @@ impl MuxInner { let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; let tx = self.tx.clone(); - let fut_rx = self.fut_rx.take().ok_or(WispError::MuxTaskStarted)?; + let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?; let mut recv_fut = fut_rx.recv_async().fuse(); let mut read_fut = rx.wisp_read_split(&tx).fuse(); @@ -349,7 +371,7 @@ impl MuxInner { } } let _ = stream.stream.try_send(data.freeze()); - if self.role == Role::Server && stream.stream_type == StreamType::Tcp { + if self.role == Role::Server && stream.should_flow_control { stream.flow_control.store( stream .flow_control diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 407df26..6ccd8ec 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -224,7 +224,7 @@ pub struct ServerMux { actor_tx: mpsc::Sender, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, tx: ws::LockedWebSocketWrite, - fut_exited: Arc, + actor_exited: Arc, } impl ServerMux { @@ -267,7 +267,7 @@ impl ServerMux { let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); - let (mux_inner, fut_exited, actor_tx, muxstream_recv) = MuxInner::new_server( + let (mux_result, muxstream_recv) = MuxInner::new_server( AppendingWebSocketRead(extra_packet, rx), tx.clone(), supported_extensions, @@ -277,26 +277,26 @@ impl ServerMux { Ok(ServerMuxResult( Self { muxstream_recv, - actor_tx, + actor_tx: mux_result.actor_tx, downgraded, supported_extension_ids, tx, - fut_exited: fut_exited.clone(), + actor_exited: mux_result.actor_exited, }, - mux_inner.into_future(), + mux_result.mux.into_future(), )) } /// Wait for a stream to be created. pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { - if self.fut_exited.load(Ordering::Acquire) { + if self.actor_exited.load(Ordering::Acquire) { return None; } self.muxstream_recv.recv_async().await.ok() } async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.fut_exited.load(Ordering::Acquire) { + if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } self.actor_tx @@ -325,7 +325,7 @@ impl ServerMux { MuxProtocolExtensionStream { stream_id: 0, tx: self.tx.clone(), - is_closed: self.fut_exited.clone(), + is_closed: self.actor_exited.clone(), } } } @@ -401,7 +401,7 @@ pub struct ClientMux { pub supported_extension_ids: Vec, actor_tx: mpsc::Sender, tx: ws::LockedWebSocketWrite, - fut_exited: Arc, + actor_exited: Arc, } impl ClientMux { @@ -450,7 +450,7 @@ impl ClientMux { let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); - let (mux_inner, fut_exited, actor_tx) = MuxInner::new_client( + let mux_result = MuxInner::new_client( AppendingWebSocketRead(extra_packet, rx), tx.clone(), supported_extensions, @@ -459,13 +459,13 @@ impl ClientMux { Ok(ClientMuxResult( Self { - actor_tx, + actor_tx: mux_result.actor_tx, downgraded, supported_extension_ids, tx, - fut_exited, + actor_exited: mux_result.actor_exited, }, - mux_inner.into_future(), + mux_result.mux.into_future(), )) } else { Err(WispError::InvalidPacketType) @@ -479,7 +479,7 @@ impl ClientMux { host: String, port: u16, ) -> Result { - if self.fut_exited.load(Ordering::Acquire) { + if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } if stream_type == StreamType::Udp @@ -501,7 +501,7 @@ impl ClientMux { } async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.fut_exited.load(Ordering::Acquire) { + if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } self.actor_tx @@ -530,7 +530,7 @@ impl ClientMux { MuxProtocolExtensionStream { stream_id: 0, tx: self.tx.clone(), - is_closed: self.fut_exited.clone(), + is_closed: self.actor_exited.clone(), } } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 4972a5b..1755db6 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -40,6 +40,7 @@ pub struct MuxStreamRead { is_closed_event: Arc, close_reason: Arc, + should_flow_control: bool, flow_control: Arc, flow_control_read: AtomicU32, target_flow_control: u32, @@ -55,7 +56,7 @@ impl MuxStreamRead { x = self.rx.recv_async() => x.ok()?, _ = self.is_closed_event.listen().fuse() => return None }; - if self.role == Role::Server && self.stream_type == StreamType::Tcp { + if self.role == Role::Server && self.should_flow_control { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { self.tx @@ -114,6 +115,7 @@ pub struct MuxStreamWrite { close_reason: Arc, continue_recieved: Arc, + should_flow_control: bool, flow_control: Arc, } @@ -124,7 +126,7 @@ impl MuxStreamWrite { body: Frame<'a>, ) -> Result<(), WispError> { if self.role == Role::Client - && self.stream_type == StreamType::Tcp + && self.should_flow_control && self.flow_control.load(Ordering::Acquire) == 0 { self.continue_recieved.listen().await; @@ -278,6 +280,7 @@ impl MuxStream { is_closed: Arc, is_closed_event: Arc, close_reason: Arc, + should_flow_control: bool, flow_control: Arc, continue_recieved: Arc, target_flow_control: u32, @@ -293,6 +296,7 @@ impl MuxStream { is_closed: is_closed.clone(), is_closed_event: is_closed_event.clone(), close_reason: close_reason.clone(), + should_flow_control, flow_control: flow_control.clone(), flow_control_read: AtomicU32::new(0), target_flow_control, @@ -305,6 +309,7 @@ impl MuxStream { tx, is_closed: is_closed.clone(), close_reason: close_reason.clone(), + should_flow_control, flow_control: flow_control.clone(), continue_recieved: continue_recieved.clone(), },