From 9f1561fa765cc073a81515cff26f0ce2a6c4581d Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 3 Feb 2024 16:31:02 -0800 Subject: [PATCH] make server mux more like client mux and fix deadlock --- Cargo.lock | 1 - server/src/main.rs | 24 ++++--- wisp/Cargo.toml | 1 - wisp/src/lib.rs | 152 ++++++++++++++++++++++++++------------------- 4 files changed, 101 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eec8ba9..eefc704 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1449,7 +1449,6 @@ version = "0.1.0" dependencies = [ "async_io_stream", "bytes", - "dashmap", "fastwebsockets", "futures", "futures-util", diff --git a/server/src/main.rs b/server/src/main.rs index 0aa7194..aeca64e 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -20,7 +20,7 @@ use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, W type HttpBody = http_body_util::Empty; -#[tokio::main(flavor = "multi_thread")] +#[tokio::main] async fn main() -> Result<(), Error> { let pem = include_bytes!("./pem.pem"); let key = include_bytes!("./key.pem"); @@ -117,7 +117,6 @@ async fn handle_mux( loop { tokio::select! { event = stream.read() => { - println!("ws rx"); match event { Some(event) => match event { WsEvent::Send(data) => { @@ -129,10 +128,9 @@ async fn handle_mux( } }, event = tcp_stream_framed.next() => { - println!("tcp rx"); match event.and_then(|x| x.ok()) { Some(event) => stream.write(event.into()).await?, - None => return Ok(true), + None => break, } } } @@ -176,10 +174,18 @@ async fn accept_ws( println!("{:?}: connected", addr); - ServerMux::handle(rx, tx, &mut |packet, stream| async move { - let close_err = stream.get_close_handle(); - let close_ok = stream.get_close_handle(); + let (mut mux, fut) = ServerMux::new(rx, tx); + + tokio::spawn(async move { + if let Err(e) = fut.await { + println!("err in mux: {:?}", e); + } + }); + + while let Some((packet, stream)) = mux.server_new_stream().await { tokio::spawn(async move { + let close_err = stream.get_close_handle(); + let close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) .or_else(|err| async move { let _ = close_err.close(0x03).await; @@ -194,9 +200,7 @@ async fn accept_ws( }) .await; }); - Ok(()) - }) - .await?; + } println!("{:?}: disconnected", addr); Ok(()) diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index ee3c3d2..fc834d0 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] async_io_stream = "0.3.3" bytes = "1.5.0" -dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 0e75e7d..9326ece 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -9,11 +9,13 @@ mod ws_stream_wasm; pub use crate::packet::*; pub use crate::stream::*; -use dashmap::DashMap; -use futures::{channel::mpsc, Future, FutureExt, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, +use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt}; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, }; #[derive(Debug, PartialEq)] @@ -68,63 +70,45 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub struct ServerMux +struct ServerMuxInner where - W: ws::WebSocketWrite, + W: ws::WebSocketWrite + Send + 'static, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, close_tx: mpsc::UnboundedSender, } -impl ServerMux { - pub fn handle<'a, FR, R>( - read: R, - write: W, - handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream) -> FR, - ) -> impl Future> + 'a - where - FR: std::future::Future> + 'a, - R: ws::WebSocketRead + 'a, - W: ws::WebSocketWrite + 'a, - { - let (tx, rx) = mpsc::unbounded::(); - let write = ws::LockedWebSocketWrite::new(write); - let map = Arc::new(DashMap::new()); - let inner = ServerMux { - stream_map: map.clone(), - tx: write.clone(), - close_tx: tx, - }; - inner.into_future(read, rx, handler_fn) - } - - async fn into_future( +impl ServerMuxInner { + pub async fn into_future( self, rx: R, close_rx: mpsc::UnboundedReceiver, - handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, ) -> Result<(), WispError> where R: ws::WebSocketRead, - FR: std::future::Future>, { - futures::select! { + let ret = futures::select! { x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, - x = self.server_msg_loop(rx, handler_fn).fuse() => x - } + x = self.server_msg_loop(rx, muxstream_sender).fuse() => x + }; + self.stream_map.lock().await.iter().for_each(|x| { + let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01))); + }); + ret } async fn server_close_loop( &self, mut close_rx: mpsc::UnboundedReceiver, - stream_map: Arc>>, + stream_map: Arc>>>, tx: ws::LockedWebSocketWrite, ) -> Result<(), WispError> { while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if stream_map.clone().remove(&stream_id).is_some() { + if stream_map.lock().await.remove(&stream_id).is_some() { let _ = channel.send( tx.write_frame(Packet::new_close(stream_id, reason).into()) .await, @@ -138,14 +122,13 @@ impl ServerMux { Ok(()) } - async fn server_msg_loop( + async fn server_msg_loop( &self, mut rx: R, - handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, ) -> Result<(), WispError> where R: ws::WebSocketRead, - FR: std::future::Future>, { self.tx .write_frame(Packet::new_continue(0, u32::MAX).into()) @@ -157,21 +140,22 @@ impl ServerMux { match packet.packet { Connect(inner_packet) => { let (ch_tx, ch_rx) = mpsc::unbounded(); - self.stream_map.clone().insert(packet.stream_id, ch_tx); - let _ = handler_fn( - inner_packet, - MuxStream::new( - packet.stream_id, - ch_rx, - self.tx.clone(), - self.close_tx.clone(), - AtomicBool::new(false).into(), - ), - ) - .await; + self.stream_map.lock().await.insert(packet.stream_id, ch_tx); + muxstream_sender + .unbounded_send(( + inner_packet, + MuxStream::new( + packet.stream_id, + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + ), + )) + .map_err(|x| WispError::Other(Box::new(x)))?; } Data(data) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); self.tx .write_frame( @@ -182,24 +166,59 @@ impl ServerMux { } Continue(_) => unreachable!(), Close(inner_packet) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.clone().remove(&packet.stream_id); + self.stream_map.lock().await.remove(&packet.stream_id); } } } + } else { + break; } } + drop(muxstream_sender); Ok(()) } } +pub struct ServerMux +where + W: ws::WebSocketWrite + Send + 'static, +{ + muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, +} + +impl ServerMux { + pub fn new(read: R, write: W) -> (Self, impl Future>) + where + R: ws::WebSocketRead, + { + let (close_tx, close_rx) = mpsc::unbounded::(); + let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let write = ws::LockedWebSocketWrite::new(write); + let map = Arc::new(Mutex::new(HashMap::new())); + ( + Self { muxstream_recv: rx }, + ServerMuxInner { + tx: write, + close_tx, + stream_map: map.clone(), + } + .into_future(read, close_rx, tx), + ) + } + + pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { + self.muxstream_recv.next().await + } +} + pub struct ClientMuxInner where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, } impl ClientMuxInner { @@ -211,7 +230,10 @@ impl ClientMuxInner { where R: ws::WebSocketRead, { - futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ()) + futures::select! { + x = self.client_bg_loop(close_rx).fuse() => x, + x = self.client_loop(rx).fuse() => x + } } async fn client_bg_loop( @@ -221,7 +243,7 @@ impl ClientMuxInner { while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if self.stream_map.clone().remove(&stream_id).is_some() { + if self.stream_map.lock().await.remove(&stream_id).is_some() { let _ = channel.send( self.tx .write_frame(Packet::new_close(stream_id, reason).into()) @@ -246,15 +268,15 @@ impl ClientMuxInner { match packet.packet { Connect(_) => unreachable!(), Data(data) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); } } Continue(_) => {} Close(inner_packet) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.clone().remove(&packet.stream_id); + self.stream_map.lock().await.remove(&packet.stream_id); } } } @@ -269,7 +291,7 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, next_free_stream_id: AtomicU32, close_tx: mpsc::UnboundedSender, } @@ -280,7 +302,7 @@ impl ClientMux { R: ws::WebSocketRead, { let (tx, rx) = mpsc::unbounded::(); - let map = Arc::new(DashMap::new()); + let map = Arc::new(Mutex::new(HashMap::new())); let write = ws::LockedWebSocketWrite::new(write); ( Self { @@ -314,7 +336,7 @@ impl ClientMux { .ok_or(WispError::MaxStreamCountReached)?, Ordering::Release, ); - self.stream_map.clone().insert(stream_id, ch_tx); + self.stream_map.lock().await.insert(stream_id, ch_tx); Ok(MuxStream::new( stream_id, ch_rx,