From 0f05588b8fbf47f50fb7d536e472ba9e2fd58b1e Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 24 Aug 2024 21:24:34 -0700 Subject: [PATCH] cancel wisp worker tasks a bit more robustly --- Cargo.lock | 1 + server/Cargo.toml | 1 + server/src/handle/wisp.rs | 202 +++++++++++++++++++++----------------- 3 files changed, 114 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8453220..1434e61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -557,6 +557,7 @@ dependencies = [ "clap", "dashmap", "env_logger", + "event-listener", "fastwebsockets", "futures-util", "http-body-util", diff --git a/server/Cargo.toml b/server/Cargo.toml index a809bb0..1c2b00f 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,7 @@ cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } dashmap = "6.0.1" env_logger = "0.11.5" +event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } futures-util = "0.3.30" http-body-util = "0.1.2" diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 428fb4a..f9f9958 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -1,5 +1,8 @@ +use std::sync::Arc; + use anyhow::Context; use cfg_if::cfg_if; +use event_listener::Event; use futures_util::FutureExt; use log::{debug, trace}; use tokio::{ @@ -11,8 +14,7 @@ use tokio::{ use tokio_util::compat::FuturesAsyncReadCompatExt; use uuid::Uuid; use wisp_mux::{ - CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, - ServerMux, + CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux, }; use crate::{ @@ -56,6 +58,7 @@ async fn handle_stream( connect: ConnectPacket, muxstream: MuxStream, id: String, + event: Arc, #[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap, ) { let requested_stream = connect.clone(); @@ -91,101 +94,114 @@ async fn handle_stream( let uuid = Uuid::new_v4(); - trace!("new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}", id, uuid, requested_stream, resolved_stream); + trace!( + "new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}", + id, + uuid, + requested_stream, + resolved_stream + ); - CLIENTS - .get(&id) - .unwrap() - .0 - .insert(uuid, (requested_stream, resolved_stream)); + if let Some(client) = CLIENTS.get(&id) { + client.0.insert(uuid, (requested_stream, resolved_stream)); + } - match stream { - ClientStream::Tcp(stream) => { - let closer = muxstream.get_close_handle(); + let forward_fut = async { + match stream { + ClientStream::Tcp(stream) => { + let closer = muxstream.get_close_handle(); - let ret: anyhow::Result<()> = async { - let (muxread, muxwrite) = muxstream.into_split(); - let muxread = muxread.into_stream().into_asyncread(); - let (tcpread, tcpwrite) = stream.into_split(); - select! { - x = copy_read_fast(muxread, tcpwrite) => x?, - x = copy_write_fast(muxwrite, tcpread) => x?, - } - Ok(()) - } - .await; - - match ret { - Ok(()) => { - let _ = closer.close(CloseReason::Voluntary).await; - } - Err(_) => { - let _ = closer.close(CloseReason::Unexpected).await; - } - } - } - ClientStream::Udp(stream) => { - let closer = muxstream.get_close_handle(); - - let ret: anyhow::Result<()> = async move { - let mut data = vec![0u8; 65507]; - loop { + let ret: anyhow::Result<()> = async { + let (muxread, muxwrite) = muxstream.into_split(); + let muxread = muxread.into_stream().into_asyncread(); + let (tcpread, tcpwrite) = stream.into_split(); select! { - size = stream.recv(&mut data) => { - let size = size?; - muxstream.write(&data[..size]).await?; - } - data = muxstream.read() => { - if let Some(data) = data { - stream.send(&data).await?; - } else { - break Ok(()); + x = copy_read_fast(muxread, tcpwrite) => x?, + x = copy_write_fast(muxwrite, tcpread) => x?, + } + Ok(()) + } + .await; + + match ret { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } + ClientStream::Udp(stream) => { + let closer = muxstream.get_close_handle(); + + let ret: anyhow::Result<()> = async move { + let mut data = vec![0u8; 65507]; + loop { + select! { + size = stream.recv(&mut data) => { + let size = size?; + muxstream.write(&data[..size]).await?; + } + data = muxstream.read() => { + if let Some(data) = data { + stream.send(&data).await?; + } else { + break Ok(()); + } } } } } - } - .await; - - match ret { - Ok(()) => { - let _ = closer.close(CloseReason::Voluntary).await; - } - Err(_) => { - let _ = closer.close(CloseReason::Unexpected).await; - } - } - } - #[cfg(feature = "twisp")] - ClientStream::Pty(cmd, pty) => { - let closer = muxstream.get_close_handle(); - let id = muxstream.stream_id; - let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); - - match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) - .await - { - Ok(()) => { - let _ = closer.close(CloseReason::Voluntary).await; - } - Err(_) => { - let _ = closer.close(CloseReason::Unexpected).await; - } - } - } - ClientStream::Invalid => { - let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; - } - ClientStream::Blocked => { - let _ = muxstream - .close(CloseReason::ServerStreamBlockedAddress) .await; - } + + match ret { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } + #[cfg(feature = "twisp")] + ClientStream::Pty(cmd, pty) => { + let closer = muxstream.get_close_handle(); + let id = muxstream.stream_id; + let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); + + match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) + .await + { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } + ClientStream::Invalid => { + let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; + } + ClientStream::Blocked => { + let _ = muxstream + .close(CloseReason::ServerStreamBlockedAddress) + .await; + } + }; + }; + + select! { + x = forward_fut => x, + x = event.listen() => x, }; trace!("stream uuid {:?} disconnected for client id {:?}", uuid, id); - CLIENTS.get(&id).unwrap().0.remove(&uuid); + if let Some(client) = CLIENTS.get(&id) { + client.0.remove(&uuid); + } } pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { @@ -214,27 +230,33 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { .context("failed to create server multiplexor")? .with_no_required_extensions(); - debug!("new wisp client id {:?} connected with extensions {:?}", id, mux.supported_extension_ids); + debug!( + "new wisp client id {:?} connected with extensions {:?}", + id, mux.supported_extension_ids + ); let mut set: JoinSet<()> = JoinSet::new(); + let event: Arc = Event::new().into(); set.spawn(tokio::task::unconstrained(fut.map(|_| {}))); while let Some((connect, stream)) = mux.server_new_stream().await { - set.spawn(tokio::task::unconstrained(handle_stream( + set.spawn(handle_stream( connect, stream, id.clone(), + event.clone(), #[cfg(feature = "twisp")] twisp_map.clone(), - ))); + )); } + trace!("shutting down wisp client id {:?}", id); + let _ = mux.close().await; + event.notify(usize::MAX); - set.abort_all(); - - while set.join_next().await.is_some() {} + while set.join_next().await.is_some() {}; debug!("wisp client id {:?} disconnected", id);