use blazingly fast flume channels 🚀

This commit is contained in:
Toshit Chawda 2024-04-15 17:42:49 -07:00
parent 5af56fe582
commit 5e741d3808
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
11 changed files with 225 additions and 135 deletions

25
Cargo.lock generated
View file

@ -861,6 +861,18 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "flume"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"spin",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -1487,6 +1499,15 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "native-tls" name = "native-tls"
version = "0.2.11" version = "0.2.11"
@ -2273,6 +2294,9 @@ name = "spin"
version = "0.9.8" version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "strsim" name = "strsim"
@ -3203,6 +3227,7 @@ dependencies = [
"dashmap", "dashmap",
"event-listener", "event-listener",
"fastwebsockets 0.7.1", "fastwebsockets 0.7.1",
"flume",
"futures", "futures",
"futures-timer", "futures-timer",
"futures-util", "futures-util",

View file

@ -238,9 +238,9 @@ onmessage = async (msg) => {
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
} else { } else {
let resp = await epoxy_client.fetch("https://httpbin.org/get"); let resp = await epoxy_client.fetch("https://www.example.com/");
console.log(resp, Object.fromEntries(resp.headers)); console.log(resp, Object.fromEntries(resp.headers));
plog(await resp.json()); log(await resp.text());
} }
log("done"); log("done");
}; };

View file

