use std::{ io::ErrorKind, pin::Pin, task::{Context, Poll}, }; use async_trait::async_trait; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use futures_rustls::{ rustls::{ self, client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider}, DigitallySignedStruct, SignatureScheme, }, TlsStream, }; use futures_util::{ready, AsyncRead, AsyncWrite, Future, Stream, StreamExt, TryStreamExt}; use http::{HeaderValue, Uri}; use hyper::{body::Body, rt::Executor}; use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array}; use pin_project_lite::pin_project; use rustls_pki_types::{CertificateDer, ServerName, UnixTime}; use send_wrapper::SendWrapper; use wasm_bindgen::{prelude::*, JsCast, JsValue}; use wasm_bindgen_futures::JsFuture; use wasm_streams::{readable::IntoStream, ReadableStream}; use web_sys::WritableStreamDefaultWriter; use wisp_mux::{ ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, WispError, }; use crate::{stream_provider::ProviderUnencryptedAsyncRW, EpoxyError}; #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = console, js_name = log)] pub fn js_console_log(s: &str); #[wasm_bindgen(js_namespace = console, js_name = error)] pub fn js_console_error(s: &str); } #[macro_export] macro_rules! console_log { ($($expr:expr),*) => { $crate::utils::js_console_log(&format!($($expr),*)); }; } #[macro_export] macro_rules! console_error { ($($expr:expr),*) => { $crate::utils::js_console_error(&format!($($expr),*)); }; } pub trait UriExt { fn get_redirect(&self, location: &HeaderValue) -> Result; } impl UriExt for Uri { fn get_redirect(&self, location: &HeaderValue) -> Result { let new_uri = location.to_str()?.parse::()?; let mut new_parts: http::uri::Parts = new_uri.into(); if new_parts.scheme.is_none() { new_parts.scheme = self.scheme().cloned(); } if new_parts.authority.is_none() { new_parts.authority = self.authority().cloned(); } Ok(Uri::from_parts(new_parts)?) } } #[derive(Clone)] pub struct WasmExecutor; impl Executor for WasmExecutor where F: Future + Send + 'static, F::Output: Send + 'static, { fn execute(&self, future: F) { wasm_bindgen_futures::spawn_local(async move { let _ = future.await; }); } } pin_project! { pub struct IncomingBody { #[pin] incoming: hyper::body::Incoming, } } impl IncomingBody { pub fn new(incoming: hyper::body::Incoming) -> IncomingBody { IncomingBody { incoming } } } impl Stream for IncomingBody { type Item = std::io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().incoming.poll_frame(cx).map(|x| { x.map(|x| { x.map_err(std::io::Error::other).and_then(|x| { x.into_data().map_err(|_| { std::io::Error::other("trailer frame recieved; not implemented") }) }) }) }) } } pin_project! { #[derive(Debug)] pub struct ReaderStream { #[pin] reader: Option, buf: BytesMut, capacity: usize, } } impl ReaderStream { pub fn new(reader: R, capacity: usize) -> Self { ReaderStream { reader: Some(reader), buf: BytesMut::new(), capacity, } } } pub fn poll_read_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { if !buf.has_remaining_mut() { return Poll::Ready(Ok(0)); } let n = { let dst = buf.chunk_mut(); let dst = unsafe { std::mem::transmute::<&mut UninitSlice, &mut [u8]>(dst) }; ready!(io.poll_read(cx, dst)?) }; unsafe { buf.advance_mut(n); } Poll::Ready(Ok(n)) } impl Stream for ReaderStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; if this.buf.capacity() == 0 { this.buf.reserve(*this.capacity); } match poll_read_buf(reader, cx, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { self.project().reader.set(None); Poll::Ready(Some(Err(err))) } Poll::Ready(Ok(0)) => { self.project().reader.set(None); Poll::Ready(None) } Poll::Ready(Ok(_)) => { let chunk = this.buf.split(); Poll::Ready(Some(Ok(chunk.freeze()))) } } } } pub struct WispTransportRead { pub inner: SendWrapper>, } #[async_trait] impl WebSocketRead for WispTransportRead { async fn wisp_read_frame( &mut self, _tx: &LockedWebSocketWrite, ) -> Result, wisp_mux::WispError> { let obj = self.inner.next().await; if let Some(pkt) = obj { let pkt = pkt.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))?; let arr: ArrayBuffer = pkt.dyn_into().map_err(|x| { WispError::WsImplError(Box::new(EpoxyError::InvalidWispTransportPacket(format!( "{:?}", x )))) })?; Ok(Frame::binary(Payload::Bytes( Uint8Array::new(&arr).to_vec().as_slice().into(), ))) } else { Ok(Frame::close(Payload::Borrowed(&[]))) } } } pub struct WispTransportWrite { pub inner: SendWrapper, } #[async_trait] impl WebSocketWrite for WispTransportWrite { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { SendWrapper::new(async { let chunk = Uint8Array::from(frame.payload.as_ref()).into(); JsFuture::from(self.inner.write_with_chunk(&chunk)) .await .map(|_| ()) .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) }) .await } async fn wisp_close(&mut self) -> Result<(), WispError> { SendWrapper::new(JsFuture::from(self.inner.abort())) .await .map(|_| ()) .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) } } fn map_close_notify(x: std::io::Result) -> std::io::Result { match x { Ok(x) => Ok(x), Err(x) => { // hacky way to find if it's actually a rustls close notify error if x.kind() == ErrorKind::UnexpectedEof && format!("{:?}", x).contains("TLS close_notify") { Ok(0) } else { Err(x) } } } } pin_project! { pub struct IgnoreCloseNotify { #[pin] pub inner: TlsStream, } } impl AsyncRead for IgnoreCloseNotify { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { self.project() .inner .poll_read(cx, buf) .map(map_close_notify) } fn poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll> { self.project() .inner .poll_read_vectored(cx, bufs) .map(map_close_notify) } } impl AsyncWrite for IgnoreCloseNotify { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.project().inner.poll_write(cx, buf) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll> { self.project().inner.poll_write_vectored(cx, bufs) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } #[derive(Debug)] pub struct NoCertificateVerification(CryptoProvider); impl NoCertificateVerification { pub fn new(provider: CryptoProvider) -> Self { Self(provider) } } impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &ServerName<'_>, _ocsp: &[u8], _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, ) -> Result { verify_tls12_signature( message, cert, dss, &self.0.signature_verification_algorithms, ) } fn verify_tls13_signature( &self, message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, ) -> Result { verify_tls13_signature( message, cert, dss, &self.0.signature_verification_algorithms, ) } fn supported_verify_schemes(&self) -> Vec { self.0.signature_verification_algorithms.supported_schemes() } } pub fn is_redirect(code: u16) -> bool { [301, 302, 303, 307, 308].contains(&code) } pub fn is_null_body(code: u16) -> bool { [101, 204, 205, 304].contains(&code) } #[wasm_bindgen(inline_js = r#" export function object_get(obj, k) { try { return obj[k] } catch(x) { return undefined } }; export function object_set(obj, k, v) { try { obj[k] = v } catch {} }; export async function convert_body_inner(body) { let req = new Request("", { method: "POST", duplex: "half", body }); let type = req.headers.get("content-type"); return [new Uint8Array(await req.arrayBuffer()), type]; } export function entries_of_object_inner(obj) { return Object.entries(obj).map(x => x.map(String)); } export function define_property(obj, k, v) { Object.defineProperty(obj, k, { value: v, writable: false }); } export function ws_key() { let key = new Uint8Array(16); crypto.getRandomValues(key); return btoa(String.fromCharCode.apply(null, key)); } export function from_entries(entries){ var ret = {}; for(var i = 0; i < entries.length; i++) ret[entries[i][0]] = entries[i][1]; return ret; } "#)] extern "C" { pub fn object_get(obj: &Object, key: &str) -> JsValue; pub fn object_set(obj: &Object, key: &str, val: JsValue); #[wasm_bindgen(catch)] async fn convert_body_inner(val: JsValue) -> Result; fn entries_of_object_inner(obj: &Object) -> Vec; pub fn define_property(obj: &Object, key: &str, val: JsValue); pub fn ws_key() -> String; #[wasm_bindgen(catch)] pub fn from_entries(iterable: &JsValue) -> Result; } pub async fn convert_body(val: JsValue) -> Result<(Uint8Array, Option), JsValue> { let req: Array = convert_body_inner(val).await?.unchecked_into(); let str: Option = object_truthy(req.at(1)).map(|x| x.unchecked_into()); Ok((req.at(0).unchecked_into(), str.map(Into::into))) } pub fn entries_of_object(obj: &Object) -> Vec> { entries_of_object_inner(obj) .into_iter() .map(|x| { x.iter() .map(|x| x.unchecked_into::().into()) .collect() }) .collect() } pub fn asyncread_to_readablestream( read: Pin>, buffer_size: usize, ) -> web_sys::ReadableStream { ReadableStream::from_stream( ReaderStream::new(read, buffer_size) .map_ok(|x| Uint8Array::from(x.as_ref()).into()) .map_err(|x| EpoxyError::from(x).into()), ) .into_raw() } pub fn object_truthy(val: JsValue) -> Option { if val.is_truthy() { Some(val) } else { None } }