cancel wisp worker tasks a bit more robustly

This commit is contained in:
Toshit Chawda 2024-08-24 21:24:34 -07:00
parent 52da4eb0fb
commit 0f05588b8f
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 114 additions and 90 deletions

1
Cargo.lock generated
View file

@ -557,6 +557,7 @@ dependencies = [
"clap",
"dashmap",
"env_logger",
"event-listener",
"fastwebsockets",
"futures-util",
"http-body-util",

View file

@ -11,6 +11,7 @@ cfg-if = "1.0.0"
clap = { version = "4.5.16", features = ["cargo", "derive"] }
dashmap = "6.0.1"
env_logger = "0.11.5"
event-listener = "5.3.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
futures-util = "0.3.30"
http-body-util = "0.1.2"

View file

@ -1,5 +1,8 @@
use std::sync::Arc;
use anyhow::Context;
use cfg_if::cfg_if;
use event_listener::Event;
use futures_util::FutureExt;
use log::{debug, trace};
use tokio::{
@ -11,8 +14,7 @@ use tokio::{
use tokio_util::compat::FuturesAsyncReadCompatExt;
use uuid::Uuid;
use wisp_mux::{
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite,
ServerMux,
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux,
};
use crate::{
@ -56,6 +58,7 @@ async fn handle_stream(
connect: ConnectPacket,
muxstream: MuxStream,
id: String,
event: Arc<Event>,
#[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap,
) {
let requested_stream = connect.clone();
@ -91,14 +94,19 @@ async fn handle_stream(
let uuid = Uuid::new_v4();
trace!("new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}", id, uuid, requested_stream, resolved_stream);
trace!(
"new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}",
id,
uuid,
requested_stream,
resolved_stream
);
CLIENTS
.get(&id)
.unwrap()
.0
.insert(uuid, (requested_stream, resolved_stream));
if let Some(client) = CLIENTS.get(&id) {
client.0.insert(uuid, (requested_stream, resolved_stream));
}
let forward_fut = async {
match stream {
ClientStream::Tcp(stream) => {
let closer = muxstream.get_close_handle();
@ -182,10 +190,18 @@ async fn handle_stream(
.await;
}
};
};
select! {
x = forward_fut => x,
x = event.listen() => x,
};
trace!("stream uuid {:?} disconnected for client id {:?}", uuid, id);
CLIENTS.get(&id).unwrap().0.remove(&uuid);
if let Some(client) = CLIENTS.get(&id) {
client.0.remove(&uuid);
}
}
pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
@ -214,27 +230,33 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
.context("failed to create server multiplexor")?
.with_no_required_extensions();
debug!("new wisp client id {:?} connected with extensions {:?}", id, mux.supported_extension_ids);
debug!(
"new wisp client id {:?} connected with extensions {:?}",
id, mux.supported_extension_ids
);
let mut set: JoinSet<()> = JoinSet::new();
let event: Arc<Event> = Event::new().into();
set.spawn(tokio::task::unconstrained(fut.map(|_| {})));
while let Some((connect, stream)) = mux.server_new_stream().await {
set.spawn(tokio::task::unconstrained(handle_stream(
set.spawn(handle_stream(
connect,
stream,
id.clone(),
event.clone(),
#[cfg(feature = "twisp")]
twisp_map.clone(),
)));
));
}
trace!("shutting down wisp client id {:?}", id);
let _ = mux.close().await;
event.notify(usize::MAX);
set.abort_all();
while set.join_next().await.is_some() {}
while set.join_next().await.is_some() {};
debug!("wisp client id {:?} disconnected", id);