@ -200,13 +200,10 @@ pub async fn make_mux(
), ),
WispError, WispError,
> { > {
let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) let (wtx, wrx) =
.await WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
.map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await; wtx.wait_for_open().await;
let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?; ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await
Ok(mux)
} }
pub fn spawn_mux_fut( pub fn spawn_mux_fut(
@ -215,6 +212,7 @@ pub fn spawn_mux_fut(
url: String, url: String,
) { ) {
wasm_bindgen_futures::spawn_local(async move { wasm_bindgen_futures::spawn_local(async move {
debug!("epoxy: mux future started");
if let Err(e) = fut.await { if let Err(e) = fut.await {
log!("epoxy: error in mux future, restarting: {:?}", e); log!("epoxy: error in mux future, restarting: {:?}", e);
while let Err(e) = replace_mux(mux.clone(), &url).await { while let Err(e) = replace_mux(mux.clone(), &url).await {
@ -229,7 +227,7 @@ pub fn spawn_mux_fut(
pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> { pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?; let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await; let mut mux_write = mux.write().await;
mux_write.close().await?; let _ = mux_write.close().await;
*mux_write = mux_replace; *mux_write = mux_replace;
drop(mux_write); drop(mux_write);
spawn_mux_fut(mux, fut, url.into()); spawn_mux_fut(mux, fut, url.into());

View file

@ -123,6 +123,7 @@ impl tower_service::Service<hyper::Uri> for TlsWispService {
let stream = service.call(uri_parsed).await?.into_inner(); let stream = service.call(uri_parsed).await?.into_inner();
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
let connector = TlsConnector::from(rustls_config); let connector = TlsConnector::from(rustls_config);
log!("got stream");
Ok(TokioIo::new(Either::Left( Ok(TokioIo::new(Either::Left(
connector connector
.connect( .connect(
@ -143,6 +144,7 @@ impl tower_service::Service<hyper::Uri> for TlsWispService {
pub enum WebSocketError { pub enum WebSocketError {
Unknown, Unknown,
SendFailed, SendFailed,
CloseFailed,
} }
impl std::fmt::Display for WebSocketError { impl std::fmt::Display for WebSocketError {
@ -151,6 +153,7 @@ impl std::fmt::Display for WebSocketError {
match self { match self {
Unknown => write!(f, "Unknown error"), Unknown => write!(f, "Unknown error"),
SendFailed => write!(f, "Send failed"), SendFailed => write!(f, "Send failed"),
CloseFailed => write!(f, "Close failed"),
} }
} }
} }
@ -213,7 +216,7 @@ impl WebSocketRead for WebSocketReader {
} }
impl WebSocketWrapper { impl WebSocketWrapper {
pub async fn connect( pub fn connect(
url: &str, url: &str,
protocols: Vec<String>, protocols: Vec<String>,
) -> Result<(Self, WebSocketReader), JsValue> { ) -> Result<(Self, WebSocketReader), JsValue> {
@ -327,6 +330,12 @@ impl WebSocketWrite for WebSocketWrapper {
_ => Err(WispError::WsImplNotSupported), _ => Err(WispError::WsImplNotSupported),
} }
} }
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.inner
.close()
.map_err(|_| WebSocketError::CloseFailed.into())
}
} }
impl Drop for WebSocketWrapper { impl Drop for WebSocketWrapper {

View file

@ -12,9 +12,13 @@ use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket};
#[cfg(unix)] #[cfg(unix)]
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tokio::{
io::{copy_bidirectional, split, BufReader, BufWriter},
net::{lookup_host, TcpListener, TcpStream, UdpSocket},
select,
};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
#[cfg(unix)] #[cfg(unix)]
use tokio_util::either::Either; use tokio_util::either::Either;
@ -22,9 +26,10 @@ use tokio_util::either::Either;
use wisp_mux::{ use wisp_mux::{
extensions::{ extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
}, },
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, CloseReason, ConnectPacket, IoStream, MuxStream, MuxStreamIo, ServerMux, StreamType, WispError,
}; };
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -182,7 +187,10 @@ async fn main() -> Result<(), Error> {
block_local: opt.block_local, block_local: opt.block_local,
block_non_http: opt.block_non_http, block_non_http: opt.block_non_http,
block_udp: opt.block_udp, block_udp: opt.block_udp,
auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]), auth: Arc::new(vec![
Box::new(UdpProtocolExtensionBuilder()),
Box::new(pw_ext),
]),
enforce_auth, enforce_auth,
}; };
@ -257,7 +265,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
.await .await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
let mut mux_stream = stream.into_io().into_asyncrw(); let mut mux_stream = stream.into_io().into_asyncrw();
tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream) copy_bidirectional(&mut mux_stream, &mut tcp_stream)
.await .await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
} }
@ -312,13 +320,7 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128 // size is set to 128
let (mut mux, fut) = if mux_options.enforce_auth { let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new( let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
rx,
tx,
128,
Some(mux_options.auth.as_slice()),
)
.await?;
if !mux if !mux
.supported_extension_ids .supported_extension_ids
.iter() .iter()
@ -333,7 +335,13 @@ async fn accept_ws(
} }
(mux, fut) (mux, fut)
} else { } else {
ServerMux::new(rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? ServerMux::new(
rx,
tx,
128,
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
)
.await?
}; };
println!( println!(
@ -388,10 +396,9 @@ async fn accept_ws(
}) })
.and_then(|should_send| async move { .and_then(|should_send| async move {
if should_send { if should_send {
close_ok.close(CloseReason::Voluntary).await let _ = close_ok.close(CloseReason::Voluntary).await;
} else {
Ok(())
} }
Ok(())
}) })
.await; .await;
}); });

View file

@ -253,7 +253,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
avg.get_average() * opts.packet_size, avg.get_average() * opts.packet_size,
); );
if is_term { if is_term {
print!("\x1b[2K{}\r", stat); println!("\x1b[1A\x1b[2K{}\r", stat);
} else { } else {
println!("{}", stat); println!("{}", stat);
} }
@ -284,6 +284,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let out = select_all(threads.into_iter()).await; let out = select_all(threads.into_iter()).await;
let duration_since = Instant::now().duration_since(start_time);
if let Err(err) = out.0? { if let Err(err) = out.0? {
println!("\n\nerr: {:?}", err); println!("\n\nerr: {:?}", err);
exit(1); exit(1);
@ -291,10 +293,10 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
out.2.into_iter().for_each(|x| x.abort()); out.2.into_iter().for_each(|x| x.abort());
let duration_since = Instant::now().duration_since(start_time); mux.close().await?;
println!( println!(
"\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(), cnt.get(),
opts.packet_size, opts.packet_size,
cnt.get() * opts.packet_size, cnt.get() * opts.packet_size,

View file

@ -15,6 +15,7 @@ bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] } dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0" event-listener = "5.0.0"
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
flume = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3" futures-timer = "3.0.3"
futures-util = "0.3.30" futures-util = "0.3.30"

View file

@ -3,7 +3,7 @@ use std::ops::Deref;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::BytesMut;
use fastwebsockets::{ use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite
}; };
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
@ -77,4 +77,8 @@ impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into()) self.write_frame(frame.into()).await.map_err(|e| e.into())
} }
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into())
}
} }

