diff --git a/client/Cargo.toml b/client/Cargo.toml index e108a47..e0d6342 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -33,7 +33,7 @@ wasm-bindgen-futures = "0.4.43" wasm-streams = "0.4.0" web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "Url", "WebSocket"] } webpki-roots = "0.26.3" -wisp-mux = { version = "*", path = "../wisp", features = ["wasm"], default-features = false } +wisp-mux = { version = "*", path = "../wisp", features = ["wasm", "generic_stream"], default-features = false } [dependencies.getrandom] version = "*" diff --git a/client/src/lib.rs b/client/src/lib.rs index e57c7e3..7baad83 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -3,11 +3,11 @@ use std::{error::Error, str::FromStr, sync::Arc}; #[cfg(feature = "full")] use async_compression::futures::bufread as async_comp; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use cfg_if::cfg_if; #[cfg(feature = "full")] use futures_util::future::Either; -use futures_util::TryStreamExt; +use futures_util::{StreamExt, TryStreamExt}; use http::{ header::{ InvalidHeaderName, InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, @@ -23,14 +23,14 @@ use hyper::{body::Incoming, Uri}; use hyper_util_wasm::client::legacy::Client; #[cfg(feature = "full")] use io_stream::{iostream_from_asyncrw, iostream_from_stream}; -use js_sys::{Array, Function, Object, Promise}; +use js_sys::{Array, ArrayBuffer, Function, Object, Promise, Uint8Array}; use send_wrapper::SendWrapper; use stream_provider::{ProviderWispTransportGenerator, StreamProvider, StreamProviderService}; use thiserror::Error; use utils::{ asyncread_to_readablestream, convert_streaming_body, entries_of_object, from_entries, is_null_body, is_redirect, object_get, object_set, object_truthy, websocket_transport, - StreamingBody, UriExt, WasmExecutor, WispTransportRead, WispTransportWrite, + StreamingBody, UriExt, WasmExecutor, WispTransportWrite, }; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; @@ -40,6 +40,7 @@ use websocket::EpoxyWebSocket; #[cfg(feature = "full")] use wisp_mux::StreamType; use wisp_mux::{ + generic::GenericWebSocketRead, ws::{WebSocketRead, WebSocketWrite}, CloseReason, }; @@ -337,12 +338,17 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator { } .into(); - let read = WispTransportRead { - inner: SendWrapper::new( - wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into()) - .into_stream(), - ), - }; + let read = GenericWebSocketRead::new(SendWrapper::new( + wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into()) + .into_stream() + .map(|x| { + let pkt = x.map_err(EpoxyError::wisp_transport)?; + let arr: ArrayBuffer = pkt.dyn_into().map_err(|x| { + EpoxyError::InvalidWispTransportPacket(format!("{:?}", x)) + })?; + Ok::(BytesMut::from(Uint8Array::new(&arr).to_vec().as_slice())) + }), + )); let write: WritableStream = object_get(&transport, "write").into(); let write = WispTransportWrite { inner: SendWrapper::new(write.get_writer().map_err(EpoxyError::wisp_transport)?), diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 8e256f7..3124ed1 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -21,9 +21,7 @@ use wisp_mux::{ }; use crate::{ - console_log, - utils::{IgnoreCloseNotify, NoCertificateVerification}, - EpoxyClientOptions, EpoxyError, + console_error, console_log, utils::{IgnoreCloseNotify, NoCertificateVerification}, EpoxyClientOptions, EpoxyError }; pub type ProviderUnencryptedStream = MuxStreamIo; @@ -140,7 +138,10 @@ impl StreamProvider { locked.replace(mux); let current_client = self.current_client.clone(); spawn_local(async move { - console_log!("multiplexor future result: {:?}", fut.await); + match fut.await { + Ok(_) => console_log!("epoxy: wisp multiplexor task ended successfully"), + Err(x) => console_error!("epoxy: wisp multiplexor task ended with an error: {} {:?}", x, x), + } current_client.lock().await.take(); }); Ok(()) diff --git a/client/src/utils.rs b/client/src/utils.rs deleted file mode 100644 index 4dc7fa1..0000000 --- a/client/src/utils.rs +++ /dev/null @@ -1,623 +0,0 @@ -use std::{ - io::ErrorKind, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use async_trait::async_trait; -use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; -use futures_rustls::{ - rustls::{ - self, - client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, - crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider}, - DigitallySignedStruct, SignatureScheme, - }, - TlsStream, -}; -use futures_util::{ready, AsyncRead, AsyncWrite, Future, Stream, StreamExt, TryStreamExt}; -use http::{HeaderValue, Uri}; -use http_body_util::{Either, Full, StreamBody}; -use hyper::rt::Executor; -use js_sys::{Array, ArrayBuffer, Function, JsString, Object, Uint8Array}; -use pin_project_lite::pin_project; -use rustls_pki_types::{CertificateDer, ServerName, UnixTime}; -use send_wrapper::SendWrapper; -use wasm_bindgen::{prelude::*, JsCast, JsValue}; -use wasm_bindgen_futures::JsFuture; -use wasm_streams::{readable::IntoStream, ReadableStream}; -use web_sys::WritableStreamDefaultWriter; -use wisp_mux::{ - ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, - WispError, -}; - -use crate::{stream_provider::ProviderUnencryptedAsyncRW, EpoxyError}; - -#[wasm_bindgen] -extern "C" { - #[wasm_bindgen(js_namespace = console, js_name = log)] - pub fn js_console_log(s: &str); - - #[wasm_bindgen(js_namespace = console, js_name = warn)] - pub fn js_console_warn(s: &str); - - #[wasm_bindgen(js_namespace = console, js_name = error)] - pub fn js_console_error(s: &str); -} - -#[macro_export] -macro_rules! console_log { - ($($expr:expr),*) => { - $crate::utils::js_console_log(&format!($($expr),*)); - }; -} -#[macro_export] -macro_rules! console_warn { - ($($expr:expr),*) => { - $crate::utils::js_console_warn(&format!($($expr),*)); - }; -} - -#[macro_export] -macro_rules! console_error { - ($($expr:expr),*) => { - $crate::utils::js_console_error(&format!($($expr),*)); - }; -} - -pub trait UriExt { - fn get_redirect(&self, location: &HeaderValue) -> Result; -} - -impl UriExt for Uri { - fn get_redirect(&self, location: &HeaderValue) -> Result { - let new_uri = location.to_str()?.parse::()?; - let mut new_parts: http::uri::Parts = new_uri.into(); - if new_parts.scheme.is_none() { - new_parts.scheme = self.scheme().cloned(); - } - if new_parts.authority.is_none() { - new_parts.authority = self.authority().cloned(); - } - - Ok(Uri::from_parts(new_parts)?) - } -} - -#[derive(Clone)] -pub struct WasmExecutor; - -impl Executor for WasmExecutor -where - F: Future + Send + 'static, - F::Output: Send + 'static, -{ - fn execute(&self, future: F) { - wasm_bindgen_futures::spawn_local(async move { - let _ = future.await; - }); - } -} - -pin_project! { - #[derive(Debug)] - pub struct ReaderStream { - #[pin] - reader: Option, - buf: BytesMut, - capacity: usize, - } -} - -impl ReaderStream { - pub fn new(reader: R, capacity: usize) -> Self { - ReaderStream { - reader: Some(reader), - buf: BytesMut::new(), - capacity, - } - } -} - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, -) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let n = { - let dst = buf.chunk_mut(); - - let dst = unsafe { std::mem::transmute::<&mut UninitSlice, &mut [u8]>(dst) }; - ready!(io.poll_read(cx, dst)?) - }; - - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -} - -impl Stream for ReaderStream { - type Item = std::io::Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.as_mut().project(); - - let reader = match this.reader.as_pin_mut() { - Some(r) => r, - None => return Poll::Ready(None), - }; - - if this.buf.capacity() == 0 { - this.buf.reserve(*this.capacity); - } - - match poll_read_buf(reader, cx, &mut this.buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => { - self.project().reader.set(None); - Poll::Ready(Some(Err(err))) - } - Poll::Ready(Ok(0)) => { - self.project().reader.set(None); - Poll::Ready(None) - } - Poll::Ready(Ok(_)) => { - let chunk = this.buf.split(); - Poll::Ready(Some(Ok(chunk.freeze()))) - } - } - } -} - -pub struct WispTransportRead { - pub inner: SendWrapper>, -} - -#[async_trait] -impl WebSocketRead for WispTransportRead { - async fn wisp_read_frame( - &mut self, - _tx: &LockedWebSocketWrite, - ) -> Result, wisp_mux::WispError> { - let obj = self.inner.next().await; - - if let Some(pkt) = obj { - let pkt = - pkt.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))?; - let arr: ArrayBuffer = pkt.dyn_into().map_err(|x| { - WispError::WsImplError(Box::new(EpoxyError::InvalidWispTransportPacket(format!( - "{:?}", - x - )))) - })?; - - Ok(Frame::binary(Payload::Bytes( - Uint8Array::new(&arr).to_vec().as_slice().into(), - ))) - } else { - Ok(Frame::close(Payload::Borrowed(&[]))) - } - } -} - -pub struct WispTransportWrite { - pub inner: SendWrapper, -} - -#[async_trait] -impl WebSocketWrite for WispTransportWrite { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - SendWrapper::new(async { - let chunk = Uint8Array::from(frame.payload.as_ref()).into(); - JsFuture::from(self.inner.write_with_chunk(&chunk)) - .await - .map(|_| ()) - .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) - }) - .await - } - - async fn wisp_close(&mut self) -> Result<(), WispError> { - SendWrapper::new(JsFuture::from(self.inner.abort())) - .await - .map(|_| ()) - .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) - } -} - -fn map_close_notify(x: std::io::Result) -> std::io::Result { - match x { - Ok(x) => Ok(x), - Err(x) => { - // hacky way to find if it's actually a rustls close notify error - if x.kind() == ErrorKind::UnexpectedEof - && format!("{:?}", x).contains("TLS close_notify") - { - Ok(0) - } else { - Err(x) - } - } - } -} - -pin_project! { - pub struct IgnoreCloseNotify { - #[pin] - pub inner: TlsStream, - pub h2_negotiated: bool, - } -} - -impl AsyncRead for IgnoreCloseNotify { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project() - .inner - .poll_read(cx, buf) - .map(map_close_notify) - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - self.project() - .inner - .poll_read_vectored(cx, bufs) - .map(map_close_notify) - } -} - -impl AsyncWrite for IgnoreCloseNotify { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_close(cx) - } -} - -#[derive(Debug)] -pub struct NoCertificateVerification(pub Arc); - -impl ServerCertVerifier for NoCertificateVerification { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls12_signature( - message, - cert, - dss, - &self.0.signature_verification_algorithms, - ) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls13_signature( - message, - cert, - dss, - &self.0.signature_verification_algorithms, - ) - } - - fn supported_verify_schemes(&self) -> Vec { - self.0.signature_verification_algorithms.supported_schemes() - } -} - -pub fn is_redirect(code: u16) -> bool { - [301, 302, 303, 307, 308].contains(&code) -} - -pub fn is_null_body(code: u16) -> bool { - [101, 204, 205, 304].contains(&code) -} - -#[wasm_bindgen(inline_js = r#" -class WebSocketStreamPonyfill { - url; - opened; - closed; - close; - constructor(url, options = {}) { - if (options.signal?.aborted) { - throw new DOMException('This operation was aborted', 'AbortError'); - } - this.url = url; - const ws = new WebSocket(url, options.protocols ?? []); - ws.binaryType = "arraybuffer"; - const closeWithInfo = ({ closeCode: code, reason } = {}) => ws.close(code, reason); - this.opened = new Promise((resolve, reject) => { - const errorHandler = ()=>reject(new Error("WebSocket closed before handshake complete.")); - ws.onopen = () => { - resolve({ - readable: new ReadableStream({ - start(controller) { - ws.onmessage = ({ data }) => controller.enqueue(data); - ws.onerror = e => controller.error(e); - }, - cancel: closeWithInfo, - }), - writable: new WritableStream({ - write(chunk) { ws.send(chunk); }, - abort() { ws.close(); }, - close: closeWithInfo, - }), - protocol: ws.protocol, - extensions: ws.extensions, - }); - ws.removeEventListener('error', errorHandler); - }; - ws.addEventListener('error', errorHandler); - }); - this.closed = new Promise((resolve, reject) => { - ws.onclose = ({ code, reason }) => { - resolve({ closeCode: code, reason }); - ws.removeEventListener('error', reject); - }; - ws.addEventListener('error', reject); - }); - if (options.signal) { - options.signal.onabort = () => ws.close(); - } - this.close = closeWithInfo; - } -} - -function ws_protocol() { - return ( - [1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, - c => (c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16) - ); -} - -export function websocket_transport(url, protocols) { - const ws_impl = typeof WebSocketStream === "undefined" ? WebSocketStreamPonyfill : WebSocketStream; - return async (wisp_v2)=>{ - if (wisp_v2) protocols.push(ws_protocol()); - const ws = new ws_impl(url, { protocols }); - const { readable, writable } = await ws.opened; - return { read: readable, write: writable }; - } -} - -export function object_get(obj, k) { - try { - return obj[k] - } catch(x) { - return undefined - } -}; -export function object_set(obj, k, v) { - try { obj[k] = v } catch {} -}; - -export async function convert_body_inner(body) { - let req = new Request("", { method: "POST", duplex: "half", body }); - let type = req.headers.get("content-type"); - return [new Uint8Array(await req.arrayBuffer()), type]; -} - -export async function convert_streaming_body_inner(body) { - try { - let req = new Request("", { method: "POST", body }); - let type = req.headers.get("content-type"); - return [false, new Uint8Array(await req.arrayBuffer()), type]; - } catch(x) { - let req = new Request("", { method: "POST", duplex: "half", body }); - let type = req.headers.get("content-type"); - return [true, req.body, type]; - } -} - -export function entries_of_object_inner(obj) { - return Object.entries(obj).map(x => x.map(String)); -} - -export function define_property(obj, k, v) { - Object.defineProperty(obj, k, { value: v, writable: false }); -} - -export function ws_key() { - let key = new Uint8Array(16); - crypto.getRandomValues(key); - return btoa(String.fromCharCode.apply(null, key)); -} - -export function from_entries(entries){ - var ret = {}; - for(var i = 0; i < entries.length; i++) ret[entries[i][0]] = entries[i][1]; - return ret; -} -"#)] -extern "C" { - pub fn websocket_transport(url: String, protocols: Vec) -> Function; - - pub fn object_get(obj: &Object, key: &str) -> JsValue; - pub fn object_set(obj: &Object, key: &str, val: JsValue); - - #[wasm_bindgen(catch)] - async fn convert_body_inner(val: JsValue) -> Result; - #[wasm_bindgen(catch)] - async fn convert_streaming_body_inner(val: JsValue) -> Result; - - fn entries_of_object_inner(obj: &Object) -> Vec; - pub fn define_property(obj: &Object, key: &str, val: JsValue); - pub fn ws_key() -> String; - - #[wasm_bindgen(catch)] - pub fn from_entries(iterable: &JsValue) -> Result; -} - -pub async fn convert_body(val: JsValue) -> Result<(Uint8Array, Option), JsValue> { - let req: Array = convert_body_inner(val).await?.unchecked_into(); - let content_type: Option = object_truthy(req.at(1)).map(|x| x.unchecked_into()); - Ok((req.at(0).unchecked_into(), content_type.map(Into::into))) -} - -pub enum MaybeStreamingBody { - Streaming(web_sys::ReadableStream), - Static(Uint8Array), -} - -pub struct StreamingInnerBody( - Pin, std::io::Error>> + Send>>, - SendWrapper, -); -impl StreamingInnerBody { - pub fn from_teed(a: ReadableStream, b: ReadableStream) -> Result { - let reader = a - .try_into_stream() - .map_err(|x| EpoxyError::StreamingBodyConvertFailed(format!("{:?}", x)))?; - let reader = reader - .then(|x| async { - Ok::(Bytes::from(convert_body(x?).await?.0.to_vec())) - }) - .map_ok(http_body::Frame::data); - let reader = reader.map_err(|x| std::io::Error::other(format!("{:?}", x))); - let reader = Box::pin(SendWrapper::new(reader)); - - Ok(Self(reader, SendWrapper::new(b))) - } -} -impl Stream for StreamingInnerBody { - type Item = Result, std::io::Error>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_next_unpin(cx) - } - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } -} -impl Clone for StreamingInnerBody { - fn clone(&self) -> Self { - match ReadableStream::from_raw(self.1.as_raw().clone()) - .try_tee() - .map_err(|x| EpoxyError::StreamingBodyTeeFailed(format!("{:?}", x))) - .and_then(|(a, b)| StreamingInnerBody::from_teed(a, b)) - { - Ok(x) => x, - Err(x) => { - console_error!( - "epoxy internal error: failed to clone streaming body: {:?}", - x - ); - unreachable!("failed to clone streaming body"); - } - } - } -} - -pub type StreamingBody = Either, Full>; - -impl MaybeStreamingBody { - pub fn into_httpbody(self) -> Result { - match self { - Self::Streaming(x) => { - let (a, b) = ReadableStream::from_raw(x) - .try_tee() - .map_err(|x| EpoxyError::StreamingBodyTeeFailed(format!("{:?}", x)))?; - - Ok(Either::Left(StreamBody::new( - StreamingInnerBody::from_teed(a, b)?, - ))) - } - Self::Static(x) => Ok(Either::Right(Full::new(Bytes::from(x.to_vec())))), - } - } -} - -pub async fn convert_streaming_body( - val: JsValue, -) -> Result<(MaybeStreamingBody, Option), JsValue> { - let req: Array = convert_streaming_body_inner(val).await?.unchecked_into(); - let content_type: Option = object_truthy(req.at(2)).map(|x| x.unchecked_into()); - - let body = if req.at(0).is_truthy() { - MaybeStreamingBody::Streaming(req.at(1).unchecked_into()) - } else { - MaybeStreamingBody::Static(req.at(1).unchecked_into()) - }; - - Ok((body, content_type.map(Into::into))) -} - -pub fn entries_of_object(obj: &Object) -> Vec> { - entries_of_object_inner(obj) - .into_iter() - .map(|x| { - x.iter() - .map(|x| x.unchecked_into::().into()) - .collect() - }) - .collect() -} - -pub fn asyncread_to_readablestream( - read: Pin>, - buffer_size: usize, -) -> web_sys::ReadableStream { - ReadableStream::from_stream( - ReaderStream::new(read, buffer_size) - .map_ok(|x| Uint8Array::from(x.as_ref()).into()) - .map_err(|x| EpoxyError::from(x).into()), - ) - .into_raw() -} - -pub fn object_truthy(val: JsValue) -> Option { - if val.is_truthy() { - Some(val) - } else { - None - } -} diff --git a/client/src/utils/js.rs b/client/src/utils/js.rs new file mode 100644 index 0000000..b55b05f --- /dev/null +++ b/client/src/utils/js.rs @@ -0,0 +1,274 @@ +use std::{pin::Pin, task::{Context, Poll}}; + +use bytes::Bytes; +use futures_util::{AsyncRead, Stream, StreamExt, TryStreamExt}; +use http_body_util::{Either, Full, StreamBody}; +use js_sys::{Array, Function, JsString, Object, Uint8Array}; +use send_wrapper::SendWrapper; +use wasm_bindgen::{prelude::*, JsCast, JsValue}; +use wasm_streams::ReadableStream; + +use crate::{console_error, EpoxyError}; + +use super::ReaderStream; + + +#[wasm_bindgen(inline_js = r#" +class WebSocketStreamPonyfill { + url; + opened; + closed; + close; + constructor(url, options = {}) { + if (options.signal?.aborted) { + throw new DOMException('This operation was aborted', 'AbortError'); + } + this.url = url; + const ws = new WebSocket(url, options.protocols ?? []); + ws.binaryType = "arraybuffer"; + const closeWithInfo = ({ closeCode: code, reason } = {}) => ws.close(code, reason); + this.opened = new Promise((resolve, reject) => { + const errorHandler = ()=>reject(new Error("WebSocket closed before handshake complete.")); + ws.onopen = () => { + resolve({ + readable: new ReadableStream({ + start(controller) { + ws.onmessage = ({ data }) => controller.enqueue(data); + ws.onerror = e => controller.error(e); + }, + cancel: closeWithInfo, + }), + writable: new WritableStream({ + write(chunk) { ws.send(chunk); }, + abort() { ws.close(); }, + close: closeWithInfo, + }), + protocol: ws.protocol, + extensions: ws.extensions, + }); + ws.removeEventListener('error', errorHandler); + }; + ws.addEventListener('error', errorHandler); + }); + this.closed = new Promise((resolve, reject) => { + ws.onclose = ({ code, reason }) => { + resolve({ closeCode: code, reason }); + ws.removeEventListener('error', reject); + }; + ws.addEventListener('error', reject); + }); + if (options.signal) { + options.signal.onabort = () => ws.close(); + } + this.close = closeWithInfo; + } +} + +function ws_protocol() { + return ( + [1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, + c => (c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16) + ); +} + +export function websocket_transport(url, protocols) { + const ws_impl = typeof WebSocketStream === "undefined" ? WebSocketStreamPonyfill : WebSocketStream; + return async (wisp_v2)=>{ + if (wisp_v2) protocols.push(ws_protocol()); + const ws = new ws_impl(url, { protocols }); + const { readable, writable } = await ws.opened; + return { read: readable, write: writable }; + } +} + +export function object_get(obj, k) { + try { + return obj[k] + } catch(x) { + return undefined + } +}; +export function object_set(obj, k, v) { + try { obj[k] = v } catch {} +}; + +export async function convert_body_inner(body) { + let req = new Request("", { method: "POST", duplex: "half", body }); + let type = req.headers.get("content-type"); + return [new Uint8Array(await req.arrayBuffer()), type]; +} + +export async function convert_streaming_body_inner(body) { + try { + let req = new Request("", { method: "POST", body }); + let type = req.headers.get("content-type"); + return [false, new Uint8Array(await req.arrayBuffer()), type]; + } catch(x) { + let req = new Request("", { method: "POST", duplex: "half", body }); + let type = req.headers.get("content-type"); + return [true, req.body, type]; + } +} + +export function entries_of_object_inner(obj) { + return Object.entries(obj).map(x => x.map(String)); +} + +export function define_property(obj, k, v) { + Object.defineProperty(obj, k, { value: v, writable: false }); +} + +export function ws_key() { + let key = new Uint8Array(16); + crypto.getRandomValues(key); + return btoa(String.fromCharCode.apply(null, key)); +} + +export function from_entries(entries){ + var ret = {}; + for(var i = 0; i < entries.length; i++) ret[entries[i][0]] = entries[i][1]; + return ret; +} +"#)] +extern "C" { + pub fn websocket_transport(url: String, protocols: Vec) -> Function; + + pub fn object_get(obj: &Object, key: &str) -> JsValue; + pub fn object_set(obj: &Object, key: &str, val: JsValue); + + #[wasm_bindgen(catch)] + async fn convert_body_inner(val: JsValue) -> Result; + #[wasm_bindgen(catch)] + async fn convert_streaming_body_inner(val: JsValue) -> Result; + + fn entries_of_object_inner(obj: &Object) -> Vec; + pub fn define_property(obj: &Object, key: &str, val: JsValue); + pub fn ws_key() -> String; + + #[wasm_bindgen(catch)] + pub fn from_entries(iterable: &JsValue) -> Result; +} + +pub async fn convert_body(val: JsValue) -> Result<(Uint8Array, Option), JsValue> { + let req: Array = convert_body_inner(val).await?.unchecked_into(); + let content_type: Option = object_truthy(req.at(1)).map(|x| x.unchecked_into()); + Ok((req.at(0).unchecked_into(), content_type.map(Into::into))) +} + +pub enum MaybeStreamingBody { + Streaming(web_sys::ReadableStream), + Static(Uint8Array), +} + +pub struct StreamingInnerBody( + Pin, std::io::Error>> + Send>>, + SendWrapper, +); +impl StreamingInnerBody { + pub fn from_teed(a: ReadableStream, b: ReadableStream) -> Result { + let reader = a + .try_into_stream() + .map_err(|x| EpoxyError::StreamingBodyConvertFailed(format!("{:?}", x)))?; + let reader = reader + .then(|x| async { + Ok::(Bytes::from(convert_body(x?).await?.0.to_vec())) + }) + .map_ok(http_body::Frame::data); + let reader = reader.map_err(|x| std::io::Error::other(format!("{:?}", x))); + let reader = Box::pin(SendWrapper::new(reader)); + + Ok(Self(reader, SendWrapper::new(b))) + } +} +impl Stream for StreamingInnerBody { + type Item = Result, std::io::Error>; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next_unpin(cx) + } + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} +impl Clone for StreamingInnerBody { + fn clone(&self) -> Self { + match ReadableStream::from_raw(self.1.as_raw().clone()) + .try_tee() + .map_err(|x| EpoxyError::StreamingBodyTeeFailed(format!("{:?}", x))) + .and_then(|(a, b)| StreamingInnerBody::from_teed(a, b)) + { + Ok(x) => x, + Err(x) => { + console_error!( + "epoxy internal error: failed to clone streaming body: {:?}", + x + ); + unreachable!("failed to clone streaming body"); + } + } + } +} + +pub type StreamingBody = Either, Full>; + +impl MaybeStreamingBody { + pub fn into_httpbody(self) -> Result { + match self { + Self::Streaming(x) => { + let (a, b) = ReadableStream::from_raw(x) + .try_tee() + .map_err(|x| EpoxyError::StreamingBodyTeeFailed(format!("{:?}", x)))?; + + Ok(Either::Left(StreamBody::new( + StreamingInnerBody::from_teed(a, b)?, + ))) + } + Self::Static(x) => Ok(Either::Right(Full::new(Bytes::from(x.to_vec())))), + } + } +} + +pub async fn convert_streaming_body( + val: JsValue, +) -> Result<(MaybeStreamingBody, Option), JsValue> { + let req: Array = convert_streaming_body_inner(val).await?.unchecked_into(); + let content_type: Option = object_truthy(req.at(2)).map(|x| x.unchecked_into()); + + let body = if req.at(0).is_truthy() { + MaybeStreamingBody::Streaming(req.at(1).unchecked_into()) + } else { + MaybeStreamingBody::Static(req.at(1).unchecked_into()) + }; + + Ok((body, content_type.map(Into::into))) +} + +pub fn entries_of_object(obj: &Object) -> Vec> { + entries_of_object_inner(obj) + .into_iter() + .map(|x| { + x.iter() + .map(|x| x.unchecked_into::().into()) + .collect() + }) + .collect() +} + +pub fn asyncread_to_readablestream( + read: Pin>, + buffer_size: usize, +) -> web_sys::ReadableStream { + ReadableStream::from_stream( + ReaderStream::new(read, buffer_size) + .map_ok(|x| Uint8Array::from(x.as_ref()).into()) + .map_err(|x| EpoxyError::from(x).into()), + ) + .into_raw() +} + +pub fn object_truthy(val: JsValue) -> Option { + if val.is_truthy() { + Some(val) + } else { + None + } +} diff --git a/client/src/utils/mod.rs b/client/src/utils/mod.rs new file mode 100644 index 0000000..27badd6 --- /dev/null +++ b/client/src/utils/mod.rs @@ -0,0 +1,201 @@ +mod js; +mod rustls; +pub use js::*; +pub use rustls::*; + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_trait::async_trait; +use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; +use futures_util::{ready, AsyncRead, Future, Stream}; +use http::{HeaderValue, Uri}; +use hyper::rt::Executor; +use js_sys::Uint8Array; +use pin_project_lite::pin_project; +use send_wrapper::SendWrapper; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; +use web_sys::WritableStreamDefaultWriter; +use wisp_mux::{ + ws::{Frame, WebSocketWrite}, + WispError, +}; + +use crate::EpoxyError; + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = console, js_name = log)] + pub fn js_console_log(s: &str); + + #[wasm_bindgen(js_namespace = console, js_name = warn)] + pub fn js_console_warn(s: &str); + + #[wasm_bindgen(js_namespace = console, js_name = error)] + pub fn js_console_error(s: &str); +} + +#[macro_export] +macro_rules! console_log { + ($($expr:expr),*) => { + $crate::utils::js_console_log(&format!($($expr),*)) + }; +} +#[macro_export] +macro_rules! console_warn { + ($($expr:expr),*) => { + $crate::utils::js_console_warn(&format!($($expr),*)) + }; +} + +#[macro_export] +macro_rules! console_error { + ($($expr:expr),*) => { + $crate::utils::js_console_error(&format!($($expr),*)) + }; +} + +pub fn is_redirect(code: u16) -> bool { + [301, 302, 303, 307, 308].contains(&code) +} + +pub fn is_null_body(code: u16) -> bool { + [101, 204, 205, 304].contains(&code) +} + +pub trait UriExt { + fn get_redirect(&self, location: &HeaderValue) -> Result; +} + +impl UriExt for Uri { + fn get_redirect(&self, location: &HeaderValue) -> Result { + let new_uri = location.to_str()?.parse::()?; + let mut new_parts: http::uri::Parts = new_uri.into(); + if new_parts.scheme.is_none() { + new_parts.scheme = self.scheme().cloned(); + } + if new_parts.authority.is_none() { + new_parts.authority = self.authority().cloned(); + } + + Ok(Uri::from_parts(new_parts)?) + } +} + +#[derive(Clone)] +pub struct WasmExecutor; + +impl Executor for WasmExecutor +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, future: F) { + wasm_bindgen_futures::spawn_local(async move { + let _ = future.await; + }); + } +} + +pin_project! { + #[derive(Debug)] + pub struct ReaderStream { + #[pin] + reader: Option, + buf: BytesMut, + capacity: usize, + } +} + +impl ReaderStream { + pub fn new(reader: R, capacity: usize) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::new(), + capacity, + } + } +} + +pub fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + + let dst = unsafe { std::mem::transmute::<&mut UninitSlice, &mut [u8]>(dst) }; + ready!(io.poll_read(cx, dst)?) + }; + + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} + +impl Stream for ReaderStream { + type Item = std::io::Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + + let reader = match this.reader.as_pin_mut() { + Some(r) => r, + None => return Poll::Ready(None), + }; + + if this.buf.capacity() == 0 { + this.buf.reserve(*this.capacity); + } + + match poll_read_buf(reader, cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + self.project().reader.set(None); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + self.project().reader.set(None); + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } +} + +pub struct WispTransportWrite { + pub inner: SendWrapper, +} + +#[async_trait] +impl WebSocketWrite for WispTransportWrite { + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + SendWrapper::new(async { + let chunk = Uint8Array::from(frame.payload.as_ref()).into(); + JsFuture::from(self.inner.write_with_chunk(&chunk)) + .await + .map(|_| ()) + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + }) + .await + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + SendWrapper::new(JsFuture::from(self.inner.abort())) + .await + .map(|_| ()) + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + } +} diff --git a/client/src/utils/rustls.rs b/client/src/utils/rustls.rs new file mode 100644 index 0000000..602cba5 --- /dev/null +++ b/client/src/utils/rustls.rs @@ -0,0 +1,140 @@ +use std::{ + io::ErrorKind, pin::Pin, sync::Arc, task::{Context, Poll} +}; + +use futures_rustls::{ + rustls::{ + self, + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider}, + DigitallySignedStruct, SignatureScheme, + }, + TlsStream, +}; +use futures_util::{AsyncRead, AsyncWrite}; +use pin_project_lite::pin_project; +use rustls_pki_types::{CertificateDer, ServerName, UnixTime}; + +use crate::stream_provider::ProviderUnencryptedAsyncRW; + +fn map_close_notify(x: std::io::Result) -> std::io::Result { + match x { + Ok(x) => Ok(x), + Err(x) => { + // hacky way to find if it's actually a rustls close notify error + if x.kind() == ErrorKind::UnexpectedEof + && format!("{:?}", x).contains("TLS close_notify") + { + Ok(0) + } else { + Err(x) + } + } + } +} + +pin_project! { + pub struct IgnoreCloseNotify { + #[pin] + pub inner: TlsStream, + pub h2_negotiated: bool, + } +} + +impl AsyncRead for IgnoreCloseNotify { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project() + .inner + .poll_read(cx, buf) + .map(map_close_notify) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project() + .inner + .poll_read_vectored(cx, bufs) + .map(map_close_notify) + } +} + +impl AsyncWrite for IgnoreCloseNotify { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +#[derive(Debug)] +pub struct NoCertificateVerification(pub Arc); + +impl ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} diff --git a/client/src/websocket.rs b/client/src/websocket.rs index b1f9067..82e4dc0 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -228,7 +228,7 @@ async fn request( spawn_local(async move { if let Err(err) = conn.with_upgrades().await { - console_error!("websocket connection future failed: {:?}", err); + console_error!("epoxy: websocket connection task failed: {:?}", err); } });