diff --git a/Cargo.lock b/Cargo.lock index 563e924..e7df3e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,7 +698,6 @@ name = "epoxy-client" version = "2.1.15" dependencies = [ "async-compression", - "async-trait", "bytes", "cfg-if", "event-listener", @@ -755,6 +754,7 @@ dependencies = [ "libc", "log", "nix", + "pin-project-lite", "pty-process", "regex", "rustls-pemfile", @@ -1850,6 +1850,12 @@ dependencies = [ "quick-error", ] +[[package]] +name = "reusable-box-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9" + [[package]] name = "ring" version = "0.17.8" @@ -3015,6 +3021,7 @@ dependencies = [ "getrandom", "nohash-hasher", "pin-project-lite", + "reusable-box-future", "thiserror", "tokio", ] diff --git a/client/Cargo.toml b/client/Cargo.toml index e0d6342..c98a68a 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -8,7 +8,6 @@ crate-type = ["cdylib"] [dependencies] async-compression = { version = "0.4.12", features = ["futures-io", "gzip", "brotli"], optional = true } -async-trait = "0.1.81" bytes = "1.7.1" cfg-if = "1.0.0" event-listener = "5.3.1" diff --git a/client/src/lib.rs b/client/src/lib.rs index c455faf..949839e 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,5 +1,5 @@ #![feature(let_chains, impl_trait_in_assoc_type)] -use std::{error::Error, str::FromStr, sync::Arc}; +use std::{error::Error, pin::Pin, str::FromStr, sync::Arc}; #[cfg(feature = "full")] use async_compression::futures::bufread as async_comp; @@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut}; use cfg_if::cfg_if; #[cfg(feature = "full")] use futures_util::future::Either; -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::{Stream, StreamExt, TryStreamExt}; use http::{ header::{ InvalidHeaderName, InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, @@ -41,7 +41,7 @@ use websocket::EpoxyWebSocket; use wisp_mux::StreamType; use wisp_mux::{ generic::GenericWebSocketRead, - ws::{WebSocketRead, WebSocketWrite}, + ws::{EitherWebSocketRead, EitherWebSocketWrite}, CloseReason, }; use ws_wrapper::WebSocketWrapper; @@ -343,7 +343,7 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator { } .into(); - let read = GenericWebSocketRead::new(SendWrapper::new( + let read = GenericWebSocketRead::new(Box::pin(SendWrapper::new( wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into()) .into_stream() .map(|x| { @@ -355,15 +355,16 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator { Uint8Array::new(&arr).to_vec().as_slice(), )) }), - )); + )) + as Pin> + Send>>); let write: WritableStream = object_get(&transport, "write").into(); let write = WispTransportWrite { inner: SendWrapper::new(write.get_writer().map_err(EpoxyError::wisp_transport)?), }; Ok(( - Box::new(read) as Box, - Box::new(write) as Box, + EitherWebSocketRead::Right(read), + EitherWebSocketWrite::Right(write), )) })) }) @@ -421,8 +422,8 @@ impl EpoxyClient { } } Ok(( - Box::new(read) as Box, - Box::new(write) as Box, + EitherWebSocketRead::Left(read), + EitherWebSocketWrite::Left(write), )) }) }), diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 3124ed1..c5046d1 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -1,5 +1,6 @@ use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll}; +use bytes::BytesMut; use cfg_if::cfg_if; use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, @@ -8,7 +9,7 @@ use futures_rustls::{ use futures_util::{ future::Either, lock::{Mutex, MutexGuard}, - AsyncRead, AsyncWrite, Future, + AsyncRead, AsyncWrite, Future, Stream, }; use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection}; use pin_project_lite::pin_project; @@ -16,18 +17,30 @@ use wasm_bindgen_futures::spawn_local; use webpki_roots::TLS_SERVER_ROOTS; use wisp_mux::{ extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder}, - ws::{WebSocketRead, WebSocketWrite}, + generic::GenericWebSocketRead, + ws::{EitherWebSocketRead, EitherWebSocketWrite}, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake, }; use crate::{ - console_error, console_log, utils::{IgnoreCloseNotify, NoCertificateVerification}, EpoxyClientOptions, EpoxyError + console_error, console_log, + utils::{IgnoreCloseNotify, NoCertificateVerification, WispTransportWrite}, + ws_wrapper::{WebSocketReader, WebSocketWrapper}, + EpoxyClientOptions, EpoxyError, }; pub type ProviderUnencryptedStream = MuxStreamIo; pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; pub type ProviderTlsAsyncRW = IgnoreCloseNotify; pub type ProviderAsyncRW = Either; +pub type ProviderWispTransportRead = EitherWebSocketRead< + WebSocketReader, + GenericWebSocketRead< + Pin> + Send>>, + EpoxyError, + >, +>; +pub type ProviderWispTransportWrite = EitherWebSocketWrite; pub type ProviderWispTransportGenerator = Box< dyn Fn( bool, @@ -35,10 +48,7 @@ pub type ProviderWispTransportGenerator = Box< Box< dyn Future< Output = Result< - ( - Box, - Box, - ), + (ProviderWispTransportRead, ProviderWispTransportWrite), EpoxyError, >, > + Sync @@ -54,7 +64,7 @@ pub struct StreamProvider { wisp_v2: bool, udp_extension: bool, - current_client: Arc>>, + current_client: Arc>>>, h2_config: Arc, client_config: Arc, @@ -115,7 +125,7 @@ impl StreamProvider { async fn create_client( &self, - mut locked: MutexGuard<'_, Option>, + mut locked: MutexGuard<'_, Option>>, ) -> Result<(), EpoxyError> { let extensions_vec: Vec = vec![AnyProtocolExtensionBuilder::new( @@ -140,7 +150,11 @@ impl StreamProvider { spawn_local(async move { match fut.await { Ok(_) => console_log!("epoxy: wisp multiplexor task ended successfully"), - Err(x) => console_error!("epoxy: wisp multiplexor task ended with an error: {} {:?}", x, x), + Err(x) => console_error!( + "epoxy: wisp multiplexor task ended with an error: {} {:?}", + x, + x + ), } current_client.lock().await.take(); }); diff --git a/client/src/utils/js.rs b/client/src/utils/js.rs index a467624..ec43c3a 100644 --- a/client/src/utils/js.rs +++ b/client/src/utils/js.rs @@ -1,4 +1,7 @@ -use std::{pin::Pin, task::{Context, Poll}}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; use bytes::Bytes; use futures_util::{AsyncRead, Stream, StreamExt, TryStreamExt}; @@ -12,7 +15,6 @@ use crate::{console_error, EpoxyError}; use super::ReaderStream; - #[wasm_bindgen(inline_js = r#" export function ws_protocol() { return ( diff --git a/client/src/utils/mod.rs b/client/src/utils/mod.rs index 27badd6..dc39ef0 100644 --- a/client/src/utils/mod.rs +++ b/client/src/utils/mod.rs @@ -8,7 +8,6 @@ use std::{ task::{Context, Poll}, }; -use async_trait::async_trait; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use futures_util::{ready, AsyncRead, Future, Stream}; use http::{HeaderValue, Uri}; @@ -179,7 +178,6 @@ pub struct WispTransportWrite { pub inner: SendWrapper, } -#[async_trait] impl WebSocketWrite for WispTransportWrite { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { SendWrapper::new(async { diff --git a/client/src/utils/rustls.rs b/client/src/utils/rustls.rs index 602cba5..a5e3101 100644 --- a/client/src/utils/rustls.rs +++ b/client/src/utils/rustls.rs @@ -1,5 +1,8 @@ use std::{ - io::ErrorKind, pin::Pin, sync::Arc, task::{Context, Poll} + io::ErrorKind, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; use futures_rustls::{ diff --git a/client/src/ws_wrapper.rs b/client/src/ws_wrapper.rs index 85a0668..c86141c 100644 --- a/client/src/ws_wrapper.rs +++ b/client/src/ws_wrapper.rs @@ -3,7 +3,6 @@ use std::sync::{ Arc, }; -use async_trait::async_trait; use bytes::BytesMut; use event_listener::Event; use flume::Receiver; @@ -14,7 +13,7 @@ use thiserror::Error; use wasm_bindgen::{closure::Closure, JsCast, JsValue}; use web_sys::{BinaryType, MessageEvent, WebSocket}; use wisp_mux::{ - ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{Frame, LockingWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, WispError, }; @@ -66,11 +65,10 @@ pub struct WebSocketReader { close_event: Arc, } -#[async_trait] impl WebSocketRead for WebSocketReader { async fn wisp_read_frame( &mut self, - _: &LockedWebSocketWrite, + _: &dyn LockingWebSocketWrite, ) -> Result, WispError> { use WebSocketMessage as M; if self.closed.load(Ordering::Acquire) { @@ -185,7 +183,6 @@ impl WebSocketWrapper { } } -#[async_trait] impl WebSocketWrite for WebSocketWrapper { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { use wisp_mux::ws::OpCode::*; diff --git a/server/Cargo.toml b/server/Cargo.toml index 5328d8e..98e0388 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -24,6 +24,7 @@ lazy_static = "1.5.0" libc = { version = "0.2.158", optional = true } log = { version = "0.4.22", features = ["serde", "std"] } nix = { version = "0.29.0", features = ["term"] } +pin-project-lite = "0.2.15" pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true } regex = "1.10.6" rustls-pemfile = "2.1.3" diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index b49129e..b7b5f3c 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -25,7 +25,7 @@ use wisp_mux::{ }; use crate::{ - route::WispResult, + route::{WispResult, WispStreamWrite}, stream::{ClientStream, ResolvedPacket}, CLIENTS, CONFIG, }; @@ -58,7 +58,7 @@ async fn copy_read_fast( } async fn copy_write_fast( - muxtx: MuxStreamWrite, + muxtx: MuxStreamWrite, tcprx: OwnedReadHalf, #[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter< async_speed_limit::clock::StandardClock, @@ -83,7 +83,7 @@ async fn copy_write_fast( async fn handle_stream( connect: ConnectPacket, - muxstream: MuxStream, + muxstream: MuxStream, id: String, event: Arc, #[cfg(feature = "twisp")] twisp_map: twisp::TwispMap, diff --git a/server/src/main.rs b/server/src/main.rs index 61328fd..3393bba 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -37,6 +37,8 @@ mod route; mod stats; #[doc(hidden)] mod stream; +#[doc(hidden)] +mod util_chain; #[doc(hidden)] type Client = (DashMap, bool); diff --git a/server/src/route.rs b/server/src/route.rs index c6fb601..8642177 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -2,7 +2,7 @@ use std::{fmt::Display, future::Future, io::Cursor}; use anyhow::Context; use bytes::Bytes; -use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; +use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite}; use http_body_util::Full; use hyper::{ body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder, @@ -10,25 +10,30 @@ use hyper::{ }; use hyper_util::rt::TokioIo; use log::{debug, error, trace}; -use tokio::io::AsyncReadExt; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use wisp_mux::{ generic::{GenericWebSocketRead, GenericWebSocketWrite}, - ws::{WebSocketRead, WebSocketWrite}, + ws::{EitherWebSocketRead, EitherWebSocketWrite}, }; use crate::{ config::SocketTransport, generate_stats, - listener::{ServerStream, ServerStreamExt}, + listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite}, stream::WebSocketStreamWrapper, + util_chain::{chain, Chain}, CONFIG, }; -pub type WispResult = ( - Box, - Box, -); +pub type WispStreamRead = EitherWebSocketRead< + WebSocketRead, ServerStreamRead>>, + GenericWebSocketRead, std::io::Error>, +>; +pub type WispStreamWrite = EitherWebSocketWrite< + WebSocketWrite, + GenericWebSocketWrite, std::io::Error>, +>; +pub type WispResult = (WispStreamRead, WispStreamWrite); pub enum ServerRouteResult { Wisp(WispResult, bool), @@ -190,12 +195,15 @@ pub async fn route( .downcast::>() .unwrap(); let (r, w) = parts.io.into_inner().split(); - (Cursor::new(parts.read_buf).chain(r), w) + (chain(Cursor::new(parts.read_buf), r), w) }); (callback)( ServerRouteResult::Wisp( - (Box::new(read), Box::new(write)), + ( + EitherWebSocketRead::Left(read), + EitherWebSocketWrite::Left(write), + ), is_v2, ), maybe_ip, @@ -229,7 +237,13 @@ pub async fn route( let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); (callback)( - ServerRouteResult::Wisp((Box::new(read), Box::new(write)), true), + ServerRouteResult::Wisp( + ( + EitherWebSocketRead::Right(read), + EitherWebSocketWrite::Right(write), + ), + true, + ), None, ); } diff --git a/server/src/stream.rs b/server/src/stream.rs index 49d498e..4da30e0 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -44,7 +44,6 @@ pub enum ClientStream { Invalid, } - // taken from rust 1.82.0 fn ipv4_is_global(addr: &Ipv4Addr) -> bool { !(addr.octets()[0] == 0 // "This network" diff --git a/server/src/util_chain.rs b/server/src/util_chain.rs new file mode 100644 index 0000000..a1b5b33 --- /dev/null +++ b/server/src/util_chain.rs @@ -0,0 +1,100 @@ +// taken from tokio io util + +use std::{ + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::ready; +use pin_project_lite::pin_project; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +pin_project! { + pub struct Chain { + #[pin] + first: T, + #[pin] + second: U, + done_first: bool, + } +} + +pub fn chain(first: T, second: U) -> Chain +where + T: AsyncRead, + U: AsyncRead, +{ + Chain { + first, + second, + done_first: false, + } +} + +impl fmt::Debug for Chain +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Chain") + .field("t", &self.first) + .field("u", &self.second) + .finish() + } +} + +impl AsyncRead for Chain +where + T: AsyncRead, + U: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + + if !*me.done_first { + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); + } + } + me.second.poll_read(cx, buf) + } +} + +impl AsyncBufRead for Chain +where + T: AsyncBufRead, + U: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + + if !*me.done_first { + match ready!(me.first.poll_fill_buf(cx)?) { + [] => { + *me.done_first = true; + } + buf => return Poll::Ready(Ok(buf)), + } + } + me.second.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + if !*me.done_first { + me.first.consume(amt) + } else { + me.second.consume(amt) + } + } +} diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index d89a0ab..58ae617 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -245,7 +245,7 @@ async fn main() -> Result<(), Box> { })); threads.push(tokio::spawn(async move { loop { - cr.read().await; + let _ = cr.read().await; } })); } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 539f253..a2a1159 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -23,6 +23,7 @@ futures = "0.3.30" getrandom = { version = "0.2.15", features = ["std"], optional = true } nohash-hasher = "0.2.0" pin-project-lite = "0.2.14" +reusable-box-future = "0.2.0" thiserror = "1.0.65" tokio = { version = "1.39.3", optional = true, default-features = false } diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index 3bdee97..1112510 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -11,7 +11,7 @@ use ed25519::{ }; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, Role, WispError, }; @@ -183,8 +183,8 @@ impl ProtocolExtension for CertAuthProtocolExtension { async fn handle_handshake( &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } @@ -192,8 +192,8 @@ impl ProtocolExtension for CertAuthProtocolExtension { async fn handle_packet( &mut self, _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 0db95be..b4eaf5f 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, Role, WispError, }; @@ -105,16 +105,16 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { /// This should be used to send or receive data before any streams are created. async fn handle_handshake( &mut self, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, + read: &mut DynWebSocketRead, + write: &dyn LockingWebSocketWrite, ) -> Result<(), WispError>; /// Handle receiving a packet. async fn handle_packet( &mut self, packet: Bytes, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, + read: &mut DynWebSocketRead, + write: &dyn LockingWebSocketWrite, ) -> Result<(), WispError>; /// Clone the protocol extension. diff --git a/wisp/src/extensions/motd.rs b/wisp/src/extensions/motd.rs index d93349c..0747df9 100644 --- a/wisp/src/extensions/motd.rs +++ b/wisp/src/extensions/motd.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; use bytes::Bytes; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, Role, WispError, }; @@ -48,8 +48,8 @@ impl ProtocolExtension for MotdProtocolExtension { async fn handle_handshake( &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } @@ -57,8 +57,8 @@ impl ProtocolExtension for MotdProtocolExtension { async fn handle_packet( &mut self, _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index ad5f71f..08b26f6 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, Role, WispError, }; @@ -94,17 +94,17 @@ impl ProtocolExtension for PasswordProtocolExtension { async fn handle_handshake( &mut self, - _read: &mut dyn WebSocketRead, - _write: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } async fn handle_packet( &mut self, - _packet: Bytes, - _read: &mut dyn WebSocketRead, - _write: &LockedWebSocketWrite, + _: Bytes, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Err(WispError::ExtensionImplNotSupported) } diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index 33bea07..1bb32c0 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use bytes::Bytes; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, WispError, }; @@ -40,8 +40,8 @@ impl ProtocolExtension for UdpProtocolExtension { async fn handle_handshake( &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } @@ -49,8 +49,8 @@ impl ProtocolExtension for UdpProtocolExtension { async fn handle_packet( &mut self, _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 63a463a..9ce1d9c 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -2,7 +2,6 @@ use std::ops::Deref; -use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead, @@ -10,7 +9,7 @@ use fastwebsockets::{ }; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{ws::LockedWebSocketWrite, WispError}; +use crate::{ws::LockingWebSocketWrite, WispError}; fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> { match payload { @@ -87,27 +86,25 @@ impl From for crate::WispError { } } -#[async_trait] impl crate::ws::WebSocketRead for FragmentCollectorRead { async fn wisp_read_frame( &mut self, - tx: &LockedWebSocketWrite, + tx: &dyn LockingWebSocketWrite, ) -> Result, WispError> { Ok(self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) .await? .into()) } } -#[async_trait] impl crate::ws::WebSocketRead for WebSocketRead { async fn wisp_read_frame( &mut self, - tx: &LockedWebSocketWrite, + tx: &dyn LockingWebSocketWrite, ) -> Result, WispError> { let mut frame = self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) .await?; if frame.opcode == OpCode::Continuation { @@ -121,7 +118,7 @@ impl crate::ws::WebSocketRead for WebSocketRead while !frame.fin { frame = self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) .await?; if frame.opcode != OpCode::Continuation { @@ -142,11 +139,11 @@ impl crate::ws::WebSocketRead for WebSocketRead async fn wisp_read_split( &mut self, - tx: &LockedWebSocketWrite, + tx: &dyn LockingWebSocketWrite, ) -> Result<(crate::ws::Frame<'static>, Option>), WispError> { let mut frame_cnt = 1; let mut frame = self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) .await?; let mut extra_frame = None; @@ -161,7 +158,7 @@ impl crate::ws::WebSocketRead for WebSocketRead while !frame.fin { frame = self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) .await?; if frame.opcode != OpCode::Continuation { @@ -197,7 +194,6 @@ impl crate::ws::WebSocketRead for WebSocketRead } } -#[async_trait] impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> { self.write_frame(frame.into()).await.map_err(|e| e.into()) diff --git a/wisp/src/generic.rs b/wisp/src/generic.rs index 5589623..45a8e5e 100644 --- a/wisp/src/generic.rs +++ b/wisp/src/generic.rs @@ -1,12 +1,11 @@ //! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s. -use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use futures::{Sink, SinkExt, Stream, StreamExt}; use std::error::Error; use crate::{ - ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, + ws::{Frame, LockingWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, WispError, }; @@ -30,13 +29,12 @@ impl> + Send + Unpin, E: Error + Sync + Sen } } -#[async_trait] impl> + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketRead for GenericWebSocketRead { async fn wisp_read_frame( &mut self, - _tx: &LockedWebSocketWrite, + _tx: &dyn LockingWebSocketWrite, ) -> Result, WispError> { match self.0.next().await { Some(data) => Ok(Frame::binary(Payload::Bytes( @@ -67,7 +65,6 @@ impl + Send + Unpin, E: Error + Sync + Send + 'static> } } -#[async_trait] impl + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite for GenericWebSocketWrite { diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs index 80a8031..56d611a 100644 --- a/wisp/src/mux/client.rs +++ b/wisp/src/mux/client.rs @@ -12,7 +12,7 @@ use futures::channel::oneshot; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, mux::send_info_packet, - ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; @@ -24,9 +24,9 @@ use super::{ WispV2Handshake, }; -async fn handshake( +async fn handshake( rx: &mut R, - tx: &LockedWebSocketWrite, + tx: &LockedWebSocketWrite, v2_info: Option, ) -> Result<(WispHandshakeResult, u32), WispError> { if let Some(WispV2Handshake { @@ -47,7 +47,9 @@ async fn handshake( let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); for extension in supported_extensions.iter_mut() { - extension.handle_handshake(rx, tx).await?; + extension + .handle_handshake(DynWebSocketRead::from_mut(rx), tx) + .await?; } Ok(( @@ -86,34 +88,36 @@ async fn handshake( } /// Client side multiplexor. -pub struct ClientMux { +pub struct ClientMux { /// Whether the connection was downgraded to Wisp v1. /// /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extensions: Vec, - actor_tx: mpsc::Sender, - tx: LockedWebSocketWrite, + actor_tx: mpsc::Sender>, + tx: LockedWebSocketWrite, actor_exited: Arc, } -impl ClientMux { +impl ClientMux { /// Create a new client side multiplexor. /// /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( + pub async fn create( mut rx: R, tx: W, wisp_v2: Option, - ) -> Result> + Send>, WispError> + ) -> Result< + MuxResult, impl Future> + Send>, + WispError, + > where - R: WebSocketRead + Send, - W: WebSocketWrite + Send + 'static, + R: WebSocketRead + 'static, { - let tx = LockedWebSocketWrite::new(Box::new(tx)); + let tx = LockedWebSocketWrite::new(tx); let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?; let (extensions, extra_packet) = handshake_result.kind.into_parts(); @@ -146,7 +150,7 @@ impl ClientMux { stream_type: StreamType, host: String, port: u16, - ) -> Result { + ) -> Result, WispError> { if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } @@ -206,7 +210,7 @@ impl ClientMux { } /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { MuxProtocolExtensionStream { stream_id: 0, tx: self.tx.clone(), @@ -215,13 +219,13 @@ impl ClientMux { } } -impl Drop for ClientMux { +impl Drop for ClientMux { fn drop(&mut self) { let _ = self.actor_tx.send(WsEvent::EndFut(None)); } } -impl Multiplexor for ClientMux { +impl Multiplexor for ClientMux { fn has_extension(&self, extension_id: u8) -> bool { self.supported_extensions .iter() diff --git a/wisp/src/mux/inner.rs b/wisp/src/mux/inner.rs index 1b41386..a7afa2f 100644 --- a/wisp/src/mux/inner.rs +++ b/wisp/src/mux/inner.rs @@ -5,23 +5,23 @@ use std::sync::{ use crate::{ extensions::AnyProtocolExtension, - ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, + ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; use bytes::{Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; -use futures::{channel::oneshot, select, FutureExt}; +use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt}; use nohash_hasher::IntMap; -pub(crate) enum WsEvent { +pub(crate) enum WsEvent { Close(Packet<'static>, oneshot::Sender>), CreateStream( StreamType, String, u16, - oneshot::Sender>, + oneshot::Sender, WispError>>, ), SendPing(Payload<'static>, oneshot::Sender>), SendPong(Payload<'static>), @@ -43,20 +43,21 @@ struct MuxMapValue { is_closed_event: Arc, } -pub struct MuxInner { +pub struct MuxInner { // gets taken by the mux task rx: Option, // gets taken by the mux task maybe_downgrade_packet: Option>, - tx: LockedWebSocketWrite, - extensions: Vec, + tx: LockedWebSocketWrite, + // gets taken by the mux task + extensions: Option>, tcp_extensions: Vec, role: Role, // gets taken by the mux task - actor_rx: Option>, - actor_tx: mpsc::Sender, + actor_rx: Option>>, + actor_tx: mpsc::Sender>, fut_exited: Arc, stream_map: IntMap, @@ -64,16 +65,16 @@ pub struct MuxInner { buffer_size: u32, target_buffer_size: u32, - server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, + server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, } -pub struct MuxInnerResult { - pub mux: MuxInner, +pub struct MuxInnerResult { + pub mux: MuxInner, pub actor_exited: Arc, - pub actor_tx: mpsc::Sender, + pub actor_tx: mpsc::Sender>, } -impl MuxInner { +impl MuxInner { fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec { extensions .iter() @@ -83,18 +84,19 @@ impl MuxInner { .collect() } + #[allow(clippy::type_complexity)] pub fn new_server( rx: R, maybe_downgrade_packet: Option>, - tx: LockedWebSocketWrite, + tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, ) -> ( - MuxInnerResult, - mpsc::Receiver<(ConnectPacket, MuxStream)>, + MuxInnerResult, + mpsc::Receiver<(ConnectPacket, MuxStream)>, ) { - let (fut_tx, fut_rx) = mpsc::bounded::(256); - let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let (fut_tx, fut_rx) = mpsc::bounded::>(256); + let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let ret_fut_tx = fut_tx.clone(); let fut_exited = Arc::new(AtomicBool::new(false)); @@ -110,7 +112,7 @@ impl MuxInner { fut_exited: fut_exited.clone(), tcp_extensions: Self::get_tcp_extensions(&extensions), - extensions, + extensions: Some(extensions), buffer_size, target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, @@ -130,12 +132,12 @@ impl MuxInner { pub fn new_client( rx: R, maybe_downgrade_packet: Option>, - tx: LockedWebSocketWrite, + tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, - ) -> MuxInnerResult { - let (fut_tx, fut_rx) = mpsc::bounded::(256); - let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + ) -> MuxInnerResult { + let (fut_tx, fut_rx) = mpsc::bounded::>(256); + let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let ret_fut_tx = fut_tx.clone(); let fut_exited = Arc::new(AtomicBool::new(false)); @@ -150,7 +152,7 @@ impl MuxInner { fut_exited: fut_exited.clone(), tcp_extensions: Self::get_tcp_extensions(&extensions), - extensions, + extensions: Some(extensions), buffer_size, target_buffer_size: 0, @@ -183,7 +185,7 @@ impl MuxInner { &mut self, stream_id: u32, stream_type: StreamType, - ) -> Result<(MuxMapValue, MuxStream), WispError> { + ) -> Result<(MuxMapValue, MuxStream), WispError> { let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { self.buffer_size as usize } else { @@ -241,11 +243,12 @@ impl MuxInner { } async fn process_wisp_message( - &mut self, rx: &mut R, - msg: Result<(Frame<'static>, Option>), WispError>, - ) -> Result, WispError> { - let (mut frame, optional_frame) = msg?; + tx: &LockedWebSocketWrite, + extensions: &mut [AnyProtocolExtension], + msg: (Frame<'static>, Option>), + ) -> Result>, WispError> { + let (mut frame, optional_frame) = msg; if frame.opcode == OpCode::Close { return Ok(None); } else if frame.opcode == OpCode::Ping { @@ -262,8 +265,7 @@ impl MuxInner { } } - let packet = - Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?; + let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?; Ok(Some(WsEvent::WispMessage(packet, optional_frame))) } @@ -271,36 +273,47 @@ impl MuxInner { async fn stream_loop(&mut self) -> Result<(), WispError> { let mut next_free_stream_id: u32 = 1; - let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; + let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; let maybe_downgrade_packet = self.maybe_downgrade_packet.take(); let tx = self.tx.clone(); let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?; + let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?; + if let Some(downgrade_packet) = maybe_downgrade_packet { if self.handle_packet(downgrade_packet, None).await? { return Ok(()); } } + let mut read_stream = Box::pin(unfold( + (rx, tx.clone(), extensions), + |(mut rx, tx, mut extensions)| async { + let ret = async { + let msg = rx.wisp_read_split(&tx).await?; + Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await + } + .await; + ret.transpose().map(|x| (x, (rx, tx, extensions))) + }, + )) + .fuse(); + let mut recv_fut = fut_rx.recv_async().fuse(); - let mut read_fut = rx.wisp_read_split(&tx).fuse(); while let Some(msg) = select! { x = recv_fut => { drop(recv_fut); recv_fut = fut_rx.recv_async().fuse(); Ok(x.ok()) }, - x = read_fut => { - drop(read_fut); - let ret = self.process_wisp_message(&mut rx, x).await; - read_fut = rx.wisp_read_split(&tx).fuse(); - ret + x = read_stream.next() => { + x.transpose() } }? { match msg { WsEvent::CreateStream(stream_type, host, port, channel) => { - let ret: Result = async { + let ret: Result, WispError> = async { let stream_id = next_free_stream_id; let next_stream_id = next_free_stream_id .checked_add(1) diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index b784704..5b75866 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -8,7 +8,7 @@ pub use server::ServerMux; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::LockedWebSocketWrite, + ws::{LockedWebSocketWrite, WebSocketWrite}, CloseReason, Packet, PacketType, Role, WispError, }; @@ -35,8 +35,8 @@ impl WispHandshakeResultKind { } } -async fn send_info_packet( - write: &LockedWebSocketWrite, +async fn send_info_packet( + write: &LockedWebSocketWrite, builders: &mut [AnyProtocolExtensionBuilder], ) -> Result<(), WispError> { write diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs index 31a3f56..688c044 100644 --- a/wisp/src/mux/server.rs +++ b/wisp/src/mux/server.rs @@ -11,7 +11,7 @@ use futures::channel::oneshot; use crate::{ extensions::AnyProtocolExtension, - ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, WispError, }; @@ -23,9 +23,9 @@ use super::{ WispV2Handshake, }; -async fn handshake( +async fn handshake( rx: &mut R, - tx: &LockedWebSocketWrite, + tx: &LockedWebSocketWrite, buffer_size: u32, v2_info: Option, ) -> Result { @@ -47,7 +47,9 @@ async fn handshake( let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); for extension in supported_extensions.iter_mut() { - extension.handle_handshake(rx, tx).await?; + extension + .handle_handshake(DynWebSocketRead::from_mut(rx), tx) + .await?; } // v2 client @@ -79,36 +81,38 @@ async fn handshake( } /// Server-side multiplexor. -pub struct ServerMux { +pub struct ServerMux { /// Whether the connection was downgraded to Wisp v1. /// /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extensions: Vec, - actor_tx: mpsc::Sender, - muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, - tx: LockedWebSocketWrite, + actor_tx: mpsc::Sender>, + muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, + tx: LockedWebSocketWrite, actor_exited: Arc, } -impl ServerMux { +impl ServerMux { /// Create a new server-side multiplexor. /// /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( + pub async fn create( mut rx: R, tx: W, buffer_size: u32, wisp_v2: Option, - ) -> Result> + Send>, WispError> + ) -> Result< + MuxResult, impl Future> + Send>, + WispError, + > where - R: WebSocketRead + Send, - W: WebSocketWrite + Send + 'static, + R: WebSocketRead + Send + 'static, { - let tx = LockedWebSocketWrite::new(Box::new(tx)); + let tx = LockedWebSocketWrite::new(tx); let ret_tx = tx.clone(); let ret = async { let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?; @@ -165,7 +169,7 @@ impl ServerMux { } /// Wait for a stream to be created. - pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { + pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { if self.actor_exited.load(Ordering::Acquire) { return None; } @@ -210,7 +214,7 @@ impl ServerMux { } /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { MuxProtocolExtensionStream { stream_id: 0, tx: self.tx.clone(), @@ -219,13 +223,13 @@ impl ServerMux { } } -impl Drop for ServerMux { +impl Drop for ServerMux { fn drop(&mut self) { let _ = self.actor_tx.send(WsEvent::EndFut(None)); } } -impl Multiplexor for ServerMux { +impl Multiplexor for ServerMux { fn has_extension(&self, extension_id: u8) -> bool { self.supported_extensions .iter() diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index b17db4e..f3b6463 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -2,7 +2,10 @@ use std::fmt::Display; use crate::{ extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, + ws::{ + self, DynWebSocketRead, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, + WebSocketWrite, + }, Role, WispError, WISP_VERSION, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -527,11 +530,11 @@ impl<'a> Packet<'a> { } } - pub(crate) async fn maybe_handle_extension( + pub(crate) async fn maybe_handle_extension( frame: Frame<'a>, extensions: &mut [AnyProtocolExtension], - read: &mut (dyn WebSocketRead + Send), - write: &LockedWebSocketWrite, + read: &mut R, + write: &LockedWebSocketWrite, ) -> Result, WispError> { if !frame.finished { return Err(WispError::WsFrameNotFinished); @@ -568,7 +571,11 @@ impl<'a> Packet<'a> { .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) { extension - .handle_packet(BytesMut::from(bytes).freeze(), read, write) + .handle_packet( + BytesMut::from(bytes).freeze(), + DynWebSocketRead::from_mut(read), + write, + ) .await?; Ok(None) } else { diff --git a/wisp/src/stream/compat.rs b/wisp/src/stream/compat.rs index dc5c7e5..6a69954 100644 --- a/wisp/src/stream/compat.rs +++ b/wisp/src/stream/compat.rs @@ -98,7 +98,10 @@ impl MuxStreamIoStream { impl Stream for MuxStreamIoStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx).map_err(std::io::Error::other) + self.project() + .rx + .poll_next(cx) + .map_err(std::io::Error::other) } } diff --git a/wisp/src/stream/mod.rs b/wisp/src/stream/mod.rs index 38e9f4c..317ddf7 100644 --- a/wisp/src/stream/mod.rs +++ b/wisp/src/stream/mod.rs @@ -4,7 +4,7 @@ pub use compat::*; use crate::{ inner::WsEvent, - ws::{Frame, LockedWebSocketWrite, Payload}, + ws::{Frame, LockedWebSocketWrite, Payload, WebSocketWrite}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, }; @@ -21,7 +21,7 @@ use std::{ }; /// Read side of a multiplexor stream. -pub struct MuxStreamRead { +pub struct MuxStreamRead { /// ID of the stream. pub stream_id: u32, /// Type of the stream. @@ -29,7 +29,7 @@ pub struct MuxStreamRead { role: Role, - tx: LockedWebSocketWrite, + tx: LockedWebSocketWrite, rx: mpsc::Receiver, is_closed: Arc, @@ -42,7 +42,7 @@ pub struct MuxStreamRead { target_flow_control: u32, } -impl MuxStreamRead { +impl MuxStreamRead { /// Read an event from the stream. pub async fn read(&self) -> Result, WispError> { if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) { @@ -98,15 +98,15 @@ impl MuxStreamRead { } /// Write side of a multiplexor stream. -pub struct MuxStreamWrite { +pub struct MuxStreamWrite { /// ID of the stream. pub stream_id: u32, /// Type of the stream. pub stream_type: StreamType, role: Role, - mux_tx: mpsc::Sender, - tx: LockedWebSocketWrite, + mux_tx: mpsc::Sender>, + tx: LockedWebSocketWrite, is_closed: Arc, close_reason: Arc, @@ -116,7 +116,7 @@ pub struct MuxStreamWrite { flow_control: Arc, } -impl MuxStreamWrite { +impl MuxStreamWrite { pub(crate) async fn write_payload_internal<'a>( &self, header: Frame<'static>, @@ -169,7 +169,7 @@ impl MuxStreamWrite { /// handle.close(0x01); /// } /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { + pub fn get_close_handle(&self) -> MuxStreamCloser { MuxStreamCloser { stream_id: self.stream_id, close_channel: self.mux_tx.clone(), @@ -179,7 +179,7 @@ impl MuxStreamWrite { } /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { MuxProtocolExtensionStream { stream_id: self.stream_id, tx: self.tx.clone(), @@ -244,7 +244,7 @@ impl MuxStreamWrite { } } -impl Drop for MuxStreamWrite { +impl Drop for MuxStreamWrite { fn drop(&mut self) { if !self.is_closed.load(Ordering::Acquire) { self.is_closed.store(true, Ordering::Release); @@ -258,22 +258,22 @@ impl Drop for MuxStreamWrite { } /// Multiplexor stream. -pub struct MuxStream { +pub struct MuxStream { /// ID of the stream. pub stream_id: u32, - rx: MuxStreamRead, - tx: MuxStreamWrite, + rx: MuxStreamRead, + tx: MuxStreamWrite, } -impl MuxStream { +impl MuxStream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( stream_id: u32, role: Role, stream_type: StreamType, rx: mpsc::Receiver, - mux_tx: mpsc::Sender, - tx: LockedWebSocketWrite, + mux_tx: mpsc::Sender>, + tx: LockedWebSocketWrite, is_closed: Arc, is_closed_event: Arc, close_reason: Arc, @@ -339,12 +339,12 @@ impl MuxStream { /// handle.close(0x01); /// } /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { + pub fn get_close_handle(&self) -> MuxStreamCloser { self.tx.get_close_handle() } /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { self.tx.get_protocol_extension_stream() } @@ -359,7 +359,7 @@ impl MuxStream { } /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { (self.rx, self.tx) } @@ -374,15 +374,15 @@ impl MuxStream { /// Close handle for a multiplexor stream. #[derive(Clone)] -pub struct MuxStreamCloser { +pub struct MuxStreamCloser { /// ID of the stream. pub stream_id: u32, - close_channel: mpsc::Sender, + close_channel: mpsc::Sender>, is_closed: Arc, close_reason: Arc, } -impl MuxStreamCloser { +impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { @@ -414,14 +414,14 @@ impl MuxStreamCloser { } /// Stream for sending arbitrary protocol extension packets. -pub struct MuxProtocolExtensionStream { +pub struct MuxProtocolExtensionStream { /// ID of the stream. pub stream_id: u32, - pub(crate) tx: LockedWebSocketWrite, + pub(crate) tx: LockedWebSocketWrite, pub(crate) is_closed: Arc, } -impl MuxProtocolExtensionStream { +impl MuxProtocolExtensionStream { /// Send a protocol extension packet with this stream's ID. pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 1d1ca78..738c970 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -4,12 +4,11 @@ //! for other WebSocket implementations. //! //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs -use std::{ops::Deref, sync::Arc}; +use std::{future::Future, ops::Deref, pin::Pin, sync::Arc}; use crate::WispError; -use async_trait::async_trait; use bytes::{Buf, BytesMut}; -use futures::lock::Mutex; +use futures::{lock::Mutex, TryFutureExt}; /// Payload of the websocket frame. #[derive(Debug)] @@ -158,101 +157,286 @@ impl<'a> Frame<'a> { } /// Generic WebSocket read trait. -#[async_trait] -pub trait WebSocketRead { +pub trait WebSocketRead: Send { /// Read a frame from the socket. - async fn wisp_read_frame( + fn wisp_read_frame( &mut self, - tx: &LockedWebSocketWrite, - ) -> Result, WispError>; + tx: &dyn LockingWebSocketWrite, + ) -> impl Future, WispError>> + Send; /// Read a split frame from the socket. - async fn wisp_read_split( + fn wisp_read_split( &mut self, - tx: &LockedWebSocketWrite, - ) -> Result<(Frame<'static>, Option>), WispError> { - self.wisp_read_frame(tx).await.map(|x| (x, None)) + tx: &dyn LockingWebSocketWrite, + ) -> impl Future, Option>), WispError>> + Send { + self.wisp_read_frame(tx).map_ok(|x| (x, None)) } } -#[async_trait] -impl WebSocketRead for Box { - async fn wisp_read_frame( - &mut self, - tx: &LockedWebSocketWrite, - ) -> Result, WispError> { - self.as_mut().wisp_read_frame(tx).await +// similar to what dynosaur does +mod wsr_inner { + use std::{future::Future, pin::Pin}; + + use crate::WispError; + + use super::{Frame, LockingWebSocketWrite, WebSocketRead}; + + trait ErasedWebSocketRead: Send { + fn wisp_read_frame<'a>( + &'a mut self, + tx: &'a dyn LockingWebSocketWrite, + ) -> Pin, WispError>> + Send + 'a>>; + + #[allow(clippy::type_complexity)] + fn wisp_read_split<'a>( + &'a mut self, + tx: &'a dyn LockingWebSocketWrite, + ) -> Pin< + Box< + dyn Future, Option>), WispError>> + + Send + + 'a, + >, + >; } - async fn wisp_read_split( - &mut self, - tx: &LockedWebSocketWrite, - ) -> Result<(Frame<'static>, Option>), WispError> { - self.as_mut().wisp_read_split(tx).await + impl ErasedWebSocketRead for T { + fn wisp_read_frame<'a>( + &'a mut self, + tx: &'a dyn LockingWebSocketWrite, + ) -> Pin, WispError>> + Send + 'a>> { + Box::pin(self.wisp_read_frame(tx)) + } + + fn wisp_read_split<'a>( + &'a mut self, + tx: &'a dyn LockingWebSocketWrite, + ) -> Pin< + Box< + dyn Future, Option>), WispError>> + + Send + + 'a, + >, + > { + Box::pin(self.wisp_read_split(tx)) + } + } + + /// WebSocketRead trait object. + #[repr(transparent)] + pub struct DynWebSocketRead { + ptr: dyn ErasedWebSocketRead + 'static, + } + impl WebSocketRead for DynWebSocketRead { + async fn wisp_read_frame( + &mut self, + tx: &dyn LockingWebSocketWrite, + ) -> Result, WispError> { + self.ptr.wisp_read_frame(tx).await + } + + async fn wisp_read_split( + &mut self, + tx: &dyn LockingWebSocketWrite, + ) -> Result<(Frame<'static>, Option>), WispError> { + self.ptr.wisp_read_split(tx).await + } + } + impl DynWebSocketRead { + /// Create a WebSocketRead trait object from a boxed WebSocketRead. + pub fn new(val: Box) -> Box { + let val: Box = val; + unsafe { std::mem::transmute(val) } + } + /// Create a WebSocketRead trait object from a WebSocketRead. + pub fn boxed(val: impl WebSocketRead + 'static) -> Box { + Self::new(Box::new(val)) + } + /// Create a WebSocketRead trait object from a WebSocketRead reference. + pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self { + let val: &(dyn ErasedWebSocketRead + 'static) = val; + unsafe { std::mem::transmute(val) } + } + /// Create a WebSocketRead trait object from a mutable WebSocketRead reference. + pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self { + let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val; + unsafe { std::mem::transmute(val) } + } } } +pub use wsr_inner::DynWebSocketRead; /// Generic WebSocket write trait. -#[async_trait] -pub trait WebSocketWrite { +pub trait WebSocketWrite: Send { /// Write a frame to the socket. - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>; - - /// Close the socket. - async fn wisp_close(&mut self) -> Result<(), WispError>; + fn wisp_write_frame( + &mut self, + frame: Frame<'_>, + ) -> impl Future> + Send; /// Write a split frame to the socket. - async fn wisp_write_split( + fn wisp_write_split( &mut self, header: Frame<'_>, body: Frame<'_>, - ) -> Result<(), WispError> { - let mut payload = BytesMut::from(header.payload); - payload.extend_from_slice(&body.payload); - self.wisp_write_frame(Frame::binary(Payload::Bytes(payload))) - .await + ) -> impl Future> + Send { + async move { + let mut payload = BytesMut::from(header.payload); + payload.extend_from_slice(&body.payload); + self.wisp_write_frame(Frame::binary(Payload::Bytes(payload))) + .await + } } + + /// Close the socket. + fn wisp_close(&mut self) -> impl Future> + Send; } -#[async_trait] -impl WebSocketWrite for Box { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - self.as_mut().wisp_write_frame(frame).await +// similar to what dynosaur does +mod wsw_inner { + use std::{future::Future, pin::Pin}; + + use crate::WispError; + + use super::{Frame, WebSocketWrite}; + + trait ErasedWebSocketWrite: Send { + fn wisp_write_frame<'a>( + &'a mut self, + frame: Frame<'a>, + ) -> Pin> + Send + 'a>>; + + fn wisp_write_split<'a>( + &'a mut self, + header: Frame<'a>, + body: Frame<'a>, + ) -> Pin> + Send + 'a>>; + + fn wisp_close<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>>; } - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.as_mut().wisp_close().await + impl ErasedWebSocketWrite for T { + fn wisp_write_frame<'a>( + &'a mut self, + frame: Frame<'a>, + ) -> Pin> + Send + 'a>> { + Box::pin(self.wisp_write_frame(frame)) + } + + fn wisp_write_split<'a>( + &'a mut self, + header: Frame<'a>, + body: Frame<'a>, + ) -> Pin> + Send + 'a>> { + Box::pin(self.wisp_write_split(header, body)) + } + + fn wisp_close<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(self.wisp_close()) + } } - async fn wisp_write_split( - &mut self, - header: Frame<'_>, - body: Frame<'_>, - ) -> Result<(), WispError> { - self.as_mut().wisp_write_split(header, body).await + /// WebSocketWrite trait object. + #[repr(transparent)] + pub struct DynWebSocketWrite { + ptr: dyn ErasedWebSocketWrite + 'static, } + impl WebSocketWrite for DynWebSocketWrite { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + self.ptr.wisp_write_frame(frame).await + } + + async fn wisp_write_split( + &mut self, + header: Frame<'_>, + body: Frame<'_>, + ) -> Result<(), WispError> { + self.ptr.wisp_write_split(header, body).await + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.ptr.wisp_close().await + } + } + impl DynWebSocketWrite { + /// Create a new WebSocketWrite trait object from a boxed WebSocketWrite. + pub fn new(val: Box) -> Box { + let val: Box = val; + unsafe { std::mem::transmute(val) } + } + /// Create a new WebSocketWrite trait object from a WebSocketWrite. + pub fn boxed(val: impl WebSocketWrite + 'static) -> Box { + Self::new(Box::new(val)) + } + /// Create a new WebSocketWrite trait object from a WebSocketWrite reference. + pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self { + let val: &(dyn ErasedWebSocketWrite + 'static) = val; + unsafe { std::mem::transmute(val) } + } + /// Create a new WebSocketWrite trait object from a mutable WebSocketWrite reference. + pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self { + let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val; + unsafe { std::mem::transmute(val) } + } + } +} +pub use wsw_inner::DynWebSocketWrite; + +mod private { + pub trait Sealed {} +} + +/// Helper trait object for LockedWebSocketWrite. +pub trait LockingWebSocketWrite: private::Sealed + Sync { + /// Write a frame to the websocket. + fn wisp_write_frame<'a>( + &'a self, + frame: Frame<'a>, + ) -> Pin> + Send + 'a>>; + + /// Write a split frame to the websocket. + fn wisp_write_split<'a>( + &'a self, + header: Frame<'a>, + body: Frame<'a>, + ) -> Pin> + Send + 'a>>; + + /// Close the websocket. + fn wisp_close<'a>(&'a self) + -> Pin> + Send + 'a>>; } /// Locked WebSocket. -#[derive(Clone)] -pub struct LockedWebSocketWrite(Arc>>); +pub struct LockedWebSocketWrite(Arc>); -impl LockedWebSocketWrite { +impl Clone for LockedWebSocketWrite { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl LockedWebSocketWrite { /// Create a new locked websocket. - pub fn new(ws: Box) -> Self { + pub fn new(ws: T) -> Self { Self(Mutex::new(ws).into()) } + /// Create a new locked websocket from an existing mutex. + pub fn from_locked(locked: Arc>) -> Self { + Self(locked) + } + /// Write a frame to the websocket. pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> { self.0.lock().await.wisp_write_frame(frame).await } - pub(crate) async fn write_split( - &self, - header: Frame<'_>, - body: Frame<'_>, - ) -> Result<(), WispError> { + /// Write a split frame to the websocket. + pub async fn write_split(&self, header: Frame<'_>, body: Frame<'_>) -> Result<(), WispError> { self.0.lock().await.wisp_write_split(header, body).await } @@ -261,3 +445,91 @@ impl LockedWebSocketWrite { self.0.lock().await.wisp_close().await } } + +impl private::Sealed for LockedWebSocketWrite {} + +impl LockingWebSocketWrite for LockedWebSocketWrite { + fn wisp_write_frame<'a>( + &'a self, + frame: Frame<'a>, + ) -> Pin> + Send + 'a>> { + Box::pin(self.write_frame(frame)) + } + + fn wisp_write_split<'a>( + &'a self, + header: Frame<'a>, + body: Frame<'a>, + ) -> Pin> + Send + 'a>> { + Box::pin(self.write_split(header, body)) + } + + fn wisp_close<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { + Box::pin(self.close()) + } +} + +/// Combines two different WebSocketReads together. +pub enum EitherWebSocketRead { + /// First WebSocketRead variant. + Left(A), + /// Second WebSocketRead variant. + Right(B), +} +impl WebSocketRead for EitherWebSocketRead { + async fn wisp_read_frame( + &mut self, + tx: &dyn LockingWebSocketWrite, + ) -> Result, WispError> { + match self { + Self::Left(x) => x.wisp_read_frame(tx).await, + Self::Right(x) => x.wisp_read_frame(tx).await, + } + } + + async fn wisp_read_split( + &mut self, + tx: &dyn LockingWebSocketWrite, + ) -> Result<(Frame<'static>, Option>), WispError> { + match self { + Self::Left(x) => x.wisp_read_split(tx).await, + Self::Right(x) => x.wisp_read_split(tx).await, + } + } +} + +/// Combines two different WebSocketWrites together. +pub enum EitherWebSocketWrite { + /// First WebSocketWrite variant. + Left(A), + /// Second WebSocketWrite variant. + Right(B), +} +impl WebSocketWrite for EitherWebSocketWrite { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + match self { + Self::Left(x) => x.wisp_write_frame(frame).await, + Self::Right(x) => x.wisp_write_frame(frame).await, + } + } + + async fn wisp_write_split( + &mut self, + header: Frame<'_>, + body: Frame<'_>, + ) -> Result<(), WispError> { + match self { + Self::Left(x) => x.wisp_write_split(header, body).await, + Self::Right(x) => x.wisp_write_split(header, body).await, + } + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + match self { + Self::Left(x) => x.wisp_close().await, + Self::Right(x) => x.wisp_close().await, + } + } +}