View file

@ -1,4 +1,4 @@
#![deny(missing_docs)] #![deny(missing_docs, warnings)]
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
//! A library for easily creating [Wisp] clients and servers. //! A library for easily creating [Wisp] clients and servers.
//! //!
@ -19,9 +19,8 @@ use bytes::Bytes;
use dashmap::DashMap; use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use futures::{ use flume as mpsc;
channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt use futures::{channel::oneshot, select, Future, FutureExt};
};
use futures_timer::Delay; use futures_timer::Delay;
use std::{ use std::{
sync::{ sync::{
@ -151,11 +150,12 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {} impl std::error::Error for WispError {}
struct MuxMapValue { struct MuxMapValue {
stream: Mutex<mpsc::Sender<Bytes>>, stream: mpsc::Sender<Bytes>,
stream_type: StreamType, stream_type: StreamType,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>, flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
} }
struct MuxInner { struct MuxInner {
@ -170,7 +170,7 @@ impl MuxInner {
rx: R, rx: R,
extensions: Vec<AnyProtocolExtension>, extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>, close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
@ -210,20 +210,60 @@ impl MuxInner {
}; };
for x in self.stream_map.iter_mut() { for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release); x.is_closed.store(true, Ordering::Release);
x.stream.lock().await.disconnect(); x.is_closed_event.notify(usize::MAX);
x.stream.lock().await.close_channel();
} }
self.stream_map.clear(); self.stream_map.clear();
let _ = self.tx.close().await;
ret ret
} }
async fn create_new_stream(
&self,
stream_id: u32,
stream_type: StreamType,
role: Role,
stream_tx: mpsc::Sender<WsEvent>,
target_buffer_size: u32,
) -> Result<(MuxMapValue, MuxStream), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let is_closed_event: Arc<Event> = Event::new().into();
Ok((
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
},
MuxStream::new(
stream_id,
role,
stream_type,
ch_rx,
stream_tx.clone(),
is_closed,
is_closed_event,
flow_control,
flow_control_event,
target_buffer_size,
),
))
}
async fn stream_loop( async fn stream_loop(
&self, &self,
mut stream_rx: mpsc::Receiver<WsEvent>, stream_rx: mpsc::Receiver<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>, stream_tx: mpsc::Sender<WsEvent>,
) { ) {
let mut next_free_stream_id: u32 = 1; let mut next_free_stream_id: u32 = 1;
while let Some(msg) = stream_rx.next().await { while let Ok(msg) = stream_rx.recv_async().await {
match msg { match msg {
WsEvent::SendPacket(packet, channel) => { WsEvent::SendPacket(packet, channel) => {
if self.stream_map.get(&packet.stream_id).is_some() { if self.stream_map.get(&packet.stream_id).is_some() {
@ -234,16 +274,20 @@ impl MuxInner {
} }
WsEvent::CreateStream(stream_type, host, port, channel) => { WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async { let ret: Result<MuxStream, WispError> = async {
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_id = next_free_stream_id; let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id let next_stream_id = next_free_stream_id
.checked_add(1) .checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?; .ok_or(WispError::MaxStreamCountReached)?;
let flow_control_event: Arc<Event> = Event::new().into(); let (map_value, stream) = self
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into(); .create_new_stream(
stream_id,
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into(); stream_type,
Role::Client,
stream_tx.clone(),
0,
)
.await?;
self.tx self.tx
.write_frame( .write_frame(
@ -251,39 +295,19 @@ impl MuxInner {
) )
.await?; .await?;
self.stream_map.insert(stream_id, map_value);
next_free_stream_id = next_stream_id; next_free_stream_id = next_stream_id;
self.stream_map.insert( Ok(stream)
stream_id,
MuxMapValue {
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
Ok(MuxStream::new(
stream_id,
Role::Client,
stream_type,
ch_rx,
stream_tx.clone(),
is_closed,
flow_control,
flow_control_event,
0,
))
} }
.await; .await;
let _ = channel.send(ret); let _ = channel.send(ret);
} }
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else { } else {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
@ -305,8 +329,8 @@ impl MuxInner {
&self, &self,
mut rx: R, mut rx: R,
mut extensions: Vec<AnyProtocolExtension>, mut extensions: Vec<AnyProtocolExtension>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>, stream_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
@ -325,42 +349,24 @@ impl MuxInner {
use PacketType::*; use PacketType::*;
match packet.packet_type { match packet.packet_type {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let (map_value, stream) = self
let stream_type = inner_packet.stream_type; .create_new_stream(
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into(); packet.stream_id,
let flow_control_event: Arc<Event> = Event::new().into(); inner_packet.stream_type,
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into(); Role::Server,
stream_tx.clone(),
self.stream_map.insert( target_buffer_size,
packet.stream_id, )
MuxMapValue { .await?;
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
muxstream_sender muxstream_sender
.unbounded_send(( .send_async((inner_packet, stream))
inner_packet, .await
MuxStream::new( .map_err(|_| WispError::MuxMessageFailedToSend)?;
packet.stream_id, self.stream_map.insert(packet.stream_id, map_value);
Role::Server,
stream_type,
ch_rx,
close_tx.clone(),
is_closed,
flow_control,
flow_control_event,
target_buffer_size,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.lock().await.send(data).await; let _ = stream.stream.send_async(data).await;
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream.flow_control.store( stream.flow_control.store(
stream stream
@ -379,8 +385,8 @@ impl MuxInner {
} }
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.lock().await.disconnect(); stream.is_closed_event.notify(usize::MAX);
stream.stream.lock().await.close_channel(); drop(stream.stream)
} }
} }
} }
@ -409,7 +415,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.lock().await.send(data).await; let _ = stream.stream.send_async(data).await;
} }
} }
Continue(inner_packet) => { Continue(inner_packet) => {
@ -428,8 +434,8 @@ impl MuxInner {
} }
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.lock().await.disconnect(); stream.is_closed_event.notify(usize::MAX);
stream.stream.lock().await.close_channel(); drop(stream.stream)
} }
} }
} }
@ -465,7 +471,7 @@ pub struct ServerMux {
/// Extensions that are supported by both sides. /// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>, pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
} }
impl ServerMux { impl ServerMux {
@ -484,7 +490,7 @@ impl ServerMux {
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
{ {
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256); let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(Box::new(write)); let write = ws::LockedWebSocketWrite::new(Box::new(write));
@ -547,12 +553,12 @@ impl ServerMux {
/// Wait for a stream to be created. /// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.next().await self.muxstream_recv.recv_async().await.ok()
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx self.close_tx
.send(WsEvent::EndFut(reason)) .send_async(WsEvent::EndFut(reason))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend) .map_err(|_| WispError::MuxMessageFailedToSend)
} }
@ -574,6 +580,13 @@ impl ServerMux {
.await .await
} }
} }
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.close_tx.send(WsEvent::EndFut(None));
}
}
/// Client side multiplexor. /// Client side multiplexor.
/// ///
/// # Example /// # Example
@ -595,7 +608,7 @@ pub struct ClientMux {
pub downgraded: bool, pub downgraded: bool,
/// Extensions that are supported by both sides. /// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>, pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>, stream_tx: mpsc::Sender<WsEvent>,
} }
impl ClientMux { impl ClientMux {
@ -654,10 +667,10 @@ impl ClientMux {
extension.handle_handshake(&mut read, &write).await?; extension.handle_handshake(&mut read, &write).await?;
} }
let (tx, rx) = mpsc::channel::<WsEvent>(256); let (tx, rx) = mpsc::bounded::<WsEvent>(256);
Ok(( Ok((
Self { Self {
close_tx: tx.clone(), stream_tx: tx.clone(),
downgraded, downgraded,
supported_extension_ids: supported_extensions supported_extension_ids: supported_extensions
.iter() .iter()
@ -697,16 +710,16 @@ impl ClientMux {
return Err(WispError::UdpExtensionNotSupported); return Err(WispError::UdpExtensionNotSupported);
} }
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
self.close_tx self.stream_tx
.send(WsEvent::CreateStream(stream_type, host, port, tx)) .send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx self.stream_tx
.send(WsEvent::EndFut(reason)) .send_async(WsEvent::EndFut(reason))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend) .map_err(|_| WispError::MuxMessageFailedToSend)
} }
@ -728,3 +741,9 @@ impl ClientMux {
.await .await
} }
} }
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.stream_tx.send(WsEvent::EndFut(None));
}
}

