make server mux more like client mux and fix deadlock

This commit is contained in:
Toshit Chawda 2024-02-03 16:31:02 -08:00
parent fa2b84d646
commit 9f1561fa76
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 101 additions and 77 deletions

1
Cargo.lock generated
View file

@ -1449,7 +1449,6 @@ version = "0.1.0"
dependencies = [
"async_io_stream",
"bytes",
"dashmap",
"fastwebsockets",
"futures",
"futures-util",

View file

@ -20,7 +20,7 @@ use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, W
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
#[tokio::main(flavor = "multi_thread")]
#[tokio::main]
async fn main() -> Result<(), Error> {
let pem = include_bytes!("./pem.pem");
let key = include_bytes!("./key.pem");
@ -117,7 +117,6 @@ async fn handle_mux(
loop {
tokio::select! {
event = stream.read() => {
println!("ws rx");
match event {
Some(event) => match event {
WsEvent::Send(data) => {
@ -129,10 +128,9 @@ async fn handle_mux(
}
},
event = tcp_stream_framed.next() => {
println!("tcp rx");
match event.and_then(|x| x.ok()) {
Some(event) => stream.write(event.into()).await?,
None => return Ok(true),
None => break,
}
}
}
@ -176,10 +174,18 @@ async fn accept_ws(
println!("{:?}: connected", addr);
ServerMux::handle(rx, tx, &mut |packet, stream| async move {
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let (mut mux, fut) = ServerMux::new(rx, tx);
tokio::spawn(async move {
if let Err(e) = fut.await {
println!("err in mux: {:?}", e);
}
});
while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream)
.or_else(|err| async move {
let _ = close_err.close(0x03).await;
@ -194,9 +200,7 @@ async fn accept_ws(
})
.await;
});
Ok(())
})
.await?;
}
println!("{:?}: disconnected", addr);
Ok(())

View file

@ -6,7 +6,6 @@ edition = "2021"
[dependencies]
async_io_stream = "0.3.3"
bytes = "1.5.0"
dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
futures = "0.3.30"
futures-util = "0.3.30"

View file

@ -9,11 +9,13 @@ mod ws_stream_wasm;
pub use crate::packet::*;
pub use crate::stream::*;
use dashmap::DashMap;
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
#[derive(Debug, PartialEq)]
@ -68,63 +70,45 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
pub struct ServerMux<W>
struct ServerMuxInner<W>
where
W: ws::WebSocketWrite,
W: ws::WebSocketWrite + Send + 'static,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn handle<'a, FR, R>(
read: R,
write: W,
handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
) -> impl Future<Output = Result<(), WispError>> + 'a
where
FR: std::future::Future<Output = Result<(), WispError>> + 'a,
R: ws::WebSocketRead + 'a,
W: ws::WebSocketWrite + 'a,
{
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(DashMap::new());
let inner = ServerMux {
stream_map: map.clone(),
tx: write.clone(),
close_tx: tx,
};
inner.into_future(read, rx, handler_fn)
}
async fn into_future<R, FR>(
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
pub async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
FR: std::future::Future<Output = Result<(), WispError>>,
{
futures::select! {
let ret = futures::select! {
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
x = self.server_msg_loop(rx, handler_fn).fuse() => x
}
x = self.server_msg_loop(rx, muxstream_sender).fuse() => x
};
self.stream_map.lock().await.iter().for_each(|x| {
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01)));
});
ret
}
async fn server_close_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
tx: ws::LockedWebSocketWrite<W>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if stream_map.clone().remove(&stream_id).is_some() {
if stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
tx.write_frame(Packet::new_close(stream_id, reason).into())
.await,
@ -138,14 +122,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
Ok(())
}
async fn server_msg_loop<R, FR>(
async fn server_msg_loop<R>(
&self,
mut rx: R,
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
FR: std::future::Future<Output = Result<(), WispError>>,
{
self.tx
.write_frame(Packet::new_continue(0, u32::MAX).into())
@ -157,21 +140,22 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
self.stream_map.clone().insert(packet.stream_id, ch_tx);
let _ = handler_fn(
inner_packet,
MuxStream::new(
packet.stream_id,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
),
)
.await;
self.stream_map.lock().await.insert(packet.stream_id, ch_tx);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
self.tx
.write_frame(
@ -182,24 +166,59 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
self.stream_map.clone().remove(&packet.stream_id);
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
} else {
break;
}
}
drop(muxstream_sender);
Ok(())
}
}
pub struct ServerMux<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn new<R>(read: R, write: W) -> (Self, impl Future<Output = Result<(), WispError>>)
where
R: ws::WebSocketRead,
{
let (close_tx, close_rx) = mpsc::unbounded::<MuxEvent>();
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self { muxstream_recv: rx },
ServerMuxInner {
tx: write,
close_tx,
stream_map: map.clone(),
}
.into_future(read, close_rx, tx),
)
}
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
self.muxstream_recv.next().await
}
}
pub struct ClientMuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
}
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
@ -211,7 +230,10 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
where
R: ws::WebSocketRead,
{
futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ())
futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
}
}
async fn client_bg_loop(
@ -221,7 +243,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.clone().remove(&stream_id).is_some() {
if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
@ -246,15 +268,15 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
match packet.packet {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
}
}
Continue(_) => {}
Close(inner_packet) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
self.stream_map.clone().remove(&packet.stream_id);
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
@ -269,7 +291,7 @@ where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
@ -280,7 +302,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
R: ws::WebSocketRead,
{
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let map = Arc::new(DashMap::new());
let map = Arc::new(Mutex::new(HashMap::new()));
let write = ws::LockedWebSocketWrite::new(write);
(
Self {
@ -314,7 +336,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
.ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release,
);
self.stream_map.clone().insert(stream_id, ch_tx);
self.stream_map.lock().await.insert(stream_id, ch_tx);
Ok(MuxStream::new(
stream_id,
ch_rx,