diff --git a/Cargo.lock b/Cargo.lock index 2879ef1..0be46bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,6 +143,16 @@ dependencies = [ "syn", ] +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "atomic-counter" version = "1.0.1" @@ -174,7 +184,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -295,9 +305,9 @@ checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "cc" -version = "1.0.106" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" +checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" [[package]] name = "certs-grabber" @@ -315,9 +325,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" +checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" dependencies = [ "clap_builder", "clap_derive", @@ -325,9 +335,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" +checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" dependencies = [ "anstream", "anstyle", @@ -511,7 +521,7 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "epoxy-client" -version = "2.0.5" +version = "2.0.6" dependencies = [ "async-compression", "async-trait", @@ -526,7 +536,7 @@ dependencies = [ "getrandom", "http 1.1.0", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util-wasm", "js-sys", "parking_lot_core", @@ -557,7 +567,7 @@ dependencies = [ "fastwebsockets", "futures-util", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", "tokio", "tokio-util", @@ -599,14 +609,14 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] name = "fastwebsockets" -version = "0.7.2" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93da8b19e29f202ef35ddd20ddea8c86166850fce5ba2a3c3f3e3174cdbb0620" +checksum = "26da0c7b5cef45c521a6f9cdfffdfeb6c9f5804fbac332deb5ae254634c7a6be" dependencies = [ "base64 0.21.7", "bytes", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", "pin-project", "rand", @@ -962,9 +972,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -986,9 +996,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", @@ -1011,7 +1021,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1027,7 +1037,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.4.0", + "hyper 1.4.1", "pin-project-lite", "tokio", ] @@ -1043,7 +1053,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.4.0", + "hyper 1.4.1", "pin-project-lite", "tower", "tower-service", @@ -1547,6 +1557,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.34" @@ -1653,6 +1672,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "send_wrapper" version = "0.4.0" @@ -1737,7 +1762,7 @@ dependencies = [ "futures", "http-body-util", "humantime", - "hyper 1.4.0", + "hyper 1.4.1", "simple_moving_average", "tokio", "tokio-native-tls", @@ -1959,7 +1984,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -2462,6 +2487,7 @@ name = "wisp-mux" version = "5.0.0" dependencies = [ "async-trait", + "async_io_stream", "bytes", "dashmap", "event-listener", diff --git a/client/Cargo.toml b/client/Cargo.toml index c46fbc3..b2e5d1b 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "epoxy-client" -version = "2.0.5" +version = "2.0.6" edition = "2021" [lib] @@ -13,7 +13,7 @@ base64 = { version = "0.22.1", optional = true } bytes = "1.6.0" cfg-if = "1.0.0" event-listener = "5.3.1" -fastwebsockets = { version = "0.7.2", features = ["unstable-split"], optional = true } +fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } flume = "0.11.0" futures-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] } futures-util = { version = "0.3.30", features = ["sink"] } diff --git a/client/demo.js b/client/demo.js index 6212777..56961c2 100644 --- a/client/demo.js +++ b/client/demo.js @@ -196,9 +196,11 @@ onmessage = async (msg) => { [], { "x-header": "abc" }, ); + let i = 0; while (true) { - log("sending `data`"); - await ws.send("data"); + log(`sending \`data${i}\``); + await ws.send("data"+i); + i++; await (new Promise((res, _) => setTimeout(res, 10))); } } else if (should_tls_test) { diff --git a/client/package.json b/client/package.json index 706456f..6e74d8b 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@mercuryworkshop/epoxy-tls", - "version": "2.0.5-1", + "version": "2.0.6-1", "description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser", "scripts": { "build": "./build.sh" diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 837e107..a75821c 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -1,13 +1,12 @@ use std::{pin::Pin, sync::Arc, task::Poll}; +use bytes::Bytes; use futures_rustls::{ - rustls::{ClientConfig, RootCertStore}, - TlsConnector, TlsStream, + rustls::{ClientConfig, RootCertStore}, + TlsConnector, TlsStream, }; use futures_util::{ - future::Either, - lock::{Mutex, MutexGuard}, - AsyncRead, AsyncWrite, Future, + future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future }; use hyper_util_wasm::client::legacy::connect::{Connected, Connection}; use js_sys::{Array, Reflect, Uint8Array}; @@ -17,246 +16,247 @@ use tower_service::Service; use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen_futures::spawn_local; use wisp_mux::{ - extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType + extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, + ClientMux, IoStream, MuxStreamIo, StreamType, }; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; fn object_to_trustanchor(obj: JsValue) -> Result, JsValue> { - let subject: Uint8Array = Reflect::get(&obj, &"subject".into())?.dyn_into()?; - let pub_key_info: Uint8Array = - Reflect::get(&obj, &"subject_public_key_info".into())?.dyn_into()?; - let name_constraints: Option = Reflect::get(&obj, &"name_constraints".into()) - .and_then(|x| x.dyn_into()) - .ok(); - Ok(TrustAnchor { - subject: Der::from(subject.to_vec()), - subject_public_key_info: Der::from(pub_key_info.to_vec()), - name_constraints: name_constraints.map(|x| Der::from(x.to_vec())), - }) + let subject: Uint8Array = Reflect::get(&obj, &"subject".into())?.dyn_into()?; + let pub_key_info: Uint8Array = + Reflect::get(&obj, &"subject_public_key_info".into())?.dyn_into()?; + let name_constraints: Option = Reflect::get(&obj, &"name_constraints".into()) + .and_then(|x| x.dyn_into()) + .ok(); + Ok(TrustAnchor { + subject: Der::from(subject.to_vec()), + subject_public_key_info: Der::from(pub_key_info.to_vec()), + name_constraints: name_constraints.map(|x| Der::from(x.to_vec())), + }) } pub struct StreamProvider { - wisp_url: String, + wisp_url: String, - wisp_v2: bool, - udp_extension: bool, - websocket_protocols: Vec, + wisp_v2: bool, + udp_extension: bool, + websocket_protocols: Vec, - client_config: Arc, + client_config: Arc, - current_client: Arc>>, + current_client: Arc>>, } pub type ProviderUnencryptedStream = MuxStreamIo; -pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; +pub type ProviderUnencryptedAsyncRW = IoStream; pub type ProviderTlsAsyncRW = TlsStream; pub type ProviderAsyncRW = Either; impl StreamProvider { - pub fn new( - wisp_url: String, - certs: Array, - options: &EpoxyClientOptions, - ) -> Result { - let certs: Result, JsValue> = - certs.iter().map(object_to_trustanchor).collect(); - let certstore = RootCertStore::from_iter(certs.map_err(|_| EpoxyError::InvalidCertStore)?); - let client_config = Arc::new( - ClientConfig::builder() - .with_root_certificates(certstore) - .with_no_client_auth(), - ); + pub fn new( + wisp_url: String, + certs: Array, + options: &EpoxyClientOptions, + ) -> Result { + let certs: Result, JsValue> = + certs.iter().map(object_to_trustanchor).collect(); + let certstore = RootCertStore::from_iter(certs.map_err(|_| EpoxyError::InvalidCertStore)?); + let client_config = Arc::new( + ClientConfig::builder() + .with_root_certificates(certstore) + .with_no_client_auth(), + ); - Ok(Self { - wisp_url, - current_client: Arc::new(Mutex::new(None)), - wisp_v2: options.wisp_v2, - udp_extension: options.udp_extension_required, - websocket_protocols: options.websocket_protocols.clone(), - client_config, - }) - } + Ok(Self { + wisp_url, + current_client: Arc::new(Mutex::new(None)), + wisp_v2: options.wisp_v2, + udp_extension: options.udp_extension_required, + websocket_protocols: options.websocket_protocols.clone(), + client_config, + }) + } - async fn create_client( - &self, - mut locked: MutexGuard<'_, Option>, - ) -> Result<(), EpoxyError> { - let extensions_vec: Vec> = - vec![Box::new(UdpProtocolExtensionBuilder())]; - let extensions = if self.wisp_v2 { - Some(extensions_vec.as_slice()) - } else { - None - }; - let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?; - if !write.wait_for_open().await { - return Err(EpoxyError::WebSocketConnectFailed); - } - let client = ClientMux::create(read, write, extensions).await?; - let (mux, fut) = if self.udp_extension { - client.with_udp_extension_required().await? - } else { - client.with_no_required_extensions() - }; - locked.replace(mux); - let current_client = self.current_client.clone(); - spawn_local(async move { - fut.await; - current_client.lock().await.take(); - }); - Ok(()) - } + async fn create_client( + &self, + mut locked: MutexGuard<'_, Option>, + ) -> Result<(), EpoxyError> { + let extensions_vec: Vec> = + vec![Box::new(UdpProtocolExtensionBuilder())]; + let extensions = if self.wisp_v2 { + Some(extensions_vec.as_slice()) + } else { + None + }; + let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?; + if !write.wait_for_open().await { + return Err(EpoxyError::WebSocketConnectFailed); + } + let client = ClientMux::create(read, write, extensions).await?; + let (mux, fut) = if self.udp_extension { + client.with_udp_extension_required().await? + } else { + client.with_no_required_extensions() + }; + locked.replace(mux); + let current_client = self.current_client.clone(); + spawn_local(async move { + fut.await; + current_client.lock().await.take(); + }); + Ok(()) + } - pub async fn replace_client(&self) -> Result<(), EpoxyError> { - self.create_client(self.current_client.lock().await).await - } + pub async fn replace_client(&self) -> Result<(), EpoxyError> { + self.create_client(self.current_client.lock().await).await + } - pub async fn get_stream( - &self, - stream_type: StreamType, - host: String, - port: u16, - ) -> Result { - Box::pin(async { - let locked = self.current_client.lock().await; - if let Some(mux) = locked.as_ref() { - Ok(mux - .client_new_stream(stream_type, host, port) - .await? - .into_io()) - } else { - self.create_client(locked).await?; - self.get_stream(stream_type, host, port).await - } - }) - .await - } + pub async fn get_stream( + &self, + stream_type: StreamType, + host: String, + port: u16, + ) -> Result { + Box::pin(async { + let locked = self.current_client.lock().await; + if let Some(mux) = locked.as_ref() { + Ok(mux + .client_new_stream(stream_type, host, port) + .await? + .into_io()) + } else { + self.create_client(locked).await?; + self.get_stream(stream_type, host, port).await + } + }) + .await + } - pub async fn get_asyncread( - &self, - stream_type: StreamType, - host: String, - port: u16, - ) -> Result { - Ok(self - .get_stream(stream_type, host, port) - .await? - .into_asyncrw()) - } + pub async fn get_asyncread( + &self, + stream_type: StreamType, + host: String, + port: u16, + ) -> Result { + Ok(self + .get_stream(stream_type, host, port) + .await? + .into_asyncrw()) + } - pub async fn get_tls_stream( - &self, - host: String, - port: u16, - ) -> Result { - let stream = self - .get_asyncread(StreamType::Tcp, host.clone(), port) - .await?; - let connector = TlsConnector::from(self.client_config.clone()); - Ok(connector.connect(host.try_into()?, stream).await?.into()) - } + pub async fn get_tls_stream( + &self, + host: String, + port: u16, + ) -> Result { + let stream = self + .get_asyncread(StreamType::Tcp, host.clone(), port) + .await?; + let connector = TlsConnector::from(self.client_config.clone()); + Ok(connector.connect(host.try_into()?, stream).await?.into()) + } } pin_project! { - pub struct HyperIo { - #[pin] - inner: ProviderAsyncRW, - } + pub struct HyperIo { + #[pin] + inner: ProviderAsyncRW, + } } impl hyper::rt::Read for HyperIo { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let buf_slice: &mut [u8] = unsafe { std::mem::transmute(buf.as_mut()) }; - match self.project().inner.poll_read(cx, buf_slice) { - Poll::Ready(bytes_read) => { - let bytes_read = bytes_read?; - unsafe { - buf.advance(bytes_read); - } - Poll::Ready(Ok(())) - } - Poll::Pending => Poll::Pending, - } - } + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let buf_slice: &mut [u8] = unsafe { std::mem::transmute(buf.as_mut()) }; + match self.project().inner.poll_read(cx, buf_slice) { + Poll::Ready(bytes_read) => { + let bytes_read = bytes_read?; + unsafe { + buf.advance(bytes_read); + } + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } } impl hyper::rt::Write for HyperIo { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().inner.poll_flush(cx) - } + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_flush(cx) + } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().inner.poll_close(cx) - } + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_close(cx) + } - fn poll_write_vectored( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } } impl Connection for HyperIo { - fn connected(&self) -> Connected { - Connected::new() - } + fn connected(&self) -> Connected { + Connected::new() + } } #[derive(Clone)] pub struct StreamProviderService(pub Arc); impl Service for StreamProviderService { - type Response = HyperIo; - type Error = EpoxyError; - type Future = Pin>>>; + type Response = HyperIo; + type Error = EpoxyError; + type Future = Pin>>>; - fn poll_ready( - &mut self, - _: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } + fn poll_ready( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } - fn call(&mut self, req: hyper::Uri) -> Self::Future { - let provider = self.0.clone(); - Box::pin(async move { - let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?; - let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); - let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme { - "https" | "wss" => Ok(443), - "http" | "ws" => Ok(80), - _ => Err(EpoxyError::NoUrlPort), - })?; - Ok(HyperIo { - inner: match scheme { - "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), - "http" | "ws" => { - Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) - } - _ => return Err(EpoxyError::InvalidUrlScheme), - }, - }) - }) - } + fn call(&mut self, req: hyper::Uri) -> Self::Future { + let provider = self.0.clone(); + Box::pin(async move { + let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?; + let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); + let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme { + "https" | "wss" => Ok(443), + "http" | "ws" => Ok(80), + _ => Err(EpoxyError::NoUrlPort), + })?; + Ok(HyperIo { + inner: match scheme { + "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), + "http" | "ws" => { + Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) + } + _ => return Err(EpoxyError::InvalidUrlScheme), + }, + }) + }) + } } diff --git a/client/src/utils.rs b/client/src/utils.rs index f3dbb77..90a4e61 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -9,11 +9,24 @@ use http::{HeaderValue, Uri}; use hyper::{body::Body, rt::Executor}; use js_sys::{Array, ArrayBuffer, Object, Reflect, Uint8Array}; use pin_project_lite::pin_project; -use wasm_bindgen::{JsCast, JsValue}; +use wasm_bindgen::{prelude::*, JsCast, JsValue}; use wasm_bindgen_futures::JsFuture; use crate::EpoxyError; +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = console, js_name = log)] + pub fn js_console_log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + ($($expr:expr),*) => { + $crate::utils::js_console_log(&format!($($expr),*)); + }; +} + pub trait UriExt { fn get_redirect(&self, location: &HeaderValue) -> Result; } diff --git a/client/src/websocket.rs b/client/src/websocket.rs index 06e5dff..296ee55 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -118,7 +118,6 @@ impl EpoxyWebSocket { ); } OpCode::Close => { - let _ = onclose.call0(&JsValue::null()); break; } // ping/pong/continue diff --git a/server/Cargo.toml b/server/Cargo.toml index e354d69..d09c152 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,7 +10,7 @@ clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wr clio = { version = "0.3.5", features = ["clap-parse"] } console-subscriber = { version = "0.2.0", optional = true } dashmap = "5.5.3" -fastwebsockets = { version = "0.7.1", features = ["upgrade", "simdutf8", "unstable-split"] } +fastwebsockets = { version = "0.8.0", features = ["upgrade", "simdutf8", "unstable-split"] } futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } diff --git a/server/src/main.rs b/server/src/main.rs index 776f41b..ceedff0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -17,7 +17,7 @@ use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio::{ - io::{copy, AsyncBufReadExt, AsyncWriteExt}, + io::{copy, copy_bidirectional, AsyncBufReadExt, AsyncWriteExt}, net::{lookup_host, TcpListener, TcpStream, UdpSocket}, select, }; @@ -34,7 +34,7 @@ use wisp_mux::{ udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, }, - CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError, + CloseReason, ConnectPacket, MuxStream, IoStream, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -269,6 +269,8 @@ async fn accept_http( } } +// re-enable once MuxStreamAsyncRW is fixed +/* async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> { let (muxrx, muxtx) = mux.into_split(); let mut muxrx = muxrx.compat(); @@ -300,6 +302,7 @@ async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> x = slow_fut => x.map(|_| ()), } } +*/ async fn handle_mux( packet: ConnectPacket, @@ -311,9 +314,9 @@ async fn handle_mux( ); match packet.stream_type { StreamType::Tcp => { - let tcp_stream = TcpStream::connect(uri).await?; - let mux = stream.into_io().into_asyncrw(); - copy_buf(mux, tcp_stream).await?; + let mut tcp_stream = TcpStream::connect(uri).await?; + let mut mux = stream.into_io().into_asyncrw().compat(); + copy_bidirectional(&mut mux, &mut tcp_stream).await?; } StreamType::Udp => { let uri = lookup_host(uri) diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index dc09cde..13041a4 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -8,7 +8,7 @@ atomic-counter = "1.0.1" bytes = "1.5.0" clap = { version = "4.5.4", features = ["cargo", "derive"] } console-subscriber = { version = "0.2.0", optional = true } -fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } +fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } futures = "0.3.30" http-body-util = "0.1.0" humantime = "2.1.0" diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 0dc2a39..1578b1f 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,10 +10,11 @@ edition = "2021" [dependencies] async-trait = "0.1.79" +async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" -fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } +fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index e978029..90a3f35 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -14,6 +14,7 @@ mod stream; pub mod ws; pub use crate::{packet::*, stream::*}; +pub use async_io_stream::IoStream; use bytes::Bytes; use dashmap::DashMap; diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index b468e50..ff65ade 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,559 +1,457 @@ use crate::{ - sink_unfold, - ws::{Frame, LockedWebSocketWrite}, - CloseReason, Packet, Role, StreamType, WispError, + sink_unfold, + ws::{Frame, LockedWebSocketWrite}, + CloseReason, Packet, Role, StreamType, WispError, }; +use async_io_stream::IoStream; use bytes::{BufMut, Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; use futures::{ - channel::oneshot, - select, - stream::{self, IntoAsyncRead, SplitSink, SplitStream}, - task::{Context, Poll}, - AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, StreamExt, TryStreamExt, + channel::oneshot, + select, stream, + task::{Context, Poll}, + FutureExt, Sink, Stream, }; use pin_project_lite::pin_project; use std::{ - pin::Pin, - sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, - }, - task::ready, + pin::Pin, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, }; pub(crate) enum WsEvent { - Close(Packet, oneshot::Sender>), - CreateStream( - StreamType, - String, - u16, - oneshot::Sender>, - ), - EndFut(Option), + Close(Packet, oneshot::Sender>), + CreateStream( + StreamType, + String, + u16, + oneshot::Sender>, + ), + EndFut(Option), } /// Read side of a multiplexor stream. pub struct MuxStreamRead { - /// ID of the stream. - pub stream_id: u32, - /// Type of the stream. - pub stream_type: StreamType, - role: Role, - tx: LockedWebSocketWrite, - rx: mpsc::Receiver, - is_closed: Arc, - is_closed_event: Arc, - flow_control: Arc, - flow_control_read: AtomicU32, - target_flow_control: u32, + /// ID of the stream. + pub stream_id: u32, + /// Type of the stream. + pub stream_type: StreamType, + role: Role, + tx: LockedWebSocketWrite, + rx: mpsc::Receiver, + is_closed: Arc, + is_closed_event: Arc, + flow_control: Arc, + flow_control_read: AtomicU32, + target_flow_control: u32, } impl MuxStreamRead { - /// Read an event from the stream. - pub async fn read(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - return None; - } - let bytes = select! { - x = self.rx.recv_async() => x.ok()?, - _ = self.is_closed_event.listen().fuse() => return None - }; - if self.role == Role::Server && self.stream_type == StreamType::Tcp { - let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; - if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { - self.tx - .write_frame( - Packet::new_continue( - self.stream_id, - self.flow_control.fetch_add(val, Ordering::AcqRel) + val, - ) - .into(), - ) - .await - .ok()?; - self.flow_control_read.store(0, Ordering::Release); - } - } - Some(bytes) - } + /// Read an event from the stream. + pub async fn read(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + return None; + } + let bytes = select! { + x = self.rx.recv_async() => x.ok()?, + _ = self.is_closed_event.listen().fuse() => return None + }; + if self.role == Role::Server && self.stream_type == StreamType::Tcp { + let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; + if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { + self.tx + .write_frame( + Packet::new_continue( + self.stream_id, + self.flow_control.fetch_add(val, Ordering::AcqRel) + val, + ) + .into(), + ) + .await + .ok()?; + self.flow_control_read.store(0, Ordering::Release); + } + } + Some(bytes) + } - pub(crate) fn into_stream(self) -> Pin + Send>> { - Box::pin(stream::unfold(self, |rx| async move { - Some((rx.read().await?, rx)) - })) - } + pub(crate) fn into_stream(self) -> Pin + Send>> { + Box::pin(stream::unfold(self, |rx| async move { + Some((rx.read().await?, rx)) + })) + } } /// Write side of a multiplexor stream. pub struct MuxStreamWrite { - /// ID of the stream. - pub stream_id: u32, - /// Type of the stream. - pub stream_type: StreamType, - role: Role, - mux_tx: mpsc::Sender, - tx: LockedWebSocketWrite, - is_closed: Arc, - continue_recieved: Arc, - flow_control: Arc, + /// ID of the stream. + pub stream_id: u32, + /// Type of the stream. + pub stream_type: StreamType, + role: Role, + mux_tx: mpsc::Sender, + tx: LockedWebSocketWrite, + is_closed: Arc, + continue_recieved: Arc, + flow_control: Arc, } impl MuxStreamWrite { - /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { - if self.role == Role::Client - && self.stream_type == StreamType::Tcp - && self.flow_control.load(Ordering::Acquire) == 0 - { - self.continue_recieved.listen().await; - } - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } + /// Write data to the stream. + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + if self.role == Role::Client + && self.stream_type == StreamType::Tcp + && self.flow_control.load(Ordering::Acquire) == 0 + { + self.continue_recieved.listen().await; + } + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } - self.tx - .write_frame(Frame::from(Packet::new_data(self.stream_id, data))) - .await?; + self.tx + .write_frame(Frame::from(Packet::new_data(self.stream_id, data))) + .await?; - if self.role == Role::Client && self.stream_type == StreamType::Tcp { - self.flow_control.store( - self.flow_control.load(Ordering::Acquire).saturating_sub(1), - Ordering::Release, - ); - } - Ok(()) - } + if self.role == Role::Client && self.stream_type == StreamType::Tcp { + self.flow_control.store( + self.flow_control.load(Ordering::Acquire).saturating_sub(1), + Ordering::Release, + ); + } + Ok(()) + } - /// Get a handle to close the connection. - /// - /// Useful to close the connection without having access to the stream. - /// - /// # Example - /// ``` - /// let handle = stream.get_close_handle(); - /// if let Err(error) = handle_stream(stream) { - /// handle.close(0x01); - /// } - /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { - MuxStreamCloser { - stream_id: self.stream_id, - close_channel: self.mux_tx.clone(), - is_closed: self.is_closed.clone(), - } - } + /// Get a handle to close the connection. + /// + /// Useful to close the connection without having access to the stream. + /// + /// # Example + /// ``` + /// let handle = stream.get_close_handle(); + /// if let Err(error) = handle_stream(stream) { + /// handle.close(0x01); + /// } + /// ``` + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + stream_id: self.stream_id, + close_channel: self.mux_tx.clone(), + is_closed: self.is_closed.clone(), + } + } - /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: self.stream_id, - tx: self.tx.clone(), - is_closed: self.is_closed.clone(), - } - } + /// Get a protocol extension stream to send protocol extension packets. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: self.stream_id, + tx: self.tx.clone(), + is_closed: self.is_closed.clone(), + } + } - /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.is_closed.store(true, Ordering::Release); + /// Close the stream. You will no longer be able to write or read after this has been called. + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + self.is_closed.store(true, Ordering::Release); - let (tx, rx) = oneshot::channel::>(); - self.mux_tx - .send_async(WsEvent::Close( - Packet::new_close(self.stream_id, reason), - tx, - )) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; + let (tx, rx) = oneshot::channel::>(); + self.mux_tx + .send_async(WsEvent::Close( + Packet::new_close(self.stream_id, reason), + tx, + )) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; - Ok(()) - } + Ok(()) + } - pub(crate) fn into_sink(self) -> Pin + Send>> { - let handle = self.get_close_handle(); - Box::pin(sink_unfold::unfold( - self, - |tx, data| async move { - tx.write(data).await?; - Ok(tx) - }, - handle, - move |handle| async { - handle.close(CloseReason::Unknown).await?; - Ok(handle) - }, - )) - } + pub(crate) fn into_sink(self) -> Pin + Send>> { + let handle = self.get_close_handle(); + Box::pin(sink_unfold::unfold( + self, + |tx, data| async move { + tx.write(data).await?; + Ok(tx) + }, + handle, + move |handle| async { + handle.close(CloseReason::Unknown).await?; + Ok(handle) + }, + )) + } } impl Drop for MuxStreamWrite { - fn drop(&mut self) { - if !self.is_closed.load(Ordering::Acquire) { - self.is_closed.store(true, Ordering::Release); - let (tx, _) = oneshot::channel(); - let _ = self.mux_tx.send(WsEvent::Close( - Packet::new_close(self.stream_id, CloseReason::Unknown), - tx, - )); - } - } + fn drop(&mut self) { + if !self.is_closed.load(Ordering::Acquire) { + self.is_closed.store(true, Ordering::Release); + let (tx, _) = oneshot::channel(); + let _ = self.mux_tx.send(WsEvent::Close( + Packet::new_close(self.stream_id, CloseReason::Unknown), + tx, + )); + } + } } /// Multiplexor stream. pub struct MuxStream { - /// ID of the stream. - pub stream_id: u32, - rx: MuxStreamRead, - tx: MuxStreamWrite, + /// ID of the stream. + pub stream_id: u32, + rx: MuxStreamRead, + tx: MuxStreamWrite, } impl MuxStream { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - stream_id: u32, - role: Role, - stream_type: StreamType, - rx: mpsc::Receiver, - mux_tx: mpsc::Sender, - tx: LockedWebSocketWrite, - is_closed: Arc, - is_closed_event: Arc, - flow_control: Arc, - continue_recieved: Arc, - target_flow_control: u32, - ) -> Self { - Self { - stream_id, - rx: MuxStreamRead { - stream_id, - stream_type, - role, - tx: tx.clone(), - rx, - is_closed: is_closed.clone(), - is_closed_event: is_closed_event.clone(), - flow_control: flow_control.clone(), - flow_control_read: AtomicU32::new(0), - target_flow_control, - }, - tx: MuxStreamWrite { - stream_id, - stream_type, - role, - mux_tx, - tx, - is_closed: is_closed.clone(), - flow_control: flow_control.clone(), - continue_recieved: continue_recieved.clone(), - }, - } - } + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + stream_id: u32, + role: Role, + stream_type: StreamType, + rx: mpsc::Receiver, + mux_tx: mpsc::Sender, + tx: LockedWebSocketWrite, + is_closed: Arc, + is_closed_event: Arc, + flow_control: Arc, + continue_recieved: Arc, + target_flow_control: u32, + ) -> Self { + Self { + stream_id, + rx: MuxStreamRead { + stream_id, + stream_type, + role, + tx: tx.clone(), + rx, + is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), + flow_control: flow_control.clone(), + flow_control_read: AtomicU32::new(0), + target_flow_control, + }, + tx: MuxStreamWrite { + stream_id, + stream_type, + role, + mux_tx, + tx, + is_closed: is_closed.clone(), + flow_control: flow_control.clone(), + continue_recieved: continue_recieved.clone(), + }, + } + } - /// Read an event from the stream. - pub async fn read(&self) -> Option { - self.rx.read().await - } + /// Read an event from the stream. + pub async fn read(&self) -> Option { + self.rx.read().await + } - /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { - self.tx.write(data).await - } + /// Write data to the stream. + pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + self.tx.write(data).await + } - /// Get a handle to close the connection. - /// - /// Useful to close the connection without having access to the stream. - /// - /// # Example - /// ``` - /// let handle = stream.get_close_handle(); - /// if let Err(error) = handle_stream(stream) { - /// handle.close(0x01); - /// } - /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { - self.tx.get_close_handle() - } + /// Get a handle to close the connection. + /// + /// Useful to close the connection without having access to the stream. + /// + /// # Example + /// ``` + /// let handle = stream.get_close_handle(); + /// if let Err(error) = handle_stream(stream) { + /// handle.close(0x01); + /// } + /// ``` + pub fn get_close_handle(&self) -> MuxStreamCloser { + self.tx.get_close_handle() + } - /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - self.tx.get_protocol_extension_stream() - } + /// Get a protocol extension stream to send protocol extension packets. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + self.tx.get_protocol_extension_stream() + } - /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - self.tx.close(reason).await - } + /// Close the stream. You will no longer be able to write or read after this has been called. + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + self.tx.close(reason).await + } - /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { - (self.rx, self.tx) - } + /// Split the stream into read and write parts, consuming it. + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + (self.rx, self.tx) + } - /// Turn the stream into one that implements futures `Stream + Sink`, consuming it. - pub fn into_io(self) -> MuxStreamIo { - MuxStreamIo { - rx: self.rx.into_stream(), - tx: self.tx.into_sink(), - } - } + /// Turn the stream into one that implements futures `Stream + Sink`, consuming it. + pub fn into_io(self) -> MuxStreamIo { + MuxStreamIo { + rx: MuxStreamIoStream { + rx: self.rx.into_stream(), + }, + tx: MuxStreamIoSink { + tx: self.tx.into_sink(), + }, + } + } } /// Close handle for a multiplexor stream. #[derive(Clone)] pub struct MuxStreamCloser { - /// ID of the stream. - pub stream_id: u32, - close_channel: mpsc::Sender, - is_closed: Arc, + /// ID of the stream. + pub stream_id: u32, + close_channel: mpsc::Sender, + is_closed: Arc, } impl MuxStreamCloser { - /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.is_closed.store(true, Ordering::Release); + /// Close the stream. You will no longer be able to write or read after this has been called. + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + self.is_closed.store(true, Ordering::Release); - let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send_async(WsEvent::Close( - Packet::new_close(self.stream_id, reason), - tx, - )) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .send_async(WsEvent::Close( + Packet::new_close(self.stream_id, reason), + tx, + )) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; - Ok(()) - } + Ok(()) + } } /// Stream for sending arbitrary protocol extension packets. pub struct MuxProtocolExtensionStream { - /// ID of the stream. - pub stream_id: u32, - pub(crate) tx: LockedWebSocketWrite, - pub(crate) is_closed: Arc, + /// ID of the stream. + pub stream_id: u32, + pub(crate) tx: LockedWebSocketWrite, + pub(crate) is_closed: Arc, } impl MuxProtocolExtensionStream { - /// Send a protocol extension packet with this stream's ID. - pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); - encoded.put_u8(packet_type); - encoded.put_u32_le(self.stream_id); - encoded.extend(data); - self.tx.write_frame(Frame::binary(encoded)).await - } + /// Send a protocol extension packet with this stream's ID. + pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); + encoded.put_u8(packet_type); + encoded.put_u32_le(self.stream_id); + encoded.extend(data); + self.tx.write_frame(Frame::binary(encoded)).await + } } pin_project! { - /// Multiplexor stream that implements futures `Stream + Sink`. - pub struct MuxStreamIo { - #[pin] - rx: Pin + Send>>, - #[pin] - tx: Pin + Send>>, - } + /// Multiplexor stream that implements futures `Stream + Sink`. + pub struct MuxStreamIo { + #[pin] + rx: MuxStreamIoStream, + #[pin] + tx: MuxStreamIoSink, + } } impl MuxStreamIo { - /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. - pub fn into_asyncrw(self) -> MuxStreamAsyncRW { - let (tx, rx) = self.split(); - MuxStreamAsyncRW { - rx: MuxStreamAsyncRead::new(rx), - tx: MuxStreamAsyncWrite::new(tx), - } - } + /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. + pub fn into_asyncrw(self) -> IoStream { + IoStream::new(self) + } + + /// Split the stream into read and write parts, consuming it. + pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) { + (self.rx, self.tx) + } } impl Stream for MuxStreamIo { - type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx).map(|x| x.map(Ok)) - } + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_next(cx) + } } impl Sink for MuxStreamIo { - type Error = std::io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_ready(cx) - .map_err(std::io::Error::other) - } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - self.project() - .tx - .start_send(item) - .map_err(std::io::Error::other) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_flush(cx) - .map_err(std::io::Error::other) - } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_close(cx) - .map_err(std::io::Error::other) - } + type Error = std::io::Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_ready(cx) + } + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.project().tx.start_send(item) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx) + } } pin_project! { - /// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. - pub struct MuxStreamAsyncRW { - #[pin] - rx: MuxStreamAsyncRead, - #[pin] - tx: MuxStreamAsyncWrite, - } + /// Read side of a multiplexor stream that implements futures `Stream`. + pub struct MuxStreamIoStream { + #[pin] + rx: Pin + Send>>, + } } -impl MuxStreamAsyncRW { - /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { - (self.rx, self.tx) - } -} - -impl AsyncRead for MuxStreamAsyncRW { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().rx.poll_read(cx, buf) - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - self.project().rx.poll_read_vectored(cx, bufs) - } -} - -impl AsyncBufRead for MuxStreamAsyncRW { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().rx.consume(amt) - } -} - -impl AsyncWrite for MuxStreamAsyncRW { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().tx.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } +impl Stream for MuxStreamIoStream { + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_next(cx).map(|x| x.map(Ok)) + } } pin_project! { - /// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`. - pub struct MuxStreamAsyncRead { - #[pin] - rx: IntoAsyncRead>, - } + /// Write side of a multiplexor stream that implements futures `Sink`. + pub struct MuxStreamIoSink { + #[pin] + tx: Pin + Send>>, + } } -impl MuxStreamAsyncRead { - pub(crate) fn new(stream: SplitStream) -> Self { - Self { - rx: stream.into_async_read(), - } - } -} - -impl AsyncRead for MuxStreamAsyncRead { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().rx.poll_read(cx, buf) - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - self.project().rx.poll_read_vectored(cx, bufs) - } -} - -impl AsyncBufRead for MuxStreamAsyncRead { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().rx.consume(amt) - } -} - -pin_project! { - /// Write side of a multiplexor stream that implements futures `AsyncWrite`. - pub struct MuxStreamAsyncWrite { - #[pin] - tx: SplitSink, - } -} - -impl MuxStreamAsyncWrite { - pub(crate) fn new(sink: SplitSink) -> Self { - Self { tx: sink } - } -} - -impl AsyncWrite for MuxStreamAsyncWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut this = self.project(); - - ready!(this.tx.as_mut().poll_ready(cx))?; - match this.tx.start_send(Bytes::copy_from_slice(buf)) { - Ok(()) => Poll::Ready(Ok(buf.len())), - Err(e) => Poll::Ready(Err(e)), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } +impl Sink for MuxStreamIoSink { + type Error = std::io::Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_ready(cx) + .map_err(std::io::Error::other) + } + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.project() + .tx + .start_send(item) + .map_err(std::io::Error::other) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_flush(cx) + .map_err(std::io::Error::other) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_close(cx) + .map_err(std::io::Error::other) + } }