View file

@ -1,13 +1,14 @@
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError}; use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
use async_io_stream::IoStream; pub use async_io_stream::IoStream;
use bytes::Bytes; use bytes::Bytes;
use event_listener::Event; use event_listener::Event;
use flume as mpsc;
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::oneshot,
stream, select, stream,
task::{Context, Poll}, task::{Context, Poll},
Sink, SinkExt, Stream, StreamExt, FutureExt, Sink, Stream,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -40,6 +41,7 @@ pub struct MuxStreamRead {
tx: mpsc::Sender<WsEvent>, tx: mpsc::Sender<WsEvent>,
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32, flow_control_read: AtomicU32,
target_flow_control: u32, target_flow_control: u32,
@ -51,13 +53,16 @@ impl MuxStreamRead {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return None; return None;
} }
let bytes = self.rx.next().await?; let bytes = select! {
x = self.rx.recv_async() => x.ok()?,
_ = self.is_closed_event.listen().fuse() => return None
};
if self.role == Role::Server && self.stream_type == StreamType::Tcp { if self.role == Role::Server && self.stream_type == StreamType::Tcp {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
if val > self.target_flow_control { if val > self.target_flow_control {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.send(WsEvent::SendPacket( .send_async(WsEvent::SendPacket(
Packet::new_continue( Packet::new_continue(
self.stream_id, self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val, self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
@ -107,13 +112,13 @@ impl MuxStreamWrite {
} }
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.send(WsEvent::SendPacket( .send_async(WsEvent::SendPacket(
Packet::new_data(self.stream_id, data), Packet::new_data(self.stream_id, data),
tx, tx,
)) ))
.await .await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store( self.flow_control.store(
self.flow_control.load(Ordering::Acquire).saturating_sub(1), self.flow_control.load(Ordering::Acquire).saturating_sub(1),
@ -151,13 +156,13 @@ impl MuxStreamWrite {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.send(WsEvent::Close( .send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await .await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(()) Ok(())
} }
@ -179,6 +184,16 @@ impl MuxStreamWrite {
} }
} }
impl Drop for MuxStreamWrite {
fn drop(&mut self) {
if !self.is_closed.load(Ordering::Acquire) {
self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel();
let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx));
}
}
}
/// Multiplexor stream. /// Multiplexor stream.
pub struct MuxStream { pub struct MuxStream {
/// ID of the stream. /// ID of the stream.
@ -196,6 +211,7 @@ impl MuxStream {
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Bytes>,
tx: mpsc::Sender<WsEvent>, tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
target_flow_control: u32, target_flow_control: u32,
@ -209,6 +225,7 @@ impl MuxStream {
tx: tx.clone(), tx: tx.clone(),
rx, rx,
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0), flow_control_read: AtomicU32::new(0),
target_flow_control, target_flow_control,
@ -288,13 +305,13 @@ impl MuxStreamCloser {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel self.close_channel
.send(WsEvent::Close( .send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await .await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(()) Ok(())
} }

View file

@ -76,6 +76,9 @@ pub trait WebSocketRead {
pub trait WebSocketWrite { pub trait WebSocketWrite {
/// Write a frame to the socket. /// Write a frame to the socket.
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
/// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>;
} }
/// Locked WebSocket. /// Locked WebSocket.
@ -88,9 +91,14 @@ impl LockedWebSocketWrite {
} }
/// Write a frame to the websocket. /// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await self.0.lock().await.wisp_write_frame(frame).await
} }
/// Close the websocket.
pub async fn close(&self) -> Result<(), WispError> {
self.0.lock().await.wisp_close().await
}
} }
pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R) pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)