wisp_mux v1.1.1: fix continue packets and flow control

This commit is contained in:
Toshit Chawda 2024-02-08 09:43:21 -08:00
parent 9ebb24b088
commit 8b2a8a3eb3
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 43 additions and 27 deletions

2
Cargo.lock generated
View file

@ -1860,7 +1860,7 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
[[package]] [[package]]
name = "wisp-mux" name = "wisp-mux"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"async_io_stream", "async_io_stream",
"bytes", "bytes",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "wisp-mux" name = "wisp-mux"
version = "1.1.0" version = "1.1.1"
license = "AGPL-3.0-only" license = "AGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients." description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"

View file

@ -6,8 +6,8 @@
#[cfg(feature = "fastwebsockets")] #[cfg(feature = "fastwebsockets")]
mod fastwebsockets; mod fastwebsockets;
mod sink_unfold;
mod packet; mod packet;
mod sink_unfold;
mod stream; mod stream;
#[cfg(feature = "hyper_tower")] #[cfg(feature = "hyper_tower")]
pub mod tokioio; pub mod tokioio;
@ -111,12 +111,18 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {} impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::UnboundedSender<MuxEvent>,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
}
struct ServerMuxInner<W> struct ServerMuxInner<W>
where where
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<MuxEvent>>>>, stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>, close_tx: mpsc::UnboundedSender<WsEvent>,
} }
@ -132,11 +138,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
let ret = futures::select! { let ret = futures::select! {
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, x = self.server_close_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
}; };
self.stream_map.lock().await.iter().for_each(|x| { self.stream_map.lock().await.iter().for_each(|x| {
let _ = x.1.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown))); let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown)));
}); });
ret ret
} }
@ -144,15 +152,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
async fn server_close_loop( async fn server_close_loop(
&self, &self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>, mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<MuxEvent>>>>,
tx: ws::LockedWebSocketWrite<W>,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await { while let Some(msg) = close_rx.next().await {
match msg { match msg {
WsEvent::Close(stream_id, reason, channel) => { WsEvent::Close(stream_id, reason, channel) => {
if stream_map.lock().await.remove(&stream_id).is_some() { if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send( let _ = channel.send(
tx.write_frame(Packet::new_close(stream_id, reason).into()) self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await, .await,
); );
} else { } else {
@ -184,7 +191,17 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded(); let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type; let stream_type = inner_packet.stream_type;
self.stream_map.lock().await.insert(packet.stream_id, ch_tx); let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
self.stream_map.lock().await.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
},
);
muxstream_sender muxstream_sender
.unbounded_send(( .unbounded_send((
inner_packet, inner_packet,
@ -196,21 +213,27 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
self.tx.clone(), self.tx.clone(),
self.close_tx.clone(), self.close_tx.clone(),
AtomicBool::new(false).into(), AtomicBool::new(false).into(),
AtomicU32::new(buffer_size).into(), flow_control,
Event::new().into(), flow_control_event,
), ),
)) ))
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(MuxEvent::Send(data)); let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
stream.flow_control.store(
stream.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
} }
} }
Continue(_) => unreachable!(), Continue(_) => unreachable!(),
Close(inner_packet) => { Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(MuxEvent::Close(inner_packet)); let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
} }
self.stream_map.lock().await.remove(&packet.stream_id); self.stream_map.lock().await.remove(&packet.stream_id);
} }
@ -281,18 +304,12 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
} }
} }
struct ClientMuxMapValue {
stream: mpsc::UnboundedSender<MuxEvent>,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
}
struct ClientMuxInner<W> struct ClientMuxInner<W>
where where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, ClientMuxMapValue>>>, stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
} }
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> { impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
@ -386,7 +403,7 @@ where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, ClientMuxMapValue>>>, stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
next_free_stream_id: AtomicU32, next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<WsEvent>, close_tx: mpsc::UnboundedSender<WsEvent>,
buf_size: u32, buf_size: u32,
@ -450,7 +467,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
); );
self.stream_map.lock().await.insert( self.stream_map.lock().await.insert(
stream_id, stream_id,
ClientMuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_event: evt.clone(), flow_control_event: evt.clone(),

View file

@ -53,7 +53,7 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
match self.rx.next().await? { match self.rx.next().await? {
MuxEvent::Send(bytes) => { MuxEvent::Send(bytes) => {
if self.role == crate::Role::Server && self.stream_type == crate::StreamType::Tcp { if self.role == crate::Role::Server && self.stream_type == crate::StreamType::Tcp {
let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst); let old_val = self.flow_control.fetch_add(1, Ordering::AcqRel);
self.tx self.tx
.write_frame( .write_frame(
crate::Packet::new_continue(self.stream_id, old_val + 1).into(), crate::Packet::new_continue(self.stream_id, old_val + 1).into(),
@ -115,8 +115,7 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
self.flow_control.store( self.flow_control.store(
self.flow_control self.flow_control
.load(Ordering::Acquire) .load(Ordering::Acquire)
.checked_add(1) .saturating_sub(1),
.unwrap_or(0),
Ordering::Release, Ordering::Release,
); );
} }