diff --git a/Cargo.lock b/Cargo.lock index fc2ffc8..ac0baec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1860,7 +1860,7 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "wisp-mux" -version = "1.1.0" +version = "1.1.1" dependencies = [ "async_io_stream", "bytes", diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 8a9240d..6d426d2 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "1.1.0" +version = "1.1.1" license = "AGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 3c46837..93cc870 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -6,8 +6,8 @@ #[cfg(feature = "fastwebsockets")] mod fastwebsockets; -mod sink_unfold; mod packet; +mod sink_unfold; mod stream; #[cfg(feature = "hyper_tower")] pub mod tokioio; @@ -111,12 +111,18 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} +struct MuxMapValue { + stream: mpsc::UnboundedSender, + flow_control: Arc, + flow_control_event: Arc, +} + struct ServerMuxInner where W: ws::WebSocketWrite + Send + 'static, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>>, + stream_map: Arc>>, close_tx: mpsc::UnboundedSender, } @@ -132,11 +138,13 @@ impl ServerMuxInner { R: ws::WebSocketRead, { let ret = futures::select! { - x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, + x = self.server_close_loop(close_rx).fuse() => x, x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x }; self.stream_map.lock().await.iter().for_each(|x| { - let _ = x.1.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown))); + let _ = + x.1.stream + .unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown))); }); ret } @@ -144,15 +152,14 @@ impl ServerMuxInner { async fn server_close_loop( &self, mut close_rx: mpsc::UnboundedReceiver, - stream_map: Arc>>>, - tx: ws::LockedWebSocketWrite, ) -> Result<(), WispError> { while let Some(msg) = close_rx.next().await { match msg { WsEvent::Close(stream_id, reason, channel) => { - if stream_map.lock().await.remove(&stream_id).is_some() { + if self.stream_map.lock().await.remove(&stream_id).is_some() { let _ = channel.send( - tx.write_frame(Packet::new_close(stream_id, reason).into()) + self.tx + .write_frame(Packet::new_close(stream_id, reason).into()) .await, ); } else { @@ -184,7 +191,17 @@ impl ServerMuxInner { Connect(inner_packet) => { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_type = inner_packet.stream_type; - self.stream_map.lock().await.insert(packet.stream_id, ch_tx); + let flow_control: Arc = AtomicU32::new(buffer_size).into(); + let flow_control_event: Arc = Event::new().into(); + + self.stream_map.lock().await.insert( + packet.stream_id, + MuxMapValue { + stream: ch_tx, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + }, + ); muxstream_sender .unbounded_send(( inner_packet, @@ -196,21 +213,27 @@ impl ServerMuxInner { self.tx.clone(), self.close_tx.clone(), AtomicBool::new(false).into(), - AtomicU32::new(buffer_size).into(), - Event::new().into(), + flow_control, + flow_control_event, ), )) .map_err(|x| WispError::Other(Box::new(x)))?; } Data(data) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - let _ = stream.unbounded_send(MuxEvent::Send(data)); + let _ = stream.stream.unbounded_send(MuxEvent::Send(data)); + stream.flow_control.store( + stream.flow_control + .load(Ordering::Acquire) + .saturating_sub(1), + Ordering::Release, + ); } } Continue(_) => unreachable!(), Close(inner_packet) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - let _ = stream.unbounded_send(MuxEvent::Close(inner_packet)); + let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet)); } self.stream_map.lock().await.remove(&packet.stream_id); } @@ -281,18 +304,12 @@ impl ServerMux { } } -struct ClientMuxMapValue { - stream: mpsc::UnboundedSender, - flow_control: Arc, - flow_control_event: Arc, -} - struct ClientMuxInner where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>, } impl ClientMuxInner { @@ -386,7 +403,7 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>, next_free_stream_id: AtomicU32, close_tx: mpsc::UnboundedSender, buf_size: u32, @@ -450,7 +467,7 @@ impl ClientMux { ); self.stream_map.lock().await.insert( stream_id, - ClientMuxMapValue { + MuxMapValue { stream: ch_tx, flow_control: flow_control.clone(), flow_control_event: evt.clone(), diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 84ac0ce..70ccad4 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -53,7 +53,7 @@ impl MuxStreamRead { match self.rx.next().await? { MuxEvent::Send(bytes) => { if self.role == crate::Role::Server && self.stream_type == crate::StreamType::Tcp { - let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst); + let old_val = self.flow_control.fetch_add(1, Ordering::AcqRel); self.tx .write_frame( crate::Packet::new_continue(self.stream_id, old_val + 1).into(), @@ -115,8 +115,7 @@ impl MuxStreamWrite { self.flow_control.store( self.flow_control .load(Ordering::Acquire) - .checked_add(1) - .unwrap_or(0), + .saturating_sub(1), Ordering::Release, ); }