From 142067961dc2c982ecd8bdedc04050fd8d0208b0 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 12 Oct 2024 13:27:06 -0700 Subject: [PATCH] Revert "remove websocket support" This reverts commit df33f8134090af0eaa7f1255c2a83dcd541ce5f4. --- Cargo.lock | 2 + client/Cargo.toml | 4 +- client/src/lib.rs | 117 +++++++++++--------- client/src/utils.rs | 63 +---------- client/src/ws_wrapper.rs | 231 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 115 deletions(-) create mode 100644 client/src/ws_wrapper.rs diff --git a/Cargo.lock b/Cargo.lock index b521068..3d4cb44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -689,7 +689,9 @@ dependencies = [ "async-trait", "bytes", "cfg-if", + "event-listener", "fastwebsockets", + "flume", "futures-rustls", "futures-util", "getrandom", diff --git a/client/Cargo.toml b/client/Cargo.toml index 5ef89ac..b0a6486 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -11,7 +11,9 @@ async-compression = { version = "0.4.12", features = ["futures-io", "gzip", "bro async-trait = "0.1.81" bytes = "1.7.1" cfg-if = "1.0.0" +event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } +flume = "0.11.0" futures-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] } futures-util = { version = "0.3.30", features = ["sink"] } http = "1.1.0" @@ -28,7 +30,7 @@ tokio = "1.39.3" wasm-bindgen = "0.2.93" wasm-bindgen-futures = "0.4.43" wasm-streams = "0.4.0" -web-sys = { version = "0.3.70", features = ["Headers", "Request", "RequestInit", "Response", "ResponseInit"] } +web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "WebSocket"] } webpki-roots = "0.26.3" wisp-mux = { path = "../wisp", features = ["wasm"], version = "5.1.0", default-features = false } diff --git a/client/src/lib.rs b/client/src/lib.rs index 20349c8..471bcca 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -26,7 +26,7 @@ use send_wrapper::SendWrapper; use stream_provider::{StreamProvider, StreamProviderService}; use thiserror::Error; use utils::{ - asyncread_to_readablestream, bind_ws_connect, convert_body, entries_of_object, + asyncread_to_readablestream, convert_body, entries_of_object, from_entries, is_null_body, is_redirect, object_get, object_set, object_truthy, IncomingBody, UriExt, WasmExecutor, WispTransportRead, WispTransportWrite, }; @@ -41,6 +41,7 @@ use wisp_mux::{ ws::{WebSocketRead, WebSocketWrite}, CloseReason, }; +use ws_wrapper::WebSocketWrapper; #[cfg(feature = "full")] mod io_stream; @@ -49,6 +50,7 @@ mod tokioio; mod utils; #[cfg(feature = "full")] mod websocket; +mod ws_wrapper; type HttpBody = http_body_util::Full; @@ -77,7 +79,10 @@ pub enum EpoxyError { #[error("Webpki: {0:?} ({0})")] Webpki(#[from] webpki::Error), - #[error("Wisp transport: {0}")] + #[error("Wisp WebSocket failed to connect")] + WebSocketConnectFailed, + + #[error("Custom Wisp transport: {0}")] WispTransport(String), #[error("Invalid Wisp transport")] InvalidWispTransport, @@ -271,53 +276,6 @@ impl EpoxyHandlers { } } -fn get_stream_provider( - func: Function, - options: &EpoxyClientOptions, -) -> Result { - let wisp_transport = SendWrapper::new(func); - StreamProvider::new( - Box::new(move || { - let wisp_transport = wisp_transport.clone(); - Box::pin(SendWrapper::new(async move { - let transport = wisp_transport - .call0(&JsValue::NULL) - .map_err(EpoxyError::wisp_transport)?; - - let transport = match transport.dyn_into::() { - Ok(transport) => { - let fut = JsFuture::from(transport); - fut.await.map_err(EpoxyError::wisp_transport)? - } - Err(transport) => transport, - } - .into(); - - let read = WispTransportRead { - inner: SendWrapper::new( - wasm_streams::ReadableStream::from_raw( - object_get(&transport, "read").into(), - ) - .into_stream(), - ), - }; - let write: WritableStream = object_get(&transport, "write").into(); - let write = WispTransportWrite { - inner: SendWrapper::new( - write.get_writer().map_err(EpoxyError::wisp_transport)?, - ), - }; - - Ok(( - Box::new(read) as Box, - Box::new(write) as Box, - )) - })) - }), - options, - ) -} - #[wasm_bindgen(inspectable)] pub struct EpoxyClient { stream_provider: Arc, @@ -340,13 +298,68 @@ impl EpoxyClient { if wisp_uri.scheme_str() != Some("wss") && wisp_uri.scheme_str() != Some("ws") { return Err(EpoxyError::InvalidUrlScheme); } + let ws_protocols = options.websocket_protocols.clone(); - Arc::new(get_stream_provider( - bind_ws_connect(wisp_url, ws_protocols), + Arc::new(StreamProvider::new( + Box::new(move || { + let wisp_url = wisp_url.clone(); + let ws_protocols = ws_protocols.clone(); + + Box::pin(async move { + let (write, read) = WebSocketWrapper::connect(&wisp_url, &ws_protocols)?; + if !write.wait_for_open().await { + return Err(EpoxyError::WebSocketConnectFailed); + } + Ok(( + Box::new(read) as Box, + Box::new(write) as Box, + )) + }) + }), &options, )?) } else if let Ok(wisp_transport) = wisp_url.dyn_into::() { - Arc::new(get_stream_provider(wisp_transport, &options)?) + let wisp_transport = SendWrapper::new(wisp_transport); + Arc::new(StreamProvider::new( + Box::new(move || { + let wisp_transport = wisp_transport.clone(); + Box::pin(SendWrapper::new(async move { + let transport = wisp_transport + .call0(&JsValue::NULL) + .map_err(EpoxyError::wisp_transport)?; + + let transport = match transport.dyn_into::() { + Ok(transport) => { + let fut = JsFuture::from(transport); + fut.await.map_err(EpoxyError::wisp_transport)? + } + Err(transport) => transport, + } + .into(); + + let read = WispTransportRead { + inner: SendWrapper::new( + wasm_streams::ReadableStream::from_raw( + object_get(&transport, "read").into(), + ) + .into_stream(), + ), + }; + let write: WritableStream = object_get(&transport, "write").into(); + let write = WispTransportWrite { + inner: SendWrapper::new( + write.get_writer().map_err(EpoxyError::wisp_transport)?, + ), + }; + + Ok(( + Box::new(read) as Box, + Box::new(write) as Box, + )) + })) + }), + &options, + )?) } else { return Err(EpoxyError::InvalidWispTransport); }; diff --git a/client/src/utils.rs b/client/src/utils.rs index 4a5ad86..0c47b84 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -18,7 +18,7 @@ use futures_rustls::{ use futures_util::{ready, AsyncRead, AsyncWrite, Future, Stream, StreamExt, TryStreamExt}; use http::{HeaderValue, Uri}; use hyper::{body::Body, rt::Executor}; -use js_sys::{Array, ArrayBuffer, Function, JsString, Object, Uint8Array}; +use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array}; use pin_project_lite::pin_project; use rustls_pki_types::{CertificateDer, ServerName, UnixTime}; use send_wrapper::SendWrapper; @@ -372,55 +372,6 @@ pub fn is_null_body(code: u16) -> bool { } #[wasm_bindgen(inline_js = r#" -class WebSocketStreamPonyfill { - url; - opened; - closed; - close; - constructor(url, options = {}) { - if (options.signal?.aborted) { - throw new DOMException('This operation was aborted', 'AbortError'); - } - this.url = url; - const ws = new WebSocket(url, options.protocols ?? []); - ws.binaryType = "arraybuffer"; - const closeWithInfo = ({ closeCode: code, reason } = {}) => ws.close(code, reason); - this.opened = new Promise((resolve, reject) => { - ws.onopen = () => { - resolve({ - readable: new ReadableStream({ - start(controller) { - ws.onmessage = ({ data }) => controller.enqueue(data); - ws.onerror = e => controller.error(e); - }, - cancel: closeWithInfo, - }), - writable: new WritableStream({ - write(chunk) { ws.send(chunk); }, - abort() { ws.close(); }, - close: closeWithInfo, - }), - protocol: ws.protocol, - extensions: ws.extensions, - }); - ws.removeEventListener('error', reject); - }; - ws.addEventListener('error', reject); - }); - this.closed = new Promise((resolve, reject) => { - ws.onclose = ({ code, reason }) => { - resolve({ closeCode: code, reason }); - ws.removeEventListener('error', reject); - }; - ws.addEventListener('error', reject); - }); - if (options.signal) { - options.signal.onabort = () => ws.close(); - } - this.close = closeWithInfo; - } -} - export function object_get(obj, k) { try { return obj[k] @@ -457,16 +408,6 @@ export function from_entries(entries){ for(var i = 0; i < entries.length; i++) ret[entries[i][0]] = entries[i][1]; return ret; } - -async function websocket_connect(url, protocols) { - let wss = new (typeof WebSocketStream !== "undefined" ? WebSocketStream : WebSocketStreamPonyfill)(url, { protocols: protocols }); - let {readable, writable} = await wss.opened; - return {read: readable, write: writable}; -} - -export function bind_ws_connect(url, protocols) { - return websocket_connect.bind(undefined, url, protocols); -} "#)] extern "C" { pub fn object_get(obj: &Object, key: &str) -> JsValue; @@ -481,8 +422,6 @@ extern "C" { #[wasm_bindgen(catch)] pub fn from_entries(iterable: &JsValue) -> Result; - - pub fn bind_ws_connect(url: String, protocols: Vec) -> Function; } pub async fn convert_body(val: JsValue) -> Result<(Uint8Array, Option), JsValue> { diff --git a/client/src/ws_wrapper.rs b/client/src/ws_wrapper.rs new file mode 100644 index 0000000..52aeb0f --- /dev/null +++ b/client/src/ws_wrapper.rs @@ -0,0 +1,231 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use async_trait::async_trait; +use bytes::BytesMut; +use event_listener::Event; +use flume::Receiver; +use futures_util::FutureExt; +use js_sys::{Array, ArrayBuffer, Uint8Array}; +use send_wrapper::SendWrapper; +use wasm_bindgen::{closure::Closure, JsCast}; +use web_sys::{BinaryType, MessageEvent, WebSocket}; +use wisp_mux::{ + ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + WispError, +}; + +use crate::EpoxyError; + +#[derive(Debug)] +pub enum WebSocketError { + Unknown, + SendFailed, + CloseFailed, +} + +impl std::fmt::Display for WebSocketError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + use WebSocketError::*; + match self { + Unknown => write!(f, "Unknown error"), + SendFailed => write!(f, "Send failed"), + CloseFailed => write!(f, "Close failed"), + } + } +} + +impl std::error::Error for WebSocketError {} + +impl From for WispError { + fn from(err: WebSocketError) -> Self { + Self::WsImplError(Box::new(err)) + } +} + +pub enum WebSocketMessage { + Closed, + Error, + Message(Vec), +} + +pub struct WebSocketWrapper { + inner: SendWrapper, + open_event: Arc, + error_event: Arc, + close_event: Arc, + closed: Arc, + + // used to retain the closures + #[allow(dead_code)] + onopen: SendWrapper>, + #[allow(dead_code)] + onclose: SendWrapper>, + #[allow(dead_code)] + onerror: SendWrapper>, + #[allow(dead_code)] + onmessage: SendWrapper>, +} + +pub struct WebSocketReader { + read_rx: Receiver, + closed: Arc, + close_event: Arc, +} + +#[async_trait] +impl WebSocketRead for WebSocketReader { + async fn wisp_read_frame( + &mut self, + _: &LockedWebSocketWrite, + ) -> Result, WispError> { + use WebSocketMessage::*; + if self.closed.load(Ordering::Acquire) { + return Err(WispError::WsImplSocketClosed); + } + let res = futures_util::select! { + data = self.read_rx.recv_async() => data.ok(), + _ = self.close_event.listen().fuse() => Some(Closed), + }; + match res.ok_or(WispError::WsImplSocketClosed)? { + Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from( + bin.as_slice(), + )))), + Error => Err(WebSocketError::Unknown.into()), + Closed => Err(WispError::WsImplSocketClosed), + } + } +} + +impl WebSocketWrapper { + pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> { + let (read_tx, read_rx) = flume::unbounded(); + let closed = Arc::new(AtomicBool::new(false)); + + let open_event = Arc::new(Event::new()); + let close_event = Arc::new(Event::new()); + let error_event = Arc::new(Event::new()); + + let onopen_event = open_event.clone(); + let onopen = Closure::wrap( + Box::new(move || while onopen_event.notify(usize::MAX) == 0 {}) as Box, + ); + + let onmessage_tx = read_tx.clone(); + let onmessage = Closure::wrap(Box::new(move |evt: MessageEvent| { + if let Ok(arr) = evt.data().dyn_into::() { + let _ = + onmessage_tx.send(WebSocketMessage::Message(Uint8Array::new(&arr).to_vec())); + } + }) as Box); + + let onclose_closed = closed.clone(); + let onclose_event = close_event.clone(); + let onclose = Closure::wrap(Box::new(move || { + onclose_closed.store(true, Ordering::Release); + onclose_event.notify(usize::MAX); + }) as Box); + + let onerror_tx = read_tx.clone(); + let onerror_closed = closed.clone(); + let onerror_close = close_event.clone(); + let onerror_event = error_event.clone(); + let onerror = Closure::wrap(Box::new(move || { + let _ = onerror_tx.send(WebSocketMessage::Error); + onerror_closed.store(true, Ordering::Release); + onerror_close.notify(usize::MAX); + onerror_event.notify(usize::MAX); + }) as Box); + + let ws = if protocols.is_empty() { + WebSocket::new(url) + } else { + WebSocket::new_with_str_sequence( + url, + &protocols + .iter() + .fold(Array::new(), |acc, x| { + acc.push(&x.into()); + acc + }) + .into(), + ) + } + .map_err(|_| EpoxyError::WebSocketConnectFailed)?; + ws.set_binary_type(BinaryType::Arraybuffer); + ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); + ws.set_onclose(Some(onclose.as_ref().unchecked_ref())); + ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); + + Ok(( + Self { + inner: SendWrapper::new(ws), + open_event, + error_event, + close_event: close_event.clone(), + closed: closed.clone(), + onopen: SendWrapper::new(onopen), + onclose: SendWrapper::new(onclose), + onerror: SendWrapper::new(onerror), + onmessage: SendWrapper::new(onmessage), + }, + WebSocketReader { + read_rx, + closed, + close_event, + }, + )) + } + + pub async fn wait_for_open(&self) -> bool { + if self.closed.load(Ordering::Acquire) { + return false; + } + futures_util::select! { + _ = self.open_event.listen().fuse() => true, + _ = self.error_event.listen().fuse() => false, + } + } +} + +#[async_trait] +impl WebSocketWrite for WebSocketWrapper { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + use wisp_mux::ws::OpCode::*; + if self.closed.load(Ordering::Acquire) { + return Err(WispError::WsImplSocketClosed); + } + match frame.opcode { + Binary | Text => self + .inner + .send_with_u8_array(&frame.payload) + .map_err(|_| WebSocketError::SendFailed.into()), + Close => { + let _ = self.inner.close(); + Ok(()) + } + _ => Err(WispError::WsImplNotSupported), + } + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.inner + .close() + .map_err(|_| WebSocketError::CloseFailed.into()) + } +} + +impl Drop for WebSocketWrapper { + fn drop(&mut self) { + self.inner.set_onopen(None); + self.inner.set_onclose(None); + self.inner.set_onerror(None); + self.inner.set_onmessage(None); + self.closed.store(true, Ordering::Release); + self.close_event.notify(usize::MAX); + let _ = self.inner.close(); + } +}