From 7001ee8fa567b4b691cb671871b6c8cb574e2357 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 26 Mar 2024 18:55:54 -0700 Subject: [PATCH] remove the mutex in wisp_mux, other improvements --- Cargo.lock | 22 ++- Cargo.toml | 6 +- client/Cargo.toml | 2 +- client/package.json | 2 +- client/src/lib.rs | 2 +- client/src/utils.rs | 2 +- client/src/wrappers.rs | 2 +- server/src/main.rs | 8 +- simple-wisp-client/Cargo.toml | 2 + simple-wisp-client/src/main.rs | 89 +++++++-- wisp/Cargo.toml | 3 +- wisp/src/fastwebsockets.rs | 8 +- wisp/src/lib.rs | 341 ++++++++++++++------------------- wisp/src/packet.rs | 43 +++-- wisp/src/sink_unfold.rs | 51 ++++- wisp/src/stream.rs | 72 +++---- 16 files changed, 346 insertions(+), 309 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6cfaa15..a77d2ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,6 +153,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "atomic-counter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62f447d68cfa5a9ab0c1c862a703da2a65b5ed1b7ce1153c9eb0169506d56019" + [[package]] name = "autocfg" version = "1.1.0" @@ -516,7 +522,7 @@ checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "epoxy-client" -version = "1.4.2" +version = "1.5.0" dependencies = [ "async-compression", "async_io_stream", @@ -1727,18 +1733,29 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" name = "simple-wisp-client" version = "1.0.0" dependencies = [ + "atomic-counter", "bytes", "console-subscriber", "fastwebsockets 0.7.1", "futures", "http-body-util", "hyper 1.2.0", + "simple_moving_average", "tokio", "tokio-native-tls", "tokio-util", "wisp-mux", ] +[[package]] +name = "simple_moving_average" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a4b144ad185430cd033299e2c93e465d5a7e65fbb858593dc57181fa13cd310" +dependencies = [ + "num-traits", +] + [[package]] name = "slab" version = "0.4.9" @@ -2500,10 +2517,11 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "wisp-mux" -version = "2.0.2" +version = "3.0.0" dependencies = [ "async_io_stream", "bytes", + "dashmap", "event-listener", "fastwebsockets 0.7.1", "futures", diff --git a/Cargo.toml b/Cargo.toml index 3f17cef..c6f7363 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,10 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 'z' -strip = true +debug = true panic = "abort" codegen-units = 1 +opt-level = 3 +[profile.release.package.epoxy-client] +opt-level = 'z' diff --git a/client/Cargo.toml b/client/Cargo.toml index 4f0c185..b24b74b 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "epoxy-client" -version = "1.4.2" +version = "1.5.0" edition = "2021" license = "LGPL-3.0-only" diff --git a/client/package.json b/client/package.json index b5b6ffd..61c6af3 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@mercuryworkshop/epoxy-tls", - "version": "1.4.2", + "version": "1.5.0", "description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser", "scripts": { "build": "./build.sh" diff --git a/client/src/lib.rs b/client/src/lib.rs index 4245876..6a6217c 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -47,8 +47,8 @@ enum EpxCompression { Gzip, } -type EpxIoTlsStream = TlsStream>>; type EpxIoUnencryptedStream = IoStream>; +type EpxIoTlsStream = TlsStream; type EpxIoStream = Either; #[wasm_bindgen(start)] diff --git a/client/src/utils.rs b/client/src/utils.rs index 11d77c9..1fdcf2e 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -231,7 +231,7 @@ pub async fn replace_mux( ) -> Result<(), WispError> { let (mux_replace, fut) = make_mux(url).await?; let mut mux_write = mux.write().await; - mux_write.close().await; + mux_write.close().await?; *mux_write = mux_replace; drop(mux_write); spawn_mux_fut(mux, fut, url.into()); diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 3c8caf8..9b16525 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -56,7 +56,7 @@ impl Stream for IncomingBody { pub struct ServiceWrapper(pub Arc>>, pub String); impl tower_service::Service for ServiceWrapper { - type Response = TokioIo>>; + type Response = TokioIo; type Error = WispError; type Future = impl Future>; diff --git a/server/src/main.rs b/server/src/main.rs index 5c90ce3..f76b0e2 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -239,7 +239,7 @@ async fn accept_ws( println!("{:?}: connected", addr); - let (mut mux, fut) = ServerMux::new(rx, tx, 128); + let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX); tokio::spawn(async move { if let Err(e) = fut.await { @@ -247,7 +247,7 @@ async fn accept_ws( } }); - while let Some((packet, stream)) = mux.server_new_stream().await { + while let Some((packet, mut stream)) = mux.server_new_stream().await { tokio::spawn(async move { if block_local { match lookup_host(format!( @@ -272,8 +272,8 @@ async fn accept_ws( } } } - let close_err = stream.get_close_handle(); - let close_ok = stream.get_close_handle(); + let mut close_err = stream.get_close_handle(); + let mut close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) .or_else(|err| async move { let _ = close_err.close(CloseReason::Unexpected).await; diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index a6df806..fed1966 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -4,12 +4,14 @@ version = "1.0.0" edition = "2021" [dependencies] +atomic-counter = "1.0.1" bytes = "1.5.0" console-subscriber = { version = "0.2.0", optional = true } fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } futures = "0.3.30" http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["http1", "client"] } +simple_moving_average = "1.0.2" tokio = { version = "1.36.0", features = ["full"] } tokio-native-tls = "0.3.1" tokio-util = "0.7.10" diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 39a5c26..6a78943 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -1,16 +1,25 @@ +use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; use fastwebsockets::{handshake, FragmentCollectorRead}; -use futures::io::AsyncWriteExt; +use futures::future::select_all; use http_body_util::Empty; use hyper::{ header::{CONNECTION, UPGRADE}, Request, }; -use std::{error::Error, future::Future}; -use tokio::net::TcpStream; +use simple_moving_average::{SingleSumSMA, SMA}; +use std::{ + error::Error, + future::Future, + io::{stdout, IsTerminal, Write}, + sync::Arc, + time::Duration, + usize, +}; +use tokio::{net::TcpStream, time::interval}; use tokio_native_tls::{native_tls, TlsConnector}; -use wisp_mux::{ClientMux, StreamType}; use tokio_util::either::Either; +use wisp_mux::{ClientMux, StreamType, WispError}; #[derive(Debug)] struct StrError(String); @@ -70,6 +79,18 @@ async fn main() -> Result<(), Box> { .nth(6) .ok_or(StrError::new("no should tls"))? .parse()?; + let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?; + + println!( + "connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}", + if should_tls { "wss" } else { "ws" }, + addr, + addr_port, + addr_path, + addr_dest, + addr_dest_port, + thread_cnt + ); let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; let socket = if should_tls { @@ -98,23 +119,59 @@ async fn main() -> Result<(), Box> { let rx = FragmentCollectorRead::new(rx); let (mux, fut) = ClientMux::new(rx, tx).await?; + let mut threads = Vec::with_capacity(thread_cnt + 1); - tokio::task::spawn(async move { println!("err: {:?}", fut.await); }); + threads.push(tokio::spawn(fut)); - let mut hi: u64 = 0; - loop { + let payload = Bytes::from_static(&[0; 1024]); + + let cnt = Arc::new(RelaxedCounter::new(0)); + + for _ in 0..thread_cnt { let mut channel = mux .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) - .await? - .into_io() - .into_asyncrw(); - for _ in 0..256 { - channel.write_all(b"hiiiiiiii").await?; - hi += 1; - println!("said hi {}", hi); - } + .await?; + let cnt = cnt.clone(); + let payload = payload.clone(); + threads.push(tokio::spawn(async move { + loop { + channel.write(payload.clone()).await?; + channel.read().await; + cnt.inc(); + } + #[allow(unreachable_code)] + Ok::<(), WispError>(()) + })); } - #[allow(unreachable_code)] + threads.push(tokio::spawn(async move { + let mut interval = interval(Duration::from_millis(100)); + let mut avg: SingleSumSMA = SingleSumSMA::new(); + let mut last_time = 0; + let is_term = stdout().is_terminal(); + loop { + interval.tick().await; + let now = cnt.get(); + let stat = format!( + "sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}", + now, + now - last_time, + avg.get_average() + ); + if is_term { + print!("\x1b[2K{}\r", stat); + } else { + println!("{}", stat); + } + stdout().flush().unwrap(); + avg.add_sample(now - last_time); + last_time = now; + } + })); + + let out = select_all(threads.into_iter()).await; + + println!("\n\nout: {:?}", out.0); + Ok(()) } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 2ed2f1b..473640d 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "2.0.2" +version = "3.0.0" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -11,6 +11,7 @@ edition = "2021" [dependencies] async_io_stream = "0.3.3" bytes = "1.5.0" +dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } futures = "0.3.30" diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index cba77d2..7a66908 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -8,7 +8,9 @@ impl From for crate::ws::OpCode { fn from(opcode: OpCode) -> Self { use OpCode::*; match opcode { - Continuation => unreachable!("continuation should never be recieved when using a fragmentcollector"), + Continuation => { + unreachable!("continuation should never be recieved when using a fragmentcollector") + } Text => Self::Text, Binary => Self::Binary, Close => Self::Close, @@ -70,8 +72,6 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { - self.write_frame(frame.into()) - .await - .map_err(|e| e.into()) + self.write_frame(frame.into()).await.map_err(|e| e.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index f657db3..152be13 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -16,14 +16,13 @@ pub use crate::packet::*; pub use crate::stream::*; use bytes::Bytes; +use dashmap::DashMap; use event_listener::Event; -use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt}; -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, - }, +use futures::SinkExt; +use futures::{channel::mpsc, Future, FutureExt, StreamExt}; +use std::sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, }; /// The role of the multiplexor. @@ -72,6 +71,8 @@ pub enum WispError { Utf8Error(std::str::Utf8Error), /// Other error. Other(Box), + /// Failed to send message to multiplexor task. + MuxMessageFailedToSend, } impl From for WispError { @@ -82,25 +83,29 @@ impl From for WispError { impl std::fmt::Display for WispError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - use WispError::*; match self { - PacketTooSmall => write!(f, "Packet too small"), - InvalidPacketType => write!(f, "Invalid packet type"), - InvalidStreamType => write!(f, "Invalid stream type"), - InvalidStreamId => write!(f, "Invalid stream id"), - InvalidCloseReason => write!(f, "Invalid close reason"), - InvalidUri => write!(f, "Invalid URI"), - UriHasNoHost => write!(f, "URI has no host"), - UriHasNoPort => write!(f, "URI has no port"), - MaxStreamCountReached => write!(f, "Maximum stream count reached"), - StreamAlreadyClosed => write!(f, "Stream already closed"), - WsFrameInvalidType => write!(f, "Invalid websocket frame type"), - WsFrameNotFinished => write!(f, "Unfinished websocket frame"), - WsImplError(err) => write!(f, "Websocket implementation error: {}", err), - WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"), - WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), - Utf8Error(err) => write!(f, "UTF-8 error: {}", err), - Other(err) => write!(f, "Other error: {}", err), + Self::PacketTooSmall => write!(f, "Packet too small"), + Self::InvalidPacketType => write!(f, "Invalid packet type"), + Self::InvalidStreamType => write!(f, "Invalid stream type"), + Self::InvalidStreamId => write!(f, "Invalid stream id"), + Self::InvalidCloseReason => write!(f, "Invalid close reason"), + Self::InvalidUri => write!(f, "Invalid URI"), + Self::UriHasNoHost => write!(f, "URI has no host"), + Self::UriHasNoPort => write!(f, "URI has no port"), + Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), + Self::StreamAlreadyClosed => write!(f, "Stream already closed"), + Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), + Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), + Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err), + Self::WsImplSocketClosed => { + write!(f, "Websocket implementation error: websocket closed") + } + Self::WsImplNotSupported => { + write!(f, "Websocket implementation error: unsupported feature") + } + Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), + Self::Other(err) => write!(f, "Other error: {}", err), + Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), } } } @@ -115,61 +120,74 @@ struct MuxMapValue { is_closed: Arc, } -struct ServerMuxInner +struct MuxInner where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, - close_tx: mpsc::UnboundedSender, + stream_map: Arc>, } -impl ServerMuxInner { - pub async fn into_future( +impl MuxInner { + pub async fn server_into_future( self, rx: R, - close_rx: mpsc::UnboundedReceiver, + close_rx: mpsc::Receiver, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, buffer_size: u32, + close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { + self.into_future( + close_rx, + self.server_loop(rx, muxstream_sender, buffer_size, close_tx), + ) + .await + } + + pub async fn client_into_future( + self, + rx: R, + close_rx: mpsc::Receiver, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + self.into_future(close_rx, self.client_loop(rx)).await + } + + async fn into_future( + &self, + close_rx: mpsc::Receiver, + wisp_fut: impl Future>, + ) -> Result<(), WispError> { let ret = futures::select! { - x = self.server_bg_loop(close_rx).fuse() => x, - x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x + _ = self.stream_loop(close_rx).fuse() => Ok(()), + x = wisp_fut.fuse() => x, }; - self.stream_map.lock().await.drain().for_each(|mut x| { - x.1.is_closed.store(true, Ordering::Release); - x.1.stream.disconnect(); - x.1.stream.close_channel(); + self.stream_map.iter_mut().for_each(|mut x| { + x.is_closed.store(true, Ordering::Release); + x.stream.disconnect(); + x.stream.close_channel(); }); + self.stream_map.clear(); ret } - async fn server_bg_loop( - &self, - mut close_rx: mpsc::UnboundedReceiver, - ) -> Result<(), WispError> { - while let Some(msg) = close_rx.next().await { + async fn stream_loop(&self, mut stream_rx: mpsc::Receiver) { + while let Some(msg) = stream_rx.next().await { match msg { WsEvent::SendPacket(packet, channel) => { - if self - .stream_map - .lock() - .await - .get(&packet.stream_id) - .is_some() - { + if self.stream_map.get(&packet.stream_id).is_some() { let _ = channel.send(self.tx.write_frame(packet.into()).await); } else { let _ = channel.send(Err(WispError::InvalidStreamId)); } } WsEvent::Close(packet, channel) => { - if let Some(mut stream) = - self.stream_map.lock().await.remove(&packet.stream_id) - { + if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.stream.disconnect(); stream.stream.close_channel(); let _ = channel.send(self.tx.write_frame(packet.into()).await); @@ -180,20 +198,20 @@ impl ServerMuxInner { WsEvent::EndFut => break, } } - Ok(()) } - async fn server_msg_loop( + async fn server_loop( &self, mut rx: R, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, buffer_size: u32, + close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { // will send continues once flow_control is at 10% of max - let target_buffer_size = buffer_size * 90 / 100; + let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32; self.tx .write_frame(Packet::new_continue(0, buffer_size).into()) .await?; @@ -214,7 +232,7 @@ impl ServerMuxInner { let flow_control_event: Arc = Event::new().into(); let is_closed: Arc = AtomicBool::new(false).into(); - self.stream_map.lock().await.insert( + self.stream_map.insert( packet.stream_id, MuxMapValue { stream: ch_tx, @@ -232,7 +250,7 @@ impl ServerMuxInner { Role::Server, stream_type, ch_rx, - self.close_tx.clone(), + close_tx.clone(), is_closed, flow_control, flow_control_event, @@ -242,7 +260,7 @@ impl ServerMuxInner { .map_err(|x| WispError::Other(Box::new(x)))?; } Data(data) => { - if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { let _ = stream.stream.unbounded_send(data); if stream.stream_type == StreamType::Tcp { stream.flow_control.store( @@ -257,9 +275,47 @@ impl ServerMuxInner { } Continue(_) => break Err(WispError::InvalidPacketType), Close(_) => { - if let Some(mut stream) = - self.stream_map.lock().await.remove(&packet.stream_id) - { + if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + stream.is_closed.store(true, Ordering::Release); + stream.stream.disconnect(); + stream.stream.close_channel(); + } + } + } + } + } + + async fn client_loop(&self, mut rx: R) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + loop { + let frame = rx.wisp_read_frame(&self.tx).await?; + if frame.opcode == ws::OpCode::Close { + break Ok(()); + } + let packet = Packet::try_from(frame)?; + + use PacketType::*; + match packet.packet_type { + Connect(_) => break Err(WispError::InvalidPacketType), + Data(data) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + let _ = stream.stream.unbounded_send(data); + } + } + Continue(inner_packet) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if stream.stream_type == StreamType::Tcp { + stream + .flow_control + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.flow_control_event.notify(u32::MAX); + } + } + } + Close(_) => { + if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); stream.stream.disconnect(); stream.stream.close_channel(); @@ -290,8 +346,7 @@ impl ServerMuxInner { /// } /// ``` pub struct ServerMux { - stream_map: Arc>>, - close_tx: mpsc::UnboundedSender, + close_tx: mpsc::Sender, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, } @@ -305,22 +360,19 @@ impl ServerMux { where R: ws::WebSocketRead, { - let (close_tx, close_rx) = mpsc::unbounded::(); + let (close_tx, close_rx) = mpsc::channel::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let write = ws::LockedWebSocketWrite::new(write); - let map = Arc::new(Mutex::new(HashMap::new())); ( Self { muxstream_recv: rx, close_tx: close_tx.clone(), - stream_map: map.clone(), }, - ServerMuxInner { + MuxInner { tx: write, - close_tx, - stream_map: map.clone(), + stream_map: DashMap::new().into(), } - .into_future(read, close_rx, tx, buffer_size), + .server_into_future(read, close_rx, tx, buffer_size, close_tx), ) } @@ -333,124 +385,13 @@ impl ServerMux { /// /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// this function is called. - pub async fn close(&self) { - self.stream_map.lock().await.drain().for_each(|mut x| { - x.1.is_closed.store(true, Ordering::Release); - x.1.stream.disconnect(); - x.1.stream.close_channel(); - }); - let _ = self.close_tx.unbounded_send(WsEvent::EndFut); + pub async fn close(&mut self) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) } } - -struct ClientMuxInner -where - W: ws::WebSocketWrite, -{ - tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, -} - -impl ClientMuxInner { - pub(crate) async fn into_future( - self, - rx: R, - close_rx: mpsc::UnboundedReceiver, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead, - { - let ret = futures::select! { - x = self.client_bg_loop(close_rx).fuse() => x, - x = self.client_loop(rx).fuse() => x - }; - self.stream_map.lock().await.drain().for_each(|mut x| { - x.1.is_closed.store(true, Ordering::Release); - x.1.stream.disconnect(); - x.1.stream.close_channel(); - }); - ret - } - - async fn client_bg_loop( - &self, - mut close_rx: mpsc::UnboundedReceiver, - ) -> Result<(), WispError> { - while let Some(msg) = close_rx.next().await { - match msg { - WsEvent::SendPacket(packet, channel) => { - if self - .stream_map - .lock() - .await - .get(&packet.stream_id) - .is_some() - { - let _ = channel.send(self.tx.write_frame(packet.into()).await); - } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); - } - } - WsEvent::Close(packet, channel) => { - if let Some(mut stream) = - self.stream_map.lock().await.remove(&packet.stream_id) - { - stream.stream.disconnect(); - stream.stream.close_channel(); - let _ = channel.send(self.tx.write_frame(packet.into()).await); - } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); - } - } - WsEvent::EndFut => break, - } - } - Ok(()) - } - - async fn client_loop(&self, mut rx: R) -> Result<(), WispError> - where - R: ws::WebSocketRead, - { - loop { - let frame = rx.wisp_read_frame(&self.tx).await?; - if frame.opcode == ws::OpCode::Close { - break Ok(()); - } - let packet = Packet::try_from(frame)?; - - use PacketType::*; - match packet.packet_type { - Connect(_) => break Err(WispError::InvalidPacketType), - Data(data) => { - if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); - } - } - Continue(inner_packet) => { - if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - if stream.stream_type == StreamType::Tcp { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); - } - } - } - Close(_) => { - if let Some(mut stream) = - self.stream_map.lock().await.remove(&packet.stream_id) - { - stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); - } - } - } - } - } -} - /// Client side multiplexor. /// /// # Example @@ -470,9 +411,9 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>, next_free_stream_id: AtomicU32, - close_tx: mpsc::UnboundedSender, + close_tx: mpsc::Sender, buf_size: u32, target_buf_size: u32, } @@ -492,23 +433,23 @@ impl ClientMux { return Err(WispError::InvalidStreamId); } if let PacketType::Continue(packet) = first_packet.packet_type { - let (tx, rx) = mpsc::unbounded::(); - let map = Arc::new(Mutex::new(HashMap::new())); + let (tx, rx) = mpsc::channel::(256); + let map = Arc::new(DashMap::new()); Ok(( Self { tx: write.clone(), stream_map: map.clone(), next_free_stream_id: AtomicU32::new(1), - close_tx: tx, + close_tx: tx.clone(), buf_size: packet.buffer_remaining, // server-only target_buf_size: 0, }, - ClientMuxInner { + MuxInner { tx: write.clone(), stream_map: map.clone(), } - .into_future(read, rx), + .client_into_future(read, rx), )) } else { Err(WispError::InvalidPacketType) @@ -540,7 +481,7 @@ impl ClientMux { self.next_free_stream_id .store(next_stream_id, Ordering::Release); - self.stream_map.lock().await.insert( + self.stream_map.insert( stream_id, MuxMapValue { stream: ch_tx, @@ -568,12 +509,10 @@ impl ClientMux { /// /// Also terminates the multiplexor future. Creating a stream is UB after calling this /// function. - pub async fn close(&self) { - self.stream_map.lock().await.drain().for_each(|mut x| { - x.1.is_closed.store(true, Ordering::Release); - x.1.stream.disconnect(); - x.1.stream.close_channel(); - }); - let _ = self.close_tx.unbounded_send(WsEvent::EndFut); + pub async fn close(&mut self) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 7ef129d..d3fb8c7 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,5 +1,5 @@ use crate::{ws, WispError}; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; /// Wisp stream type. #[derive(Debug, PartialEq, Copy, Clone)] @@ -115,13 +115,13 @@ impl TryFrom for ConnectPacket { } } -impl From for Vec { +impl From for Bytes { fn from(packet: ConnectPacket) -> Self { - let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); + let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); encoded.put_u8(packet.stream_type as u8); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); - encoded + encoded.freeze() } } @@ -153,11 +153,11 @@ impl TryFrom for ContinuePacket { } } -impl From for Vec { +impl From for Bytes { fn from(packet: ContinuePacket) -> Self { - let mut encoded = Self::with_capacity(4); + let mut encoded = BytesMut::with_capacity(4); encoded.put_u32_le(packet.buffer_remaining); - encoded + encoded.freeze() } } @@ -190,11 +190,11 @@ impl TryFrom for ClosePacket { } } -impl From for Vec { +impl From for Bytes { fn from(packet: ClosePacket) -> Self { - let mut encoded = Self::with_capacity(1); + let mut encoded = BytesMut::with_capacity(1); encoded.put_u8(packet.reason as u8); - encoded + encoded.freeze() } } @@ -224,12 +224,12 @@ impl PacketType { } } -impl From for Vec { +impl From for Bytes { fn from(packet: PacketType) -> Self { use PacketType::*; match packet { Connect(x) => x.into(), - Data(x) => x.to_vec(), + Data(x) => x, Continue(x) => x.into(), Close(x) => x.into(), } @@ -250,7 +250,10 @@ impl Packet { /// /// The helper functions should be used for most use cases. pub fn new(stream_id: u32, packet: PacketType) -> Self { - Self { stream_id, packet_type: packet } + Self { + stream_id, + packet_type: packet, + } } /// Create a new connect packet. @@ -316,13 +319,15 @@ impl TryFrom for Packet { } } -impl From for Vec { +impl From for Bytes { fn from(packet: Packet) -> Self { - let mut encoded = Self::with_capacity(1 + 4); - encoded.push(packet.packet_type.as_u8()); + let inner_u8 = packet.packet_type.as_u8(); + let inner = Bytes::from(packet.packet_type); + let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len()); + encoded.put_u8(inner_u8); encoded.put_u32_le(packet.stream_id); - encoded.extend(Vec::::from(packet.packet_type)); - encoded + encoded.extend(inner); + encoded.freeze() } } @@ -341,6 +346,6 @@ impl TryFrom for Packet { impl From for ws::Frame { fn from(packet: Packet) -> Self { - Self::binary(Vec::::from(packet).into()) + Self::binary(packet.into()) } } diff --git a/wisp/src/sink_unfold.rs b/wisp/src/sink_unfold.rs index ee9e337..c82254a 100644 --- a/wisp/src/sink_unfold.rs +++ b/wisp/src/sink_unfold.rs @@ -45,28 +45,42 @@ pin_project! { /// Sink for the [`unfold`] function. #[derive(Debug)] #[must_use = "sinks do nothing unless polled"] - pub struct Unfold { + pub struct Unfold { function: F, - close_function: FC, + close_function: CF, #[pin] state: UnfoldState, + #[pin] + close_state: UnfoldState } } -pub(crate) fn unfold(init: T, function: F, close_function: FC) -> Unfold +pub(crate) fn unfold( + init: T, + function: F, + close_init: CT, + close_function: CF, +) -> Unfold where F: FnMut(T, Item) -> R, R: Future>, - FC: Fn() -> Result<(), E>, + CF: FnMut(CT) -> CR, + CR: Future>, { - Unfold { function, close_function, state: UnfoldState::Value { value: init } } + Unfold { + function, + close_function, + state: UnfoldState::Value { value: init }, + close_state: UnfoldState::Value { value: close_init }, + } } -impl Sink for Unfold +impl Sink for Unfold where F: FnMut(T, Item) -> R, R: Future>, - FC: Fn() -> Result<(), E>, + CF: FnMut(CT) -> CR, + CR: Future>, { type Error = E; @@ -104,6 +118,27 @@ where fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; - Poll::Ready((self.close_function)()) + let mut this = self.project(); + Poll::Ready( + if let Some(future) = this.close_state.as_mut().project_future() { + match ready!(future.poll(cx)) { + Ok(state) => { + this.close_state.set(UnfoldState::Value { value: state }); + Ok(()) + } + Err(err) => { + this.close_state.set(UnfoldState::Empty); + Err(err) + } + } + } else { + let future = match this.close_state.as_mut().take_value() { + Some(value) => (this.close_function)(value), + None => panic!("start_send called without poll_ready being called first"), + }; + this.close_state.set(UnfoldState::Future { future }); + return Poll::Pending; + }, + ) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 524c04d..109b9ab 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -7,7 +7,7 @@ use futures::{ channel::{mpsc, oneshot}, stream, task::{Context, Poll}, - Sink, Stream, StreamExt, + Sink, SinkExt, Stream, StreamExt, }; use pin_project_lite::pin_project; use std::{ @@ -31,7 +31,7 @@ pub struct MuxStreamRead { /// Type of the stream. pub stream_type: StreamType, role: Role, - tx: mpsc::UnboundedSender, + tx: mpsc::Sender, rx: mpsc::UnboundedReceiver, is_closed: Arc, flow_control: Arc, @@ -51,13 +51,14 @@ impl MuxStreamRead { if val > self.target_flow_control { let (tx, rx) = oneshot::channel::>(); self.tx - .unbounded_send(WsEvent::SendPacket( + .send(WsEvent::SendPacket( Packet::new_continue( self.stream_id, self.flow_control.fetch_add(val, Ordering::AcqRel) + val, ), tx, )) + .await .ok()?; rx.await.ok()?.ok()?; self.flow_control_read.store(0, Ordering::Release); @@ -80,7 +81,7 @@ pub struct MuxStreamWrite { /// Type of the stream. pub stream_type: StreamType, role: Role, - tx: mpsc::UnboundedSender, + tx: mpsc::Sender, is_closed: Arc, continue_recieved: Arc, flow_control: Arc, @@ -88,7 +89,7 @@ pub struct MuxStreamWrite { impl MuxStreamWrite { /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -100,10 +101,11 @@ impl MuxStreamWrite { } let (tx, rx) = oneshot::channel::>(); self.tx - .unbounded_send(WsEvent::SendPacket( + .send(WsEvent::SendPacket( Packet::new_data(self.stream_id, data), tx, )) + .await .map_err(|x| WispError::Other(Box::new(x)))?; rx.await.map_err(|x| WispError::Other(Box::new(x)))??; if self.role == Role::Client && self.stream_type == StreamType::Tcp { @@ -135,7 +137,7 @@ impl MuxStreamWrite { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -143,10 +145,11 @@ impl MuxStreamWrite { let (tx, rx) = oneshot::channel::>(); self.tx - .unbounded_send(WsEvent::Close( + .send(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) + .await .map_err(|x| WispError::Other(Box::new(x)))?; rx.await.map_err(|x| WispError::Other(Box::new(x)))??; @@ -157,25 +160,19 @@ impl MuxStreamWrite { let handle = self.get_close_handle(); Box::pin(sink_unfold::unfold( self, - |tx, data| async move { + |mut tx, data| async move { tx.write(data).await?; Ok(tx) }, - move || handle.close_sync(CloseReason::Unknown), + handle, + move |mut handle| async { + handle.close(CloseReason::Unknown).await?; + Ok(handle) + }, )) } } -impl Drop for MuxStreamWrite { - fn drop(&mut self) { - let (tx, _) = oneshot::channel::>(); - let _ = self.tx.unbounded_send(WsEvent::Close( - Packet::new_close(self.stream_id, CloseReason::Unknown), - tx, - )); - } -} - /// Multiplexor stream. pub struct MuxStream { /// ID of the stream. @@ -191,7 +188,7 @@ impl MuxStream { role: Role, stream_type: StreamType, rx: mpsc::UnboundedReceiver, - tx: mpsc::UnboundedSender, + tx: mpsc::Sender, is_closed: Arc, flow_control: Arc, continue_recieved: Arc, @@ -228,7 +225,7 @@ impl MuxStream { } /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { self.tx.write(data).await } @@ -248,7 +245,7 @@ impl MuxStream { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { self.tx.close(reason).await } @@ -271,13 +268,13 @@ impl MuxStream { pub struct MuxStreamCloser { /// ID of the stream. pub stream_id: u32, - close_channel: mpsc::UnboundedSender, + close_channel: mpsc::Sender, is_closed: Arc, } impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } @@ -285,32 +282,16 @@ impl MuxStreamCloser { let (tx, rx) = oneshot::channel::>(); self.close_channel - .unbounded_send(WsEvent::Close( + .send(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) + .await .map_err(|x| WispError::Other(Box::new(x)))?; rx.await.map_err(|x| WispError::Other(Box::new(x)))??; Ok(()) } - - pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.is_closed.store(true, Ordering::Release); - - let (tx, _) = oneshot::channel::>(); - self.close_channel - .unbounded_send(WsEvent::Close( - Packet::new_close(self.stream_id, reason), - tx, - )) - .map_err(|x| WispError::Other(Box::new(x)))?; - - Ok(()) - } } pin_project! { @@ -336,10 +317,7 @@ impl MuxStreamIo { impl Stream for MuxStreamIo { type Item = Result, std::io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .rx - .poll_next(cx) - .map(|x| x.map(|x| Ok(x.to_vec()))) + self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec()))) } }