From 16268905fc3179c90b44d19b377b0e518f49b402 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 16 Aug 2024 23:29:33 -0700 Subject: [PATCH] custom wisp transport support --- Cargo.lock | 13 ++- client/.npmignore | 4 +- client/Cargo.toml | 2 +- client/demo.js | 1 - client/src/io_stream.rs | 2 +- client/src/lib.rs | 108 +++++++++++++++++++++--- client/src/stream_provider.rs | 152 +++++++++------------------------- client/src/utils.rs | 69 ++++++++++++++- wisp/src/stream.rs | 61 +++++++++++++- wisp/src/ws.rs | 36 ++++++++ 10 files changed, 313 insertions(+), 135 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08292a3..87f2e73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -535,7 +535,7 @@ dependencies = [ "pin-project-lite", "ring", "rustls-pki-types", - "send_wrapper", + "send_wrapper 0.6.0", "thiserror", "tokio", "wasm-bindgen", @@ -726,7 +726,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" dependencies = [ "gloo-timers", - "send_wrapper", + "send_wrapper 0.4.0", ] [[package]] @@ -1488,6 +1488,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" +[[package]] +name = "send_wrapper" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73" +dependencies = [ + "futures-core", +] + [[package]] name = "serde" version = "1.0.204" diff --git a/client/.npmignore b/client/.npmignore index 91b4d78..cefec79 100644 --- a/client/.npmignore +++ b/client/.npmignore @@ -2,6 +2,6 @@ build.sh Cargo.toml serve.py src -tests -test.sh .cargo +index.html +demo.js diff --git a/client/Cargo.toml b/client/Cargo.toml index 25f929f..8554829 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -23,7 +23,7 @@ hyper-util-wasm = { git = "https://github.com/r58Playz/hyper-util-wasm", branch js-sys = "0.3.69" lazy_static = "1.5.0" pin-project-lite = "0.2.14" -send_wrapper = "0.4.0" +send_wrapper = { version = "0.6.0", features = ["futures"] } thiserror = "1.0.61" tokio = "1.38.0" wasm-bindgen = "0.2.92" diff --git a/client/demo.js b/client/demo.js index 1ff9cd1..9c0e467 100644 --- a/client/demo.js +++ b/client/demo.js @@ -254,7 +254,6 @@ import initEpoxy, { EpoxyClient, EpoxyClientOptions, EpoxyHandlers, info as epox } total_mux_multi = total_mux_multi / num_outer_tests; log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); - } else { console.time(); let resp = await epoxy_client.fetch("https://www.example.com/"); diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index d0df66e..86d7be6 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -111,7 +111,7 @@ pub struct EpoxyUdpStream { #[wasm_bindgen] impl EpoxyUdpStream { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { - let (mut rx, tx) = stream.into_inner().into_split(); + let (mut rx, tx) = stream.into_split(); let EpoxyHandlers { onopen, diff --git a/client/src/lib.rs b/client/src/lib.rs index b9e7d64..8b9b967 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -18,21 +18,28 @@ use hyper::{body::Incoming, Uri}; use hyper_util_wasm::client::legacy::Client; #[cfg(feature = "full")] use io_stream::{EpoxyIoStream, EpoxyUdpStream}; -use js_sys::{Array, Function, Object}; +use js_sys::{Array, Function, Object, Promise}; +use send_wrapper::SendWrapper; use stream_provider::{StreamProvider, StreamProviderService}; use thiserror::Error; use utils::{ asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, - object_get, object_set, object_truthy, IncomingBody, UriExt, WasmExecutor, + object_get, object_set, object_truthy, IncomingBody, UriExt, WasmExecutor, WispTransportRead, + WispTransportWrite, }; use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; use wasm_streams::ReadableStream; use web_sys::ResponseInit; #[cfg(feature = "full")] use websocket::EpoxyWebSocket; -use wisp_mux::CloseReason; #[cfg(feature = "full")] use wisp_mux::StreamType; +use wisp_mux::{ + ws::{WebSocketRead, WebSocketWrite}, + CloseReason, +}; +use ws_wrapper::WebSocketWrapper; #[cfg(feature = "full")] mod io_stream; @@ -67,6 +74,15 @@ pub enum EpoxyError { #[error("Fastwebsockets: {0:?} ({0})")] FastWebSockets(#[from] fastwebsockets::WebSocketError), + #[error("Custom wisp transport: {0}")] + WispTransport(String), + #[error("Invalid Wisp transport")] + InvalidWispTransport, + #[error("Invalid Wisp transport packet")] + InvalidWispTransportPacket, + #[error("Wisp transport already closed")] + WispTransportClosed, + #[error("Invalid URL scheme")] InvalidUrlScheme, #[error("No URL host found")] @@ -99,6 +115,12 @@ pub enum EpoxyError { ResponseNewFailed, } +impl EpoxyError { + pub fn wisp_transport(value: JsValue) -> Self { + Self::WispTransport(format!("{:?}", value)) + } +} + impl From for JsValue { fn from(value: EpoxyError) -> Self { JsError::from(value).into() @@ -137,7 +159,7 @@ impl From for EpoxyError { impl From for EpoxyError { fn from(value: CloseReason) -> Self { - EpoxyError::WispCloseReason(value) + EpoxyError::WispCloseReason(value) } } @@ -224,13 +246,79 @@ pub struct EpoxyClient { #[wasm_bindgen] impl EpoxyClient { #[wasm_bindgen(constructor)] - pub fn new(wisp_url: String, options: EpoxyClientOptions) -> Result { - let wisp_url: Uri = wisp_url.try_into()?; - if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") { - return Err(EpoxyError::InvalidUrlScheme); - } + pub fn new(wisp_url: JsValue, options: EpoxyClientOptions) -> Result { + let stream_provider = if let Some(wisp_url) = wisp_url.as_string() { + let wisp_uri: Uri = wisp_url.clone().try_into()?; + if wisp_uri.scheme_str() != Some("wss") && wisp_uri.scheme_str() != Some("ws") { + return Err(EpoxyError::InvalidUrlScheme); + } - let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), &options)?); + let ws_protocols = options.websocket_protocols.clone(); + Arc::new(StreamProvider::new( + Box::new(move || { + let wisp_url = wisp_url.clone(); + let ws_protocols = ws_protocols.clone(); + + Box::pin(async move { + let (write, read) = WebSocketWrapper::connect(&wisp_url, &ws_protocols)?; + if !write.wait_for_open().await { + return Err(EpoxyError::WebSocketConnectFailed); + } + Ok(( + Box::new(read) as Box, + Box::new(write) as Box, + )) + }) + }), + &options, + )?) + } else if let Ok(wisp_transport) = wisp_url.dyn_into::() { + let wisp_transport = SendWrapper::new(wisp_transport); + Arc::new(StreamProvider::new( + Box::new(move || { + let wisp_transport = wisp_transport.clone(); + Box::pin(SendWrapper::new(async move { + let transport = wisp_transport + .call0(&JsValue::NULL) + .map_err(EpoxyError::wisp_transport)?; + + let transport = match transport.dyn_into::() { + Ok(transport) => { + let fut = JsFuture::from(transport); + fut.await.map_err(EpoxyError::wisp_transport)? + } + Err(transport) => transport, + } + .into(); + + let read = WispTransportRead { + inner: SendWrapper::new( + wasm_streams::ReadableStream::from_raw( + object_get(&transport, "read").into(), + ) + .into_stream(), + ), + }; + let write = WispTransportWrite { + inner: Some(SendWrapper::new( + wasm_streams::WritableStream::from_raw( + object_get(&transport, "write").into(), + ) + .into_sink(), + )), + }; + + Ok(( + Box::new(read) as Box, + Box::new(write) as Box, + )) + })) + }), + &options, + )?) + } else { + return Err(EpoxyError::InvalidWispTransport); + }; let service = StreamProviderService(stream_provider.clone()); let client = Client::builder(WasmExecutor) diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 6f232e6..54ab34a 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -1,10 +1,4 @@ -use std::{ - io::ErrorKind, - ops::{Deref, DerefMut}, - pin::Pin, - sync::Arc, - task::Poll, -}; +use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll}; use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, @@ -22,10 +16,11 @@ use wasm_bindgen_futures::spawn_local; use webpki_roots::TLS_SERVER_ROOTS; use wisp_mux::{ extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, - ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType, + ws::{WebSocketRead, WebSocketWrite}, + ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, }; -use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; +use crate::{console_log, EpoxyClientOptions, EpoxyError}; lazy_static! { static ref CLIENT_CONFIG: Arc = { @@ -38,117 +33,45 @@ lazy_static! { }; } -pin_project! { - pub struct CloserWrapper { - #[pin] - pub inner: T, - pub closer: MuxStreamCloser, - } -} - -impl CloserWrapper { - pub fn new(inner: T, closer: MuxStreamCloser) -> Self { - Self { inner, closer } - } - - pub fn into_inner(self) -> T { - self.inner - } -} - -impl Deref for CloserWrapper { - type Target = T; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for CloserWrapper { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl AsyncRead for CloserWrapper { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - self.project().inner.poll_read_vectored(cx, bufs) - } -} - -impl AsyncWrite for CloserWrapper { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().inner.poll_close(cx) - } -} - -impl From> for CloserWrapper { - fn from(value: CloserWrapper) -> Self { - let CloserWrapper { inner, closer } = value; - CloserWrapper::new(inner.into_asyncrw(), closer) - } -} +pub type ProviderUnencryptedStream = MuxStreamIo; +pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; +pub type ProviderTlsAsyncRW = TlsStream; +pub type ProviderAsyncRW = Either; +pub type ProviderWispTransportGenerator = Box< + dyn Fn() -> Pin< + Box< + dyn Future< + Output = Result< + ( + Box, + Box, + ), + EpoxyError, + >, + > + Sync + Send, + >, + > + Sync + Send, +>; pub struct StreamProvider { - wisp_url: String, + wisp_generator: ProviderWispTransportGenerator, wisp_v2: bool, udp_extension: bool, - websocket_protocols: Vec, current_client: Arc>>, } -pub type ProviderUnencryptedStream = CloserWrapper; -pub type ProviderUnencryptedAsyncRW = CloserWrapper; -pub type ProviderTlsAsyncRW = TlsStream; -pub type ProviderAsyncRW = Either; - impl StreamProvider { - pub fn new(wisp_url: String, options: &EpoxyClientOptions) -> Result { + pub fn new( + wisp_generator: ProviderWispTransportGenerator, + options: &EpoxyClientOptions, + ) -> Result { Ok(Self { - wisp_url, + wisp_generator, current_client: Arc::new(Mutex::new(None)), wisp_v2: options.wisp_v2, udp_extension: options.udp_extension_required, - websocket_protocols: options.websocket_protocols.clone(), }) } @@ -163,10 +86,9 @@ impl StreamProvider { } else { None }; - let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?; - if !write.wait_for_open().await { - return Err(EpoxyError::WebSocketConnectFailed); - } + + let (read, write) = (self.wisp_generator)().await?; + let client = ClientMux::create(read, write, extensions).await?; let (mux, fut) = if self.udp_extension { client.with_udp_extension_required().await? @@ -196,8 +118,7 @@ impl StreamProvider { let locked = self.current_client.lock().await; if let Some(mux) = locked.as_ref() { let stream = mux.client_new_stream(stream_type, host, port).await?; - let closer = stream.get_close_handle(); - Ok(CloserWrapper::new(stream.into_io(), closer)) + Ok(stream.into_io()) } else { self.create_client(locked).await?; self.get_stream(stream_type, host, port).await @@ -212,7 +133,10 @@ impl StreamProvider { host: String, port: u16, ) -> Result { - Ok(self.get_stream(stream_type, host, port).await?.into()) + Ok(self + .get_stream(stream_type, host, port) + .await? + .into_asyncrw()) } pub async fn get_tls_stream( @@ -233,7 +157,7 @@ impl StreamProvider { Err((err, stream)) => { if matches!(err.kind(), ErrorKind::UnexpectedEof) { // maybe actually a wisp error? - if let Some(reason) = stream.closer.get_close_reason() { + if let Some(reason) = stream.get_close_reason() { return Err(reason.into()); } } diff --git a/client/src/utils.rs b/client/src/utils.rs index 08c3d17..d7d8935 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -3,13 +3,20 @@ use std::{ task::{Context, Poll}, }; +use async_trait::async_trait; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; -use futures_util::{ready, AsyncRead, Future, Stream, TryStreamExt}; +use futures_util::{ready, AsyncRead, Future, SinkExt, Stream, StreamExt, TryStreamExt}; use http::{HeaderValue, Uri}; use hyper::{body::Body, rt::Executor}; -use js_sys::{Array, JsString, Object, Uint8Array}; +use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array}; use pin_project_lite::pin_project; +use send_wrapper::SendWrapper; use wasm_bindgen::{prelude::*, JsCast, JsValue}; +use wasm_streams::{readable::IntoStream, writable::IntoSink}; +use wisp_mux::{ + ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + WispError, +}; use crate::EpoxyError; @@ -168,6 +175,64 @@ impl Stream for ReaderStream { } } +pub struct WispTransportRead { + pub inner: SendWrapper>, +} + +#[async_trait] +impl WebSocketRead for WispTransportRead { + async fn wisp_read_frame( + &mut self, + _tx: &LockedWebSocketWrite, + ) -> Result, wisp_mux::WispError> { + let obj = self.inner.next().await; + + if let Some(pkt) = obj { + let pkt = + pkt.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))?; + let arr: ArrayBuffer = pkt.dyn_into().map_err(|_| { + WispError::WsImplError(Box::new(EpoxyError::InvalidWispTransportPacket)) + })?; + + Ok(Frame::binary(Payload::Bytes( + Uint8Array::new(&arr).to_vec().as_slice().into(), + ))) + } else { + Ok(Frame::close(Payload::Borrowed(&[]))) + } + } +} + +pub struct WispTransportWrite { + pub inner: Option>>, +} + +#[async_trait] +impl WebSocketWrite for WispTransportWrite { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + SendWrapper::new( + self.inner + .as_mut() + .ok_or_else(|| WispError::WsImplError(Box::new(EpoxyError::WispTransportClosed)))? + .send(Uint8Array::from(frame.payload.as_ref()).into()), + ) + .await + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + SendWrapper::new( + self.inner + .take() + .ok_or_else(|| WispError::WsImplError(Box::new(EpoxyError::WispTransportClosed)))? + .take() + .abort(), + ) + .await + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + } +} + pub fn is_redirect(code: u16) -> bool { [301, 302, 303, 307, 308].contains(&code) } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index eb7c045..bd0982d 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -93,6 +93,8 @@ impl MuxStreamRead { /// Turn the read half into one that implements futures `Stream`, consuming it. pub fn into_stream(self) -> MuxStreamIoStream { MuxStreamIoStream { + close_reason: self.close_reason.clone(), + is_closed: self.is_closed.clone(), rx: self.into_inner_stream(), } } @@ -246,6 +248,8 @@ impl MuxStreamWrite { /// Turn the write half into one that implements futures `Sink`, consuming it. pub fn into_sink(self) -> MuxStreamIoSink { MuxStreamIoSink { + close_reason: self.close_reason.clone(), + is_closed: self.is_closed.clone(), tx: self.into_inner_sink(), } } @@ -352,6 +356,11 @@ impl MuxStream { self.tx.get_protocol_extension_stream() } + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.rx.get_close_reason() + } + /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { self.tx.close(reason).await @@ -455,6 +464,11 @@ impl MuxStreamIo { } } + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.rx.get_close_reason() + } + /// Split the stream into read and write parts, consuming it. pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) { (self.rx, self.tx) @@ -489,6 +503,8 @@ pin_project! { pub struct MuxStreamIoStream { #[pin] rx: Pin + Send>>, + is_closed: Arc, + close_reason: Arc, } } @@ -497,6 +513,15 @@ impl MuxStreamIoStream { pub fn into_asyncread(self) -> MuxStreamAsyncRead { MuxStreamAsyncRead::new(self) } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None + } + } } impl Stream for MuxStreamIoStream { @@ -511,6 +536,8 @@ pin_project! { pub struct MuxStreamIoSink { #[pin] tx: Pin, Error = WispError> + Send>>, + is_closed: Arc, + close_reason: Arc, } } @@ -519,6 +546,15 @@ impl MuxStreamIoSink { pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite { MuxStreamAsyncWrite::new(self) } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None + } + } } impl Sink<&[u8]> for MuxStreamIoSink { @@ -560,6 +596,11 @@ pin_project! { } impl MuxStreamAsyncRW { + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.rx.get_close_reason() + } + /// Split the stream into read and write parts, consuming it. pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { (self.rx, self.tx) @@ -617,15 +658,26 @@ pin_project! { pub struct MuxStreamAsyncRead { #[pin] rx: IntoAsyncRead, - // state: Option + is_closed: Arc, + close_reason: Arc, } } impl MuxStreamAsyncRead { pub(crate) fn new(stream: MuxStreamIoStream) -> Self { Self { + is_closed: stream.is_closed.clone(), + close_reason: stream.close_reason.clone(), rx: stream.into_async_read(), - // state: None, + } + } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None } } } @@ -664,6 +716,11 @@ impl MuxStreamAsyncWrite { error: None, } } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.tx.get_close_reason() + } } impl AsyncWrite for MuxStreamAsyncWrite { diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 63f7cb7..fbf63af 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -166,6 +166,23 @@ pub trait WebSocketRead { } } +#[async_trait] +impl WebSocketRead for Box { + async fn wisp_read_frame( + &mut self, + tx: &LockedWebSocketWrite, + ) -> Result, WispError> { + self.as_mut().wisp_read_frame(tx).await + } + + async fn wisp_read_split( + &mut self, + tx: &LockedWebSocketWrite, + ) -> Result<(Frame<'static>, Option>), WispError> { + self.as_mut().wisp_read_split(tx).await + } +} + /// Generic WebSocket write trait. #[async_trait] pub trait WebSocketWrite { @@ -188,6 +205,25 @@ pub trait WebSocketWrite { } } +#[async_trait] +impl WebSocketWrite for Box { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + self.as_mut().wisp_write_frame(frame).await + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.as_mut().wisp_close().await + } + + async fn wisp_write_split( + &mut self, + header: Frame<'_>, + body: Frame<'_>, + ) -> Result<(), WispError> { + self.as_mut().wisp_write_split(header, body).await + } +} + /// Locked WebSocket. #[derive(Clone)] pub struct LockedWebSocketWrite(Arc>>);