From 4f0a36239088ba7c498dd4b17b567bee3ddefec6 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 5 Jul 2024 16:03:55 -0700 Subject: [PATCH] massive speed improvements --- Cargo.lock | 31 +----- client/Cargo.toml | 2 +- client/src/io_stream.rs | 6 +- client/src/stream_provider.rs | 5 +- server/Cargo.toml | 5 +- server/src/main.rs | 125 ++++++++++++++++++------ wisp/Cargo.toml | 4 +- wisp/src/fastwebsockets.rs | 11 ++- wisp/src/packet.rs | 4 +- wisp/src/stream.rs | 178 ++++++++++++++++++++++++++++++---- 10 files changed, 282 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf52723..e467dfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,17 +143,6 @@ dependencies = [ "syn", ] -[[package]] -name = "async_io_stream" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" -dependencies = [ - "futures", - "rustc_version", - "tokio", -] - [[package]] name = "atomic-counter" version = "1.0.1" @@ -559,6 +548,7 @@ name = "epoxy-server" version = "1.0.0" dependencies = [ "bytes", + "cfg-if", "clap", "clio", "console-subscriber", @@ -1556,15 +1546,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "0.38.34" @@ -1671,12 +1652,6 @@ dependencies = [ "libc", ] -[[package]] -name = "semver" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" - [[package]] name = "send_wrapper" version = "0.4.0" @@ -1963,6 +1938,7 @@ checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -2482,10 +2458,9 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "wisp-mux" -version = "4.0.1" +version = "5.0.0" dependencies = [ "async-trait", - "async_io_stream", "bytes", "dashmap", "event-listener", diff --git a/client/Cargo.toml b/client/Cargo.toml index e483e48..55ea77e 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -32,7 +32,7 @@ wasm-bindgen = "0.2.92" wasm-bindgen-futures = "0.4.42" wasm-streams = "0.4.0" web-sys = { version = "0.3.69", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "WebSocket"] } -wisp-mux = { version = "4.0.1", path = "../wisp", features = ["wasm"] } +wisp-mux = { path = "../wisp", features = ["wasm"] } [dependencies.ring] # update whenever rustls updates diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index 6b4b7c7..70c6a39 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -1,4 +1,4 @@ -use bytes::{buf::UninitSlice, BufMut, BytesMut}; +use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use futures_util::{ io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, }; @@ -105,7 +105,7 @@ impl EpoxyIoStream { #[wasm_bindgen] pub struct EpoxyUdpStream { - tx: Mutex>>, + tx: Mutex>, onerror: Function, } @@ -154,7 +154,7 @@ impl EpoxyUdpStream { .map_err(|_| EpoxyError::InvalidPayload)? .0 .to_vec(); - Ok(self.tx.lock().await.send(payload).await?) + Ok(self.tx.lock().await.send(payload.into()).await?) } .await; diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 7e6050b..837e107 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -17,8 +17,7 @@ use tower_service::Service; use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen_futures::spawn_local; use wisp_mux::{ - extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, - ClientMux, IoStream, MuxStreamIo, StreamType, + extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType }; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; @@ -50,7 +49,7 @@ pub struct StreamProvider { } pub type ProviderUnencryptedStream = MuxStreamIo; -pub type ProviderUnencryptedAsyncRW = IoStream>; +pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; pub type ProviderTlsAsyncRW = TlsStream; pub type ProviderAsyncRW = Either; diff --git a/server/Cargo.toml b/server/Cargo.toml index 1cecd66..e354d69 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] bytes = "1.5.0" +cfg-if = "1.0.0" clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] } clio = { version = "0.3.5", features = ["clap-parse"] } console-subscriber = { version = "0.2.0", optional = true } @@ -15,8 +16,8 @@ http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } hyper-util = { version = "0.1.2", features = ["tokio"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } -tokio-util = { version = "0.7.10", features = ["codec"] } -wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] } +tokio-util = { version = "0.7.10", features = ["codec", "compat"] } +wisp-mux = { path = "../wisp", features = ["fastwebsockets"] } [features] tokio-console = ["tokio/tracing", "dep:console-subscriber"] diff --git a/server/src/main.rs b/server/src/main.rs index be3a5b4..776f41b 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc}; use bytes::Bytes; +use cfg_if::cfg_if; use clap::Parser; use fastwebsockets::{ upgrade::{self, UpgradeFut}, @@ -9,18 +10,23 @@ use fastwebsockets::{ }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, + body::Incoming, server::conn::http1, service::service_fn, upgrade::Parts, Request, Response, + StatusCode, }; use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio::{ - io::copy_bidirectional, + io::{copy, AsyncBufReadExt, AsyncWriteExt}, net::{lookup_host, TcpListener, TcpStream, UdpSocket}, + select, }; -use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] use tokio_util::either::Either; +use tokio_util::{ + codec::{BytesCodec, Framed}, + compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}, +}; use wisp_mux::{ extensions::{ @@ -28,7 +34,7 @@ use wisp_mux::{ udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, }, - CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, + CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -83,10 +89,13 @@ struct MuxOptions { pub wisp_v1: bool, } -#[cfg(not(unix))] -type ListenerStream = TcpStream; -#[cfg(unix)] -type ListenerStream = Either; +cfg_if! { + if #[cfg(unix)] { + type ListenerStream = Either; + } else { + type ListenerStream = TcpStream; + } +} enum Listener { Tcp(TcpListener), @@ -99,13 +108,12 @@ impl Listener { Ok(match self { Listener::Tcp(listener) => { let (stream, addr) = listener.accept().await?; - #[cfg(not(unix))] - { - (stream, addr.to_string()) - } - #[cfg(unix)] - { - (Either::Left(stream), addr.to_string()) + cfg_if! { + if #[cfg(unix)] { + (Either::Left(stream), addr.to_string()) + } else { + (stream, addr.to_string()) + } } } #[cfg(unix)] @@ -123,17 +131,20 @@ impl Listener { } async fn bind(addr: &str, unix: bool) -> Result { - #[cfg(unix)] - if unix { - if std::fs::metadata(addr).is_ok() { - println!("attempting to remove old socket {:?}", addr); - std::fs::remove_file(addr)?; + cfg_if! { + if #[cfg(unix)] { + if unix { + if std::fs::metadata(addr).is_ok() { + println!("attempting to remove old socket {:?}", addr); + std::fs::remove_file(addr)?; + } + return Ok(Listener::Unix(UnixListener::bind(addr)?)); + } + } else { + if unix { + panic!("Unix sockets are only supported on Unix."); + } } - return Ok(Listener::Unix(UnixListener::bind(addr)?)); - } - #[cfg(not(unix))] - if unix { - panic!("Unix sockets are only supported on Unix."); } Ok(Listener::Tcp(TcpListener::bind(addr).await?)) @@ -258,6 +269,38 @@ async fn accept_http( } } +async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> { + let (muxrx, muxtx) = mux.into_split(); + let mut muxrx = muxrx.compat(); + let mut muxtx = muxtx.compat_write(); + + let (mut tcprx, mut tcptx) = tcp.into_split(); + + let fast_fut = async { + loop { + let buf = muxrx.fill_buf().await?; + if buf.is_empty() { + tcptx.flush().await?; + return Ok(()); + } + + let i = tcptx.write(buf).await?; + if i == 0 { + return Err(std::io::ErrorKind::WriteZero.into()); + } + + muxrx.consume(i); + } + }; + + let slow_fut = copy(&mut tcprx, &mut muxtx); + + select! { + x = fast_fut => x, + x = slow_fut => x.map(|_| ()), + } +} + async fn handle_mux( packet: ConnectPacket, stream: MuxStream, @@ -268,9 +311,9 @@ async fn handle_mux( ); match packet.stream_type { StreamType::Tcp => { - let mut tcp_stream = TcpStream::connect(uri).await?; - let mut mux_stream = stream.into_io().into_asyncrw(); - copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?; + let tcp_stream = TcpStream::connect(uri).await?; + let mux = stream.into_io().into_asyncrw(); + copy_buf(mux, tcp_stream).await?; } StreamType::Udp => { let uri = lookup_host(uri) @@ -315,7 +358,31 @@ async fn accept_ws( // to prevent memory ""leaks"" because users are sending in packets way too fast the message // size is set to 1M ws.set_max_message_size(1024 * 1024); - let (rx, tx) = ws.split(tokio::io::split); + let (rx, tx) = ws.split(|x| { + let Parts { + io, read_buf: buf, .. + } = x + .into_inner() + .downcast::>() + .unwrap(); + assert_eq!(buf.len(), 0); + cfg_if! { + if #[cfg(unix)] { + match io.into_inner() { + Either::Left(x) => { + let (rx, tx) = x.into_split(); + (Either::Left(rx), Either::Left(tx)) + } + Either::Right(x) => { + let (rx, tx) = x.into_split(); + (Either::Right(rx), Either::Right(tx)) + } + } + } else { + io.into_inner().into_split() + } + } + }); let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index cfb28fe..0dc2a39 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "4.0.1" +version = "5.0.0" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -10,7 +10,6 @@ edition = "2021" [dependencies] async-trait = "0.1.79" -async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" @@ -23,7 +22,6 @@ tokio = { version = "1.35.1", optional = true, default-features = false } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] -tokio_io = ["async_io_stream/tokio_io"] wasm = ["futures-timer/wasm-bindgen"] [package.metadata.docs.rs] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index f21cf75..2525662 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -9,6 +9,15 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ws::LockedWebSocketWrite, WispError}; +fn match_payload(payload: Payload) -> BytesMut { + match payload { + Payload::Bytes(x) => x, + Payload::Owned(x) => BytesMut::from(x.deref()), + Payload::BorrowedMut(x) => BytesMut::from(x.deref()), + Payload::Borrowed(x) => BytesMut::from(x), + } +} + impl From for crate::ws::OpCode { fn from(opcode: OpCode) -> Self { use OpCode::*; @@ -30,7 +39,7 @@ impl From> for crate::ws::Frame { Self { finished: frame.fin, opcode: frame.opcode.into(), - payload: BytesMut::from(frame.payload.deref()), + payload: match_payload(frame.payload), } } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 919b443..138e0c5 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -240,7 +240,7 @@ impl Encode for InfoPacket { bytes.put_u8(self.version.major); bytes.put_u8(self.version.minor); for extension in self.extensions { - bytes.extend(Bytes::from(extension)); + bytes.extend_from_slice(&Bytes::from(extension)); } } } @@ -290,7 +290,7 @@ impl Encode for PacketType { use PacketType as P; match self { P::Connect(x) => x.encode(bytes), - P::Data(x) => bytes.extend(x), + P::Data(x) => bytes.extend_from_slice(&x), P::Continue(x) => x.encode(bytes), P::Close(x) => x.encode(bytes), P::Info(x) => x.encode(bytes), diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 1dc792b..b468e50 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -4,15 +4,15 @@ use crate::{ CloseReason, Packet, Role, StreamType, WispError, }; -pub use async_io_stream::IoStream; use bytes::{BufMut, Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; use futures::{ channel::oneshot, - select, stream, + select, + stream::{self, IntoAsyncRead, SplitSink, SplitStream}, task::{Context, Poll}, - FutureExt, Sink, Stream, + AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, StreamExt, TryStreamExt, }; use pin_project_lite::pin_project; use std::{ @@ -21,6 +21,7 @@ use std::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, }, + task::ready, }; pub(crate) enum WsEvent { @@ -367,26 +368,24 @@ pin_project! { } impl MuxStreamIo { - /// Turn the stream into one that implements futures `AsyncRead + AsyncWrite`. - /// - /// Enable the `tokio_io` feature to implement the tokio version of `AsyncRead` and - /// `AsyncWrite`. - pub fn into_asyncrw(self) -> IoStream> { - IoStream::new(self) + /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. + pub fn into_asyncrw(self) -> MuxStreamAsyncRW { + let (tx, rx) = self.split(); + MuxStreamAsyncRW { + rx: MuxStreamAsyncRead::new(rx), + tx: MuxStreamAsyncWrite::new(tx), + } } } impl Stream for MuxStreamIo { - type Item = Result, std::io::Error>; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .rx - .poll_next(cx) - .map(|x| x.map(|x| Ok(x.to_vec()))) + self.project().rx.poll_next(cx).map(|x| x.map(Ok)) } } -impl Sink> for MuxStreamIo { +impl Sink for MuxStreamIo { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() @@ -394,10 +393,10 @@ impl Sink> for MuxStreamIo { .poll_ready(cx) .map_err(std::io::Error::other) } - fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { self.project() .tx - .start_send(item.into()) + .start_send(item) .map_err(std::io::Error::other) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -413,3 +412,148 @@ impl Sink> for MuxStreamIo { .map_err(std::io::Error::other) } } + +pin_project! { + /// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. + pub struct MuxStreamAsyncRW { + #[pin] + rx: MuxStreamAsyncRead, + #[pin] + tx: MuxStreamAsyncWrite, + } +} + +impl MuxStreamAsyncRW { + /// Split the stream into read and write parts, consuming it. + pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { + (self.rx, self.tx) + } +} + +impl AsyncRead for MuxStreamAsyncRW { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().rx.poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project().rx.poll_read_vectored(cx, bufs) + } +} + +impl AsyncBufRead for MuxStreamAsyncRW { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().rx.consume(amt) + } +} + +impl AsyncWrite for MuxStreamAsyncRW { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().tx.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx) + } +} + +pin_project! { + /// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`. + pub struct MuxStreamAsyncRead { + #[pin] + rx: IntoAsyncRead>, + } +} + +impl MuxStreamAsyncRead { + pub(crate) fn new(stream: SplitStream) -> Self { + Self { + rx: stream.into_async_read(), + } + } +} + +impl AsyncRead for MuxStreamAsyncRead { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().rx.poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project().rx.poll_read_vectored(cx, bufs) + } +} + +impl AsyncBufRead for MuxStreamAsyncRead { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().rx.consume(amt) + } +} + +pin_project! { + /// Write side of a multiplexor stream that implements futures `AsyncWrite`. + pub struct MuxStreamAsyncWrite { + #[pin] + tx: SplitSink, + } +} + +impl MuxStreamAsyncWrite { + pub(crate) fn new(sink: SplitSink) -> Self { + Self { tx: sink } + } +} + +impl AsyncWrite for MuxStreamAsyncWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + + ready!(this.tx.as_mut().poll_ready(cx))?; + match this.tx.start_send(Bytes::copy_from_slice(buf)) { + Ok(()) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx) + } +}