diff --git a/Cargo.lock b/Cargo.lock index dc862b4..eec8ba9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,12 +133,6 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.5.0" @@ -248,6 +242,7 @@ name = "epoxy-client" version = "1.0.0" dependencies = [ "async-compression", + "async_io_stream", "base64", "bytes", "console_error_panic_hook", @@ -255,11 +250,10 @@ dependencies = [ "fastwebsockets", "futures-util", "getrandom", - "http 1.0.0", + "http", "http-body-util", "hyper", "js-sys", - "penguin-mux-wasm", "pin-project-lite", "rand", "ring", @@ -488,17 +482,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" -[[package]] -name = "http" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.0.0" @@ -517,7 +500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.0.0", + "http", ] [[package]] @@ -528,7 +511,7 @@ checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" dependencies = [ "bytes", "futures-util", - "http 1.0.0", + "http", "http-body", "pin-project-lite", ] @@ -554,7 +537,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "httparse", "httpdate", @@ -573,7 +556,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "hyper", "pin-project-lite", @@ -744,16 +727,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - [[package]] name = "parking_lot_core" version = "0.9.9" @@ -767,23 +740,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "penguin-mux-wasm" -version = "0.1.0" -source = "git+https://github.com/r58Playz/penguin-mux-wasm#69b413aedb6f50f55eac646fda361abe430eb022" -dependencies = [ - "bytes", - "futures-util", - "http 0.2.11", - "parking_lot", - "rand", - "thiserror", - "tokio", - "tokio-tungstenite", - "tracing", - "wasm-bindgen-futures", -] - [[package]] name = "pharos" version = "0.5.3" @@ -1129,7 +1085,6 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -1168,18 +1123,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.7.10" @@ -1201,21 +1144,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ "pin-project-lite", - "tracing-attributes", "tracing-core", ] -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tracing-core" version = "0.1.32" @@ -1231,20 +1162,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "log", - "rand", - "thiserror", - "utf-8", -] - [[package]] name = "typenum" version = "1.17.0" diff --git a/client/Cargo.toml b/client/Cargo.toml index 3a1b3bd..48ff0e5 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -16,7 +16,6 @@ http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["client", "http1"] } pin-project-lite = "0.2.13" -penguin-mux-wasm = { git = "https://github.com/r58Playz/penguin-mux-wasm" } tokio = { version = "1.35.1", default_features = false } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.39" @@ -33,7 +32,8 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } rand = "0.8.5" base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] } +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] } +async_io_stream = { version = "0.3.3", features = ["tokio_io"] } [dependencies.getrandom] features = ["js"] diff --git a/client/src/lib.rs b/client/src/lib.rs index 30aabde..80b9c58 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -8,17 +8,20 @@ mod wrappers; use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use websocket::EpxWebSocket; -use wrappers::{IncomingBody, WsStreamWrapper}; +use wrappers::IncomingBody; use std::sync::Arc; use async_compression::tokio::bufread as async_comp; +use async_io_stream::IoStream; use bytes::Bytes; -use futures_util::StreamExt; +use futures_util::{ + stream::SplitSink, + StreamExt, +}; use http::{uri, HeaderName, HeaderValue, Request, Response}; use hyper::{body::Incoming, client::conn::http1::Builder, Uri}; use js_sys::{Array, Function, Object, Reflect, Uint8Array}; -use penguin_mux_wasm::{Multiplexor, MuxStream}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_util::{ either::Either, @@ -26,6 +29,8 @@ use tokio_util::{ }; use wasm_bindgen::prelude::*; use web_sys::TextEncoder; +use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; +use ws_stream_wasm::{WsMeta, WsStream, WsMessage}; type HttpBody = http_body_util::Full; @@ -40,8 +45,8 @@ enum EpxCompression { Gzip, } -type EpxTlsStream = TlsStream>; -type EpxUnencryptedStream = MuxStream; +type EpxTlsStream = TlsStream>>; +type EpxUnencryptedStream = IoStream>; type EpxStream = Either; async fn send_req( @@ -113,7 +118,7 @@ async fn start() { #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, - mux: Multiplexor, + mux: ClientMux>, useragent: String, redirect_limit: usize, } @@ -138,11 +143,18 @@ impl EpoxyClient { } debug!("connecting to ws {:?}", ws_url); - let ws = WsStreamWrapper::connect(ws_url, None) + let (_, ws) = WsMeta::connect(ws_url, vec!["wisp-v1"]) .await .replace_err("Failed to connect to websocket")?; debug!("connected!"); - let mux = Multiplexor::new(ws, penguin_mux_wasm::Role::Client, None, None); + let (wtx, wrx) = ws.split(); + let (mux, fut) = ClientMux::new(wrx, wtx); + + wasm_bindgen_futures::spawn_local(async move { + if let Err(err) = fut.await { + error!("epoxy: error in mux future! {:?}", err); + } + }); let mut certstore = RootCertStore::empty(); certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); @@ -161,14 +173,16 @@ impl EpoxyClient { }) } - async fn get_http_io(&self, url: &Uri) -> Result { + async fn get_http_io(&mut self, url: &Uri) -> Result { let url_host = url.host().replace_err("URL must have a host")?; let url_port = utils::get_url_port(url)?; let channel = self .mux - .client_new_stream_channel(url_host.as_bytes(), url_port) + .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) .await - .replace_err("Failed to create multiplexor channel")?; + .replace_err("Failed to create multiplexor channel")? + .into_io() + .into_asyncrw(); if utils::get_is_secure(url)? { let cloned_uri = url_host.to_string().clone(); @@ -189,7 +203,7 @@ impl EpoxyClient { } async fn send_req( - &self, + &mut self, req: http::Request, should_redirect: bool, ) -> Result<(hyper::Response, Uri, bool), JsError> { @@ -217,7 +231,7 @@ impl EpoxyClient { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect_ws( - &self, + &mut self, onopen: Function, onclose: Function, onerror: Function, @@ -232,7 +246,11 @@ impl EpoxyClient { .await } - pub async fn fetch(&self, url: String, options: Object) -> Result { + pub async fn fetch( + &mut self, + url: String, + options: Object, + ) -> Result { let uri = url.parse::().replace_err("Failed to parse URL")?; let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?; if *uri_scheme != uri::Scheme::HTTP && *uri_scheme != uri::Scheme::HTTPS { diff --git a/client/src/websocket.rs b/client/src/websocket.rs index addae2c..2ce9149 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -30,7 +30,7 @@ impl EpxWebSocket { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect( - tcp: &EpoxyClient, + tcp: &mut EpoxyClient, onopen: Function, onclose: Function, onerror: Function, diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 1ecc702..8526a98 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -4,117 +4,9 @@ use std::{ task::{Context, Poll}, }; -use futures_util::{Sink, Stream}; +use futures_util::Stream; use hyper::body::Body; -use penguin_mux_wasm::ws; use pin_project_lite::pin_project; -use ws_stream_wasm::{WsErr, WsMessage, WsMeta, WsStream}; - -pin_project! { - pub struct WsStreamWrapper { - #[pin] - ws: WsStream, - } -} - -impl WsStreamWrapper { - pub async fn connect( - url: impl AsRef, - protocols: impl Into>>, - ) -> Result { - let (_, wsstream) = WsMeta::connect(url, protocols).await?; - Ok(WsStreamWrapper { ws: wsstream }) - } -} - -impl Stream for WsStreamWrapper { - type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_next(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(item.map(|x| { - Ok(match x { - WsMessage::Text(txt) => ws::Message::Text(txt), - WsMessage::Binary(bin) => ws::Message::Binary(bin), - }) - })), - Poll::Pending => Poll::>::Pending, - } - } -} - -fn wserr_to_ws_err(err: WsErr) -> ws::Error { - debug!("err: {:?}", err); - match err { - WsErr::ConnectionNotOpen => ws::Error::AlreadyClosed, - _ => ws::Error::Io(std::io::Error::other(format!("{:?}", err))), - } -} - -impl Sink for WsStreamWrapper { - type Error = ws::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_ready(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(match item { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - }), - Poll::Pending => Poll::>::Pending, - } - } - - fn start_send(self: Pin<&mut Self>, item: ws::Message) -> Result<(), Self::Error> { - use ws::Message::*; - let item = match item { - Text(txt) => WsMessage::Text(txt), - Binary(bin) => WsMessage::Binary(bin), - Close(_) => { - debug!("closing"); - return match self.ws.wrapped().close() { - Ok(_) => Ok(()), - Err(err) => Err(ws::Error::Io(std::io::Error::other(format!( - "ws close err: {:?}", - err - )))), - }; - } - Ping(_) | Pong(_) | Frame(_) => return Ok(()), - }; - let this = self.project(); - let ret = this.ws.start_send(item); - match ret { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - } - } - - // no point wrapping this as it's not going to do anything - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_close(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(match item { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - }), - Poll::Pending => Poll::>::Pending, - } - } -} - -impl ws::WebSocketStream for WsStreamWrapper { - fn ping_auto_pong(&self) -> bool { - true - } -} pin_project! { pub struct IncomingBody { diff --git a/server/src/main.rs b/server/src/main.rs index 11f6478..7b0b35c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -101,7 +101,7 @@ async fn accept_http( async fn handle_mux( packet: ConnectPacket, - mut stream: MuxStream, + mut stream: MuxStream, ) -> Result { let uri = format!( "{}:{}", @@ -174,9 +174,7 @@ async fn accept_ws( println!("{:?}: connected", addr); - let mut mux = ServerMux::new(rx, tx); - - mux.server_loop(&mut |packet, stream| async move { + ServerMux::handle(rx, tx, &mut |packet, stream| async move { let mut close_err = stream.get_close_handle(); let mut close_ok = stream.get_close_handle(); tokio::spawn(async move { diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 9dc0a2d..ee3c3d2 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -17,3 +17,4 @@ ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] ws_stream_wasm = ["dep:ws_stream_wasm"] +tokio_io = ["async_io_stream/tokio_io"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index f020bfd..fb31e4a 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -53,10 +53,10 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for FragmentCollectorRead { +impl crate::ws::WebSocketRead for FragmentCollectorRead { async fn wisp_read_frame( &mut self, - tx: &mut crate::ws::LockedWebSocketWrite, + tx: &crate::ws::LockedWebSocketWrite, ) -> Result { Ok(self .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) @@ -65,7 +65,7 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead } } -impl crate::ws::WebSocketWrite for WebSocketWrite { +impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { self.write_frame(frame.try_into()?).await.map_err(|e| e.into()) } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 2eb0594..d4f843e 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -10,7 +10,7 @@ pub use crate::packet::*; pub use crate::stream::*; use dashmap::DashMap; -use futures::{channel::mpsc, StreamExt}; +use futures::{channel::mpsc, Future, StreamExt}; use std::sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, @@ -68,38 +68,66 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub struct ServerMux +pub struct ServerMux where - R: ws::WebSocketRead, W: ws::WebSocketWrite, { - rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, - close_rx: mpsc::UnboundedReceiver, close_tx: mpsc::UnboundedSender, } -impl ServerMux { - pub fn new(read: R, write: W) -> Self { +impl ServerMux { + pub fn handle<'a, FR, R>( + read: R, + write: W, + handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> impl Future> + 'a + where + FR: std::future::Future> + 'a, + R: ws::WebSocketRead + 'a, + W: ws::WebSocketWrite + 'a, + { let (tx, rx) = mpsc::unbounded::(); - Self { - rx: read, - tx: ws::LockedWebSocketWrite::new(write), - stream_map: Arc::new(DashMap::new()), - close_rx: rx, + let write = ws::LockedWebSocketWrite::new(write); + let map = Arc::new(DashMap::new()); + let inner = ServerMux { + stream_map: map.clone(), + tx: write.clone(), close_tx: tx, - } + }; + inner.into_future(read, rx, handler_fn) } - pub async fn server_bg_loop(&mut self) { - while let Some(msg) = self.close_rx.next().await { + async fn into_future( + self, + rx: R, + close_rx: mpsc::UnboundedReceiver, + handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead, + FR: std::future::Future>, + { + futures::try_join! { + self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()), + self.server_msg_loop(rx, handler_fn) + } + .map(|_| ()) + } + + async fn server_close_loop( + &self, + mut close_rx: mpsc::UnboundedReceiver, + stream_map: Arc>>, + tx: ws::LockedWebSocketWrite, + ) -> Result<(), WispError> { + while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if self.stream_map.clone().remove(&stream_id).is_some() { + if stream_map.clone().remove(&stream_id).is_some() { let _ = channel.send( - self.tx - .write_frame(Packet::new_close(stream_id, reason).into()) + tx.write_frame(Packet::new_close(stream_id, reason).into()) .await, ); } else { @@ -108,20 +136,23 @@ impl ServerMux { } } } + Ok(()) } - pub async fn server_loop( - &mut self, + async fn server_msg_loop( + &self, + mut rx: R, handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, ) -> Result<(), WispError> where - FR: std::future::Future>, + R: ws::WebSocketRead, + FR: std::future::Future>, { self.tx .write_frame(Packet::new_continue(0, u32::MAX).into()) .await?; - while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + while let Ok(frame) = rx.wisp_read_frame(&self.tx).await { if let Ok(packet) = Packet::try_from(frame) { use PacketType::*; match packet.packet { @@ -164,34 +195,31 @@ impl ServerMux { } } -pub struct ClientMux +pub struct ClientMuxInner where - R: ws::WebSocketRead, W: ws::WebSocketWrite, { - rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, - next_free_stream_id: AtomicU32, - close_rx: mpsc::UnboundedReceiver, - close_tx: mpsc::UnboundedSender, } -impl ClientMux { - pub fn new(read: R, write: W) -> Self { - let (tx, rx) = mpsc::unbounded::(); - Self { - rx: read, - tx: ws::LockedWebSocketWrite::new(write), - stream_map: Arc::new(DashMap::new()), - next_free_stream_id: AtomicU32::new(1), - close_rx: rx, - close_tx: tx, - } +impl ClientMuxInner { + pub async fn into_future( + self, + rx: R, + close_rx: mpsc::UnboundedReceiver, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ()) } - pub async fn client_bg_loop(&mut self) { - while let Some(msg) = self.close_rx.next().await { + async fn client_bg_loop( + &self, + mut close_rx: mpsc::UnboundedReceiver, + ) -> Result<(), WispError> { + while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { if self.stream_map.clone().remove(&stream_id).is_some() { @@ -206,14 +234,14 @@ impl ClientMux { } } } + Ok(()) } - pub async fn client_loop(&mut self) -> Result<(), WispError> { - self.tx - .write_frame(Packet::new_continue(0, u32::MAX).into()) - .await?; - - while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + async fn client_loop(&self, mut rx: R) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + while let Ok(frame) = rx.wisp_read_frame(&self.tx).await { if let Ok(packet) = Packet::try_from(frame) { use PacketType::*; match packet.packet { @@ -235,12 +263,52 @@ impl ClientMux { } Ok(()) } +} + +pub struct ClientMux +where + W: ws::WebSocketWrite, +{ + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, + next_free_stream_id: AtomicU32, + close_tx: mpsc::UnboundedSender, +} + +impl ClientMux { + pub fn new(read: R, write: W) -> (Self, impl Future>) + where + R: ws::WebSocketRead, + { + let (tx, rx) = mpsc::unbounded::(); + let map = Arc::new(DashMap::new()); + let write = ws::LockedWebSocketWrite::new(write); + ( + Self { + tx: write.clone(), + stream_map: map.clone(), + next_free_stream_id: AtomicU32::new(1), + close_tx: tx, + }, + ClientMuxInner { + tx: write.clone(), + stream_map: map.clone(), + } + .into_future(read, rx), + ) + } pub async fn client_new_stream( &mut self, + stream_type: StreamType, + host: String, + port: u16, ) -> Result, WispError> { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_id = self.next_free_stream_id.load(Ordering::Acquire); + self.tx + .write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) + .await?; self.next_free_stream_id.store( stream_id .checked_add(1) diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index ff86585..cd9daab 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -4,7 +4,7 @@ use futures::{ channel::{mpsc, oneshot}, sink, stream, task::{Context, Poll}, - AsyncRead, AsyncWrite, Sink, Stream, StreamExt, + Sink, Stream, StreamExt, }; use pin_project_lite::pin_project; use std::{ @@ -44,7 +44,7 @@ impl MuxStreamRead { } } - pub(crate) fn into_stream(self) -> Pin>> { + pub(crate) fn into_stream(self) -> Pin + Send>> { Box::pin(stream::unfold(self, |mut rx| async move { let evt = rx.read().await?; Some(( @@ -68,7 +68,7 @@ where is_closed: Arc, } -impl MuxStreamWrite { +impl MuxStreamWrite { pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); @@ -101,10 +101,7 @@ impl MuxStreamWrite { Ok(()) } - pub(crate) fn into_sink<'a>(self) -> Pin + 'a>> - where - W: 'a, - { + pub(crate) fn into_sink(self) -> Pin + Send>> { Box::pin(sink::unfold(self, |mut tx, data| async move { tx.write(data).await?; Ok(tx) @@ -130,7 +127,7 @@ where tx: MuxStreamWrite, } -impl MuxStream { +impl MuxStream { pub(crate) fn new( stream_id: u32, rx: mpsc::UnboundedReceiver, @@ -174,10 +171,7 @@ impl MuxStream { (self.rx, self.tx) } - pub fn into_io<'a>(self) -> MuxStreamIo<'a> - where - W: 'a, - { + pub fn into_io(self) -> MuxStreamIo { MuxStreamIo { rx: self.rx.into_stream(), tx: self.tx.into_sink(), @@ -208,55 +202,54 @@ impl MuxStreamCloser { } pin_project! { - pub struct MuxStreamIo<'a> { + pub struct MuxStreamIo { #[pin] - rx: Pin + 'a>>, + rx: Pin + Send>>, #[pin] - tx: Pin + 'a>>, + tx: Pin + Send>>, } } -impl<'a> MuxStreamIo<'a> { - pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a { - IoStream::new(self.map(|x| Ok::, std::io::Error>(x.to_vec()))) +impl MuxStreamIo { + pub fn into_asyncrw(self) -> IoStream> { + IoStream::new(self) } } -impl Stream for MuxStreamIo<'_> { - type Item = Bytes; +impl Stream for MuxStreamIo { + type Item = Result, std::io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx) + self.project() + .rx + .poll_next(cx) + .map(|x| x.map(|x| Ok(x.to_vec()))) } } -impl Sink for MuxStreamIo<'_> { - type Error = crate::WispError; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_ready(cx) - } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - self.project().tx.start_send(item) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } -} - -impl Sink> for MuxStreamIo<'_> { +impl Sink> for MuxStreamIo { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_ready(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_ready(cx) + .map_err(std::io::Error::other) } fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - self.project().tx.start_send(item.into()).map_err(std::io::Error::other) + self.project() + .tx + .start_send(item.into()) + .map_err(std::io::Error::other) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_flush(cx) + .map_err(std::io::Error::other) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_close(cx) + .map_err(std::io::Error::other) } } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 5b1243e..f75c526 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -46,20 +46,20 @@ impl Frame { pub trait WebSocketRead { fn wisp_read_frame( &mut self, - tx: &mut crate::ws::LockedWebSocketWrite, - ) -> impl std::future::Future>; + tx: &crate::ws::LockedWebSocketWrite, + ) -> impl std::future::Future> + Send; } pub trait WebSocketWrite { fn wisp_write_frame( &mut self, frame: Frame, - ) -> impl std::future::Future>; + ) -> impl std::future::Future> + Send; } pub struct LockedWebSocketWrite(Arc>); -impl LockedWebSocketWrite { +impl LockedWebSocketWrite { pub fn new(ws: S) -> Self { Self(Arc::new(Mutex::new(ws))) } diff --git a/wisp/src/ws_stream_wasm.rs b/wisp/src/ws_stream_wasm.rs index 6e15816..410b537 100644 --- a/wisp/src/ws_stream_wasm.rs +++ b/wisp/src/ws_stream_wasm.rs @@ -1,4 +1,4 @@ -use futures::{SinkExt, StreamExt}; +use futures::{stream::{SplitStream, SplitSink}, SinkExt, StreamExt}; use ws_stream_wasm::{WsErr, WsMessage, WsStream}; impl From for crate::ws::Frame { @@ -37,10 +37,10 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for WsStream { +impl crate::ws::WebSocketRead for SplitStream { async fn wisp_read_frame( &mut self, - _: &mut crate::ws::LockedWebSocketWrite, + _: &crate::ws::LockedWebSocketWrite, ) -> Result { Ok(self .next() @@ -50,8 +50,11 @@ impl crate::ws::WebSocketRead for WsStream { } } -impl crate::ws::WebSocketWrite for WsStream { +impl crate::ws::WebSocketWrite for SplitSink { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { - self.send(frame.try_into()?).await.map_err(|e| e.into()) + self + .send(frame.try_into()?) + .await + .map_err(|e| e.into()) } }