mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
cancel wisp worker tasks a bit more robustly
This commit is contained in:
parent
52da4eb0fb
commit
0f05588b8f
3 changed files with 114 additions and 90 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -557,6 +557,7 @@ dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"event-listener",
|
||||||
"fastwebsockets",
|
"fastwebsockets",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
|
|
|
@ -11,6 +11,7 @@ cfg-if = "1.0.0"
|
||||||
clap = { version = "4.5.16", features = ["cargo", "derive"] }
|
clap = { version = "4.5.16", features = ["cargo", "derive"] }
|
||||||
dashmap = "6.0.1"
|
dashmap = "6.0.1"
|
||||||
env_logger = "0.11.5"
|
env_logger = "0.11.5"
|
||||||
|
event-listener = "5.3.1"
|
||||||
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
|
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
http-body-util = "0.1.2"
|
http-body-util = "0.1.2"
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use cfg_if::cfg_if;
|
use cfg_if::cfg_if;
|
||||||
|
use event_listener::Event;
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
|
@ -11,8 +14,7 @@ use tokio::{
|
||||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use wisp_mux::{
|
use wisp_mux::{
|
||||||
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite,
|
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux,
|
||||||
ServerMux,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -56,6 +58,7 @@ async fn handle_stream(
|
||||||
connect: ConnectPacket,
|
connect: ConnectPacket,
|
||||||
muxstream: MuxStream,
|
muxstream: MuxStream,
|
||||||
id: String,
|
id: String,
|
||||||
|
event: Arc<Event>,
|
||||||
#[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap,
|
#[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap,
|
||||||
) {
|
) {
|
||||||
let requested_stream = connect.clone();
|
let requested_stream = connect.clone();
|
||||||
|
@ -91,101 +94,114 @@ async fn handle_stream(
|
||||||
|
|
||||||
let uuid = Uuid::new_v4();
|
let uuid = Uuid::new_v4();
|
||||||
|
|
||||||
trace!("new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}", id, uuid, requested_stream, resolved_stream);
|
trace!(
|
||||||
|
"new stream created for client id {:?}: (stream uuid {:?}) {:?} {:?}",
|
||||||
|
id,
|
||||||
|
uuid,
|
||||||
|
requested_stream,
|
||||||
|
resolved_stream
|
||||||
|
);
|
||||||
|
|
||||||
CLIENTS
|
if let Some(client) = CLIENTS.get(&id) {
|
||||||
.get(&id)
|
client.0.insert(uuid, (requested_stream, resolved_stream));
|
||||||
.unwrap()
|
}
|
||||||
.0
|
|
||||||
.insert(uuid, (requested_stream, resolved_stream));
|
|
||||||
|
|
||||||
match stream {
|
let forward_fut = async {
|
||||||
ClientStream::Tcp(stream) => {
|
match stream {
|
||||||
let closer = muxstream.get_close_handle();
|
ClientStream::Tcp(stream) => {
|
||||||
|
let closer = muxstream.get_close_handle();
|
||||||
|
|
||||||
let ret: anyhow::Result<()> = async {
|
let ret: anyhow::Result<()> = async {
|
||||||
let (muxread, muxwrite) = muxstream.into_split();
|
let (muxread, muxwrite) = muxstream.into_split();
|
||||||
let muxread = muxread.into_stream().into_asyncread();
|
let muxread = muxread.into_stream().into_asyncread();
|
||||||
let (tcpread, tcpwrite) = stream.into_split();
|
let (tcpread, tcpwrite) = stream.into_split();
|
||||||
select! {
|
|
||||||
x = copy_read_fast(muxread, tcpwrite) => x?,
|
|
||||||
x = copy_write_fast(muxwrite, tcpread) => x?,
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match ret {
|
|
||||||
Ok(()) => {
|
|
||||||
let _ = closer.close(CloseReason::Voluntary).await;
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let _ = closer.close(CloseReason::Unexpected).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ClientStream::Udp(stream) => {
|
|
||||||
let closer = muxstream.get_close_handle();
|
|
||||||
|
|
||||||
let ret: anyhow::Result<()> = async move {
|
|
||||||
let mut data = vec![0u8; 65507];
|
|
||||||
loop {
|
|
||||||
select! {
|
select! {
|
||||||
size = stream.recv(&mut data) => {
|
x = copy_read_fast(muxread, tcpwrite) => x?,
|
||||||
let size = size?;
|
x = copy_write_fast(muxwrite, tcpread) => x?,
|
||||||
muxstream.write(&data[..size]).await?;
|
}
|
||||||
}
|
Ok(())
|
||||||
data = muxstream.read() => {
|
}
|
||||||
if let Some(data) = data {
|
.await;
|
||||||
stream.send(&data).await?;
|
|
||||||
} else {
|
match ret {
|
||||||
break Ok(());
|
Ok(()) => {
|
||||||
|
let _ = closer.close(CloseReason::Voluntary).await;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let _ = closer.close(CloseReason::Unexpected).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ClientStream::Udp(stream) => {
|
||||||
|
let closer = muxstream.get_close_handle();
|
||||||
|
|
||||||
|
let ret: anyhow::Result<()> = async move {
|
||||||
|
let mut data = vec![0u8; 65507];
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
size = stream.recv(&mut data) => {
|
||||||
|
let size = size?;
|
||||||
|
muxstream.write(&data[..size]).await?;
|
||||||
|
}
|
||||||
|
data = muxstream.read() => {
|
||||||
|
if let Some(data) = data {
|
||||||
|
stream.send(&data).await?;
|
||||||
|
} else {
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match ret {
|
|
||||||
Ok(()) => {
|
|
||||||
let _ = closer.close(CloseReason::Voluntary).await;
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let _ = closer.close(CloseReason::Unexpected).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[cfg(feature = "twisp")]
|
|
||||||
ClientStream::Pty(cmd, pty) => {
|
|
||||||
let closer = muxstream.get_close_handle();
|
|
||||||
let id = muxstream.stream_id;
|
|
||||||
let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split();
|
|
||||||
|
|
||||||
match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(()) => {
|
|
||||||
let _ = closer.close(CloseReason::Voluntary).await;
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let _ = closer.close(CloseReason::Unexpected).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ClientStream::Invalid => {
|
|
||||||
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
|
|
||||||
}
|
|
||||||
ClientStream::Blocked => {
|
|
||||||
let _ = muxstream
|
|
||||||
.close(CloseReason::ServerStreamBlockedAddress)
|
|
||||||
.await;
|
.await;
|
||||||
}
|
|
||||||
|
match ret {
|
||||||
|
Ok(()) => {
|
||||||
|
let _ = closer.close(CloseReason::Voluntary).await;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let _ = closer.close(CloseReason::Unexpected).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(feature = "twisp")]
|
||||||
|
ClientStream::Pty(cmd, pty) => {
|
||||||
|
let closer = muxstream.get_close_handle();
|
||||||
|
let id = muxstream.stream_id;
|
||||||
|
let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split();
|
||||||
|
|
||||||
|
match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(()) => {
|
||||||
|
let _ = closer.close(CloseReason::Voluntary).await;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let _ = closer.close(CloseReason::Unexpected).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ClientStream::Invalid => {
|
||||||
|
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
|
||||||
|
}
|
||||||
|
ClientStream::Blocked => {
|
||||||
|
let _ = muxstream
|
||||||
|
.close(CloseReason::ServerStreamBlockedAddress)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
select! {
|
||||||
|
x = forward_fut => x,
|
||||||
|
x = event.listen() => x,
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("stream uuid {:?} disconnected for client id {:?}", uuid, id);
|
trace!("stream uuid {:?} disconnected for client id {:?}", uuid, id);
|
||||||
|
|
||||||
CLIENTS.get(&id).unwrap().0.remove(&uuid);
|
if let Some(client) = CLIENTS.get(&id) {
|
||||||
|
client.0.remove(&uuid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
||||||
|
@ -214,27 +230,33 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
||||||
.context("failed to create server multiplexor")?
|
.context("failed to create server multiplexor")?
|
||||||
.with_no_required_extensions();
|
.with_no_required_extensions();
|
||||||
|
|
||||||
debug!("new wisp client id {:?} connected with extensions {:?}", id, mux.supported_extension_ids);
|
debug!(
|
||||||
|
"new wisp client id {:?} connected with extensions {:?}",
|
||||||
|
id, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
|
||||||
let mut set: JoinSet<()> = JoinSet::new();
|
let mut set: JoinSet<()> = JoinSet::new();
|
||||||
|
let event: Arc<Event> = Event::new().into();
|
||||||
|
|
||||||
set.spawn(tokio::task::unconstrained(fut.map(|_| {})));
|
set.spawn(tokio::task::unconstrained(fut.map(|_| {})));
|
||||||
|
|
||||||
while let Some((connect, stream)) = mux.server_new_stream().await {
|
while let Some((connect, stream)) = mux.server_new_stream().await {
|
||||||
set.spawn(tokio::task::unconstrained(handle_stream(
|
set.spawn(handle_stream(
|
||||||
connect,
|
connect,
|
||||||
stream,
|
stream,
|
||||||
id.clone(),
|
id.clone(),
|
||||||
|
event.clone(),
|
||||||
#[cfg(feature = "twisp")]
|
#[cfg(feature = "twisp")]
|
||||||
twisp_map.clone(),
|
twisp_map.clone(),
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trace!("shutting down wisp client id {:?}", id);
|
||||||
|
|
||||||
let _ = mux.close().await;
|
let _ = mux.close().await;
|
||||||
|
event.notify(usize::MAX);
|
||||||
|
|
||||||
set.abort_all();
|
while set.join_next().await.is_some() {};
|
||||||
|
|
||||||
while set.join_next().await.is_some() {}
|
|
||||||
|
|
||||||
debug!("wisp client id {:?} disconnected", id);
|
debug!("wisp client id {:?} disconnected", id);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue