rewrite client

This commit is contained in:
Toshit Chawda 2024-06-12 11:51:06 -07:00
parent 273063ec28
commit 177a0d2167
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
13 changed files with 1338 additions and 1710 deletions

637
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,40 +1,37 @@
[package] [package]
name = "epoxy-client" name = "epoxy-client"
version = "1.5.1" version = "2.0.0"
edition = "2021" edition = "2021"
license = "LGPL-3.0-only"
[lib] [lib]
crate-type = ["cdylib", "rlib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
bytes = "1.5.0" async-compression = { version = "0.4.11", features = ["futures-io", "gzip", "brotli"] }
http = "1.0.0"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["client", "http1", "http2"] }
pin-project-lite = "0.2.13"
wasm-bindgen = { version = "0.2.91", features = ["enable-interning"] }
wasm-bindgen-futures = "0.4.39"
futures-util = "0.3.30"
js-sys = "0.3.66"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] }
web-sys = { version = "0.3.66", features = ["Request", "RequestInit", "Headers", "Response", "ResponseInit", "WebSocket", "BinaryType", "MessageEvent"] }
wasm-streams = "0.4.0"
tokio-util = { version = "0.7.10", features = ["io"] }
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
base64 = "0.21.7"
wisp-mux = { path = "../wisp", features = ["tokio_io", "wasm"] }
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
getrandom = { version = "0.2.12", features = ["js"] }
hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] }
tokio = { version = "1.36.0", default-features = false }
tower-service = "0.3.2"
console_error_panic_hook = "0.1.7"
send_wrapper = "0.6.0"
event-listener = "5.2.0"
wasmtimer = "0.2.0"
async-trait = "0.1.80" async-trait = "0.1.80"
base64 = "0.22.1"
bytes = "1.6.0"
event-listener = "5.3.1"
fastwebsockets = { version = "0.7.2", features = ["unstable-split"] }
flume = "0.11.0"
futures-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] }
futures-util = { version = "0.3.30", features = ["sink"] }
getrandom = { version = "0.2.15", features = ["js"] }
http = "1.1.0"
http-body-util = "0.1.2"
hyper = "1.3.1"
hyper-util-wasm = { version = "0.1.3", features = ["client-legacy", "http1", "http2"] }
js-sys = "0.3.69"
pin-project-lite = "0.2.14"
send_wrapper = "0.4.0"
thiserror = "1.0.61"
tokio = "1.38.0"
tower-service = "0.3.2"
wasm-bindgen = "0.2.92"
wasm-bindgen-futures = "0.4.42"
wasm-streams = "0.4.0"
web-sys = { version = "0.3.69", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "WebSocket"] }
wisp-mux = { version = "4.0.1", path = "../wisp", features = ["wasm"] }
[dependencies.ring] [dependencies.ring]
# update whenever rustls updates # update whenever rustls updates
@ -45,9 +42,3 @@ features = ["wasm32_unknown_unknown_js"]
# update whenever rustls updates # update whenever rustls updates
version = "1.4.1" version = "1.4.1"
features = ["web"] features = ["web"]
[dev-dependencies]
default-env = "0.1.1"
wasm-bindgen-test = "0.3.42"
web-sys = { version = "0.3.69", features = ["FormData", "UrlSearchParams"] }
webpki-roots = "0.26.0"

174
client/src/io_stream.rs Normal file
View file

@ -0,0 +1,174 @@
use bytes::{BufMut, BytesMut};
use futures_util::{
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
};
use js_sys::{Function, Uint8Array};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::spawn_local;
use crate::{
stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream},
utils::convert_body,
EpoxyError, EpoxyHandlers,
};
#[wasm_bindgen]
pub struct EpoxyIoStream {
tx: Mutex<WriteHalf<ProviderAsyncRW>>,
onerror: Function,
}
impl EpoxyIoStream {
pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self {
let (mut rx, tx) = stream.split();
let tx = Mutex::new(tx);
let EpoxyHandlers {
onopen,
onclose,
onerror,
onmessage,
} = handlers;
let onerror_cloned = onerror.clone();
// similar to tokio::io::ReaderStream
spawn_local(async move {
let mut buf = BytesMut::with_capacity(4096);
loop {
match rx.read(buf.as_mut()).await {
Ok(cnt) => {
unsafe { buf.advance_mut(cnt) };
let _ = onmessage
.call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref()));
}
Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break;
}
}
}
let _ = onclose.call0(&JsValue::null());
});
let _ = onopen.call0(&JsValue::null());
Self {
tx,
onerror: onerror_cloned,
}
}
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret: Result<(), EpoxyError> = async move {
let payload = convert_body(payload)
.await
.map_err(|_| EpoxyError::InvalidPayload)?
.0
.to_vec();
Ok(self.tx.lock().await.write_all(&payload).await?)
}
.await;
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn close(&self) -> Result<(), EpoxyError> {
match self.tx.lock().await.close().await {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err.into())
}
}
}
}
#[wasm_bindgen]
pub struct EpoxyUdpStream {
tx: Mutex<SplitSink<ProviderUnencryptedStream, Vec<u8>>>,
onerror: Function,
}
impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (tx, mut rx) = stream.split();
let tx = Mutex::new(tx);
let EpoxyHandlers {
onopen,
onclose,
onerror,
onmessage,
} = handlers;
let onerror_cloned = onerror.clone();
spawn_local(async move {
while let Some(packet) = rx.next().await {
match packet {
Ok(buf) => {
let _ = onmessage.call1(&JsValue::null(), &Uint8Array::from(buf.as_ref()));
}
Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break;
}
}
}
let _ = onclose.call0(&JsValue::null());
});
let _ = onopen.call0(&JsValue::null());
Self {
tx,
onerror: onerror_cloned,
}
}
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret: Result<(), EpoxyError> = async move {
let payload = convert_body(payload)
.await
.map_err(|_| EpoxyError::InvalidPayload)?
.0
.to_vec();
Ok(self.tx.lock().await.send(payload).await?)
}
.await;
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn close(&self) -> Result<(), EpoxyError> {
match self.tx.lock().await.close().await {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err.into())
}
}
}
}

View file

@ -1,185 +1,333 @@
#![feature(let_chains, impl_trait_in_assoc_type)] #![feature(let_chains, impl_trait_in_assoc_type)]
#[macro_use] use std::{str::FromStr, sync::Arc};
mod utils;
mod tls_stream;
mod tokioio;
mod udp_stream;
mod websocket;
mod wrappers;
use tls_stream::EpxTlsStream; use async_compression::futures::bufread as async_comp;
use tokioio::TokioIo;
use udp_stream::EpxUdpStream;
use utils::object_to_trustanchor;
pub use utils::{Boolinator, ReplaceErr, UriExt};
use websocket::EpxWebSocket;
use wrappers::{IncomingBody, ServiceWrapper, TlsWispService, WebSocketWrapper};
use std::sync::Arc;
use async_compression::tokio::bufread as async_comp;
use async_io_stream::IoStream;
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::{future::Either, TryStreamExt};
use http::{uri, HeaderName, HeaderValue, Request, Response}; use http::{
header::{InvalidHeaderName, InvalidHeaderValue},
method::InvalidMethod,
uri::{InvalidUri, InvalidUriParts},
HeaderName, HeaderValue, Method, Request, Response,
};
use hyper::{body::Incoming, Uri}; use hyper::{body::Incoming, Uri};
use hyper_util_wasm::client::legacy::Client; use hyper_util_wasm::client::legacy::Client;
use js_sys::{Array, Function, Object, Reflect, Uint8Array}; use io_stream::{EpoxyIoStream, EpoxyUdpStream};
use rustls::pki_types::TrustAnchor; use js_sys::{Array, Function, Object, Reflect};
use tokio::sync::RwLock; use stream_provider::{StreamProvider, StreamProviderService};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use thiserror::Error;
use tokio_util::{ use utils::{
either::Either, convert_body, entries_of_object, is_null_body, is_redirect, object_get, object_set,
io::{ReaderStream, StreamReader}, IncomingBody, UriExt, WasmExecutor,
}; };
use wasm_bindgen::{intern, prelude::*}; use wasm_bindgen::prelude::*;
use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; use wasm_streams::ReadableStream;
use web_sys::ResponseInit;
use websocket::EpoxyWebSocket;
use wisp_mux::StreamType;
mod io_stream;
mod stream_provider;
mod tokioio;
mod utils;
mod websocket;
mod ws_wrapper;
type HttpBody = http_body_util::Full<Bytes>; type HttpBody = http_body_util::Full<Bytes>;
#[derive(Debug, Error)]
pub enum EpoxyError {
#[error(transparent)]
InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError),
#[error(transparent)]
Wisp(#[from] wisp_mux::WispError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Http(#[from] http::Error),
#[error(transparent)]
HyperClient(#[from] hyper_util_wasm::client::legacy::Error),
#[error(transparent)]
Hyper(#[from] hyper::Error),
#[error(transparent)]
ToStr(#[from] http::header::ToStrError),
#[error(transparent)]
GetRandom(#[from] getrandom::Error),
#[error(transparent)]
FastWebSockets(#[from] fastwebsockets::WebSocketError),
#[error("Invalid URL scheme")]
InvalidUrlScheme,
#[error("No URL host found")]
NoUrlHost,
#[error("No URL port found")]
NoUrlPort,
#[error("Invalid request body")]
InvalidRequestBody,
#[error("Invalid request")]
InvalidRequest,
#[error("Invalid websocket response status code")]
WsInvalidStatusCode,
#[error("Invalid websocket upgrade header")]
WsInvalidUpgradeHeader,
#[error("Invalid websocket connection header")]
WsInvalidConnectionHeader,
#[error("Invalid websocket payload")]
WsInvalidPayload,
#[error("Invalid payload")]
InvalidPayload,
#[error("Invalid certificate store")]
InvalidCertStore,
#[error("WebSocket failed to connect")]
WebSocketConnectFailed,
#[error("Failed to construct response headers object")]
ResponseHeadersFromEntriesFailed,
#[error("Failed to construct response object")]
ResponseNewFailed,
#[error("Failed to construct define_property object")]
DefinePropertyObjFailed,
#[error("Failed to set raw header item")]
RawHeaderSetFailed,
}
impl From<EpoxyError> for JsValue {
fn from(value: EpoxyError) -> Self {
JsError::from(value).into()
}
}
impl From<InvalidUri> for EpoxyError {
fn from(value: InvalidUri) -> Self {
http::Error::from(value).into()
}
}
impl From<InvalidUriParts> for EpoxyError {
fn from(value: InvalidUriParts) -> Self {
http::Error::from(value).into()
}
}
impl From<InvalidHeaderName> for EpoxyError {
fn from(value: InvalidHeaderName) -> Self {
http::Error::from(value).into()
}
}
impl From<InvalidHeaderValue> for EpoxyError {
fn from(value: InvalidHeaderValue) -> Self {
http::Error::from(value).into()
}
}
impl From<InvalidMethod> for EpoxyError {
fn from(value: InvalidMethod) -> Self {
http::Error::from(value).into()
}
}
#[derive(Debug)] #[derive(Debug)]
enum EpxResponse { enum EpoxyResponse {
Success(Response<Incoming>), Success(Response<Incoming>),
Redirect((Response<Incoming>, http::Request<HttpBody>)), Redirect((Response<Incoming>, http::Request<HttpBody>)),
} }
enum EpxCompression { enum EpoxyCompression {
Brotli, Brotli,
Gzip, Gzip,
} }
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>; #[wasm_bindgen]
type EpxIoTlsStream = TlsStream<EpxIoUnencryptedStream>; pub struct EpoxyClientOptions {
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>; pub wisp_v2: bool,
pub udp_extension_required: bool,
#[wasm_bindgen(start)] #[wasm_bindgen(getter_with_clone)]
fn init() { pub websocket_protocols: Vec<String>,
console_error_panic_hook::set_once(); pub redirect_limit: usize,
// utils.rs #[wasm_bindgen(getter_with_clone)]
intern("value"); pub user_agent: String,
intern("writable");
intern("POST");
// main.rs
intern("method");
intern("redirect");
intern("body");
intern("headers");
intern("url");
intern("redirected");
intern("rawHeaders");
intern("Content-Type");
} }
#[wasm_bindgen(inspectable)] #[wasm_bindgen]
impl EpoxyClientOptions {
#[wasm_bindgen(constructor)]
pub fn new_default() -> Self {
Self::default()
}
}
impl Default for EpoxyClientOptions {
fn default() -> Self {
Self {
wisp_v2: true,
udp_extension_required: true,
websocket_protocols: Vec::new(),
redirect_limit: 10,
user_agent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36".to_string(),
}
}
}
#[wasm_bindgen(getter_with_clone)]
pub struct EpoxyHandlers {
pub onopen: Function,
pub onclose: Function,
pub onerror: Function,
pub onmessage: Function,
}
#[wasm_bindgen]
impl EpoxyHandlers {
#[wasm_bindgen(constructor)]
pub fn new(
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
) -> Self {
Self {
onopen,
onclose,
onerror,
onmessage,
}
}
}
#[wasm_bindgen]
pub struct EpoxyClient { pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>, stream_provider: Arc<StreamProvider>,
mux: Arc<RwLock<ClientMux>>, client: Client<StreamProviderService, HttpBody>,
hyper_client: Client<TlsWispService, HttpBody>,
#[wasm_bindgen(getter_with_clone)] redirect_limit: usize,
pub useragent: String, user_agent: String,
#[wasm_bindgen(js_name = "redirectLimit")]
pub redirect_limit: usize,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyClient { impl EpoxyClient {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub async fn new( pub fn new(
ws_url: String, wisp_url: String,
useragent: String,
redirect_limit: usize,
certs: Array, certs: Array,
) -> Result<EpoxyClient, JsError> { options: EpoxyClientOptions,
let ws_uri = ws_url ) -> Result<EpoxyClient, EpoxyError> {
.parse::<uri::Uri>() let wisp_url: Uri = wisp_url.try_into()?;
.replace_err("Failed to parse websocket url")?; if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") {
return Err(EpoxyError::InvalidUrlScheme);
let ws_uri_scheme = ws_uri
.scheme_str()
.replace_err("Websocket URL must have a scheme")?;
if ws_uri_scheme != "ws" && ws_uri_scheme != "wss" {
return Err(JsError::new("Scheme must be either `ws` or `wss`"));
} }
let (mux, fut) = utils::make_mux(&ws_url).await?; let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), certs, &options)?);
let mux = Arc::new(RwLock::new(mux));
utils::spawn_mux_fut(mux.clone(), fut, ws_url.clone());
let mut certstore = RootCertStore::empty(); let service = StreamProviderService(stream_provider.clone());
let certs: Result<Vec<TrustAnchor>, JsValue> = let client = Client::builder(WasmExecutor)
certs.iter().map(object_to_trustanchor).collect();
certstore.extend(
certs
.replace_err("Failed to get certificates from cert store")?
.into_iter(),
);
let rustls_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(certstore)
.with_no_client_auth(),
);
Ok(EpoxyClient {
mux: mux.clone(),
hyper_client: Client::builder(utils::WasmExecutor {})
.http09_responses(true) .http09_responses(true)
.http1_title_case_headers(true) .http1_title_case_headers(true)
.http1_preserve_header_case(true) .http1_preserve_header_case(true)
.build(TlsWispService { .build(service);
rustls_config: rustls_config.clone(),
service: ServiceWrapper(mux, ws_url), Ok(Self {
}), stream_provider,
rustls_config, client,
useragent, redirect_limit: options.redirect_limit,
redirect_limit, user_agent: options.user_agent,
}) })
} }
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> { pub async fn connect_websocket(
let channel = self &self,
.mux handlers: EpoxyHandlers,
.write() url: String,
protocols: Vec<String>,
) -> Result<EpoxyWebSocket, EpoxyError> {
EpoxyWebSocket::connect(self, handlers, url, protocols).await
}
pub async fn connect_tcp(
&self,
handlers: EpoxyHandlers,
url: String,
) -> Result<EpoxyIoStream, EpoxyError> {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self
.stream_provider
.get_asyncread(StreamType::Tcp, host.to_string(), port)
.await .await
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) {
Ok(stream) => Ok(EpoxyIoStream::connect(Either::Right(stream), handlers)),
Err(err) => {
let _ = handlers
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn connect_tls(
&self,
handlers: EpoxyHandlers,
url: String,
) -> Result<EpoxyIoStream, EpoxyError> {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self
.stream_provider
.get_tls_stream(host.to_string(), port)
.await .await
.replace_err("Failed to create multiplexor channel")? {
.into_io() Ok(stream) => Ok(EpoxyIoStream::connect(Either::Left(stream), handlers)),
.into_asyncrw(); Err(err) => {
let connector = TlsConnector::from(self.rustls_config.clone()); let _ = handlers
let io = connector .onerror
.connect( .call1(&JsValue::null(), &err.to_string().into());
url_host Err(err)
.to_string() }
.try_into() }
.replace_err("Failed to parse URL (rustls)")?, }
channel,
) pub async fn connect_udp(
&self,
handlers: EpoxyHandlers,
url: String,
) -> Result<EpoxyUdpStream, EpoxyError> {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self
.stream_provider
.get_stream(StreamType::Udp, host.to_string(), port)
.await .await
.replace_err("Failed to perform TLS handshake")?; {
Ok(io) Ok(stream) => Ok(EpoxyUdpStream::connect(stream, handlers)),
Err(err) => {
let _ = handlers
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
} }
async fn send_req_inner( async fn send_req_inner(
&self, &self,
req: http::Request<HttpBody>, req: http::Request<HttpBody>,
should_redirect: bool, should_redirect: bool,
) -> Result<EpxResponse, JsError> { ) -> Result<EpoxyResponse, EpoxyError> {
let new_req = if should_redirect { let new_req = if should_redirect {
Some(req.clone()) Some(req.clone())
} else { } else {
None None
}; };
let res = self let res = self.client.request(req).await;
.hyper_client
.request(req)
.await
.replace_err("Failed to send request");
match res { match res {
Ok(res) => { Ok(res) => {
if utils::is_redirect(res.status().as_u16()) if is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req && let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location") && let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location) && let Ok(redirect_url) = new_req.uri().get_redirect(location)
@ -190,12 +338,12 @@ impl EpoxyClient {
"Host", "Host",
HeaderValue::from_str(redirect_url_authority.as_str())?, HeaderValue::from_str(redirect_url_authority.as_str())?,
); );
Ok(EpxResponse::Redirect((res, new_req))) Ok(EpoxyResponse::Redirect((res, new_req)))
} else { } else {
Ok(EpxResponse::Success(res)) Ok(EpoxyResponse::Success(res))
} }
} }
Err(err) => Err(err), Err(err) => Err(err.into()),
} }
} }
@ -203,14 +351,14 @@ impl EpoxyClient {
&self, &self,
req: http::Request<HttpBody>, req: http::Request<HttpBody>,
should_redirect: bool, should_redirect: bool,
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> { ) -> Result<(hyper::Response<Incoming>, Uri, bool), EpoxyError> {
let mut redirected = false; let mut redirected = false;
let mut current_url = req.uri().clone(); let mut current_url = req.uri().clone();
let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?; let mut current_resp: EpoxyResponse = self.send_req_inner(req, should_redirect).await?;
for _ in 0..self.redirect_limit { for _ in 0..self.redirect_limit {
match current_resp { match current_resp {
EpxResponse::Success(_) => break, EpoxyResponse::Success(_) => break,
EpxResponse::Redirect((_, req)) => { EpoxyResponse::Redirect((_, req)) => {
redirected = true; redirected = true;
current_url = req.uri().clone(); current_url = req.uri().clone();
current_resp = self.send_req_inner(req, should_redirect).await? current_resp = self.send_req_inner(req, should_redirect).await?
@ -219,109 +367,75 @@ impl EpoxyClient {
} }
match current_resp { match current_resp {
EpxResponse::Success(resp) => Ok((resp, current_url, redirected)), EpoxyResponse::Success(resp) => Ok((resp, current_url, redirected)),
EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), EpoxyResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)),
} }
} }
// shut up pub async fn fetch(
#[allow(clippy::too_many_arguments)]
pub async fn connect_ws(
&self, &self,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String, url: String,
protocols: Vec<String>, options: Object,
origin: String, ) -> Result<web_sys::Response, EpoxyError> {
) -> Result<EpxWebSocket, JsError> { let url: Uri = url.try_into()?;
EpxWebSocket::connect( // only valid `Scheme`s are HTTP and HTTPS, which are the ones we support
self, onopen, onclose, onerror, onmessage, url, protocols, origin, url.scheme().ok_or(EpoxyError::InvalidUrlScheme)?;
)
.await
}
pub async fn connect_tls( let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
&self,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxTlsStream, JsError> {
EpxTlsStream::connect(self, onopen, onclose, onerror, onmessage, url).await
}
pub async fn connect_udp( let request_method = object_get(&options, "method")
&self, .and_then(|x| x.as_string())
onopen: Function, .unwrap_or_else(|| "GET".to_string());
onclose: Function, let request_method: Method = Method::from_str(&request_method)?;
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxUdpStream, JsError> {
EpxUdpStream::connect(self, onopen, onclose, onerror, onmessage, url).await
}
pub async fn fetch(&self, url: String, options: Object) -> Result<web_sys::Response, JsError> { let request_redirect = object_get(&options, "redirect")
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?; .map(|x| {
let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?; !matches!(
if *uri_scheme != uri::Scheme::HTTP && *uri_scheme != uri::Scheme::HTTPS { x.as_string().unwrap_or_default().as_str(),
return Err(jerr!("Scheme must be either `http` or `https`"));
}
let uri_host = uri.host().replace_err("URL must have a host")?;
let req_method_string: String = match Reflect::get(&options, &jval!("method")) {
Ok(val) => val.as_string().unwrap_or("GET".to_string()),
Err(_) => "GET".to_string(),
};
let req_method: http::Method = http::Method::try_from(req_method_string.as_str())
.replace_err("Invalid http method")?;
let req_should_redirect = match Reflect::get(&options, &jval!("redirect")) {
Ok(val) => !matches!(
val.as_string().unwrap_or_default().as_str(),
"error" | "manual" "error" | "manual"
), )
Err(_) => true, })
}; .unwrap_or(true);
let mut body_content_type: Option<String> = None; let mut body_content_type: Option<String> = None;
let body_jsvalue: Option<JsValue> = Reflect::get(&options, &jval!("body")).ok(); let body = match object_get(&options, "body") {
let body_bytes: Bytes = match body_jsvalue {
Some(buf) => { Some(buf) => {
let (body, req) = utils::jval_to_u8_array_req(buf) let (body, req) = convert_body(buf)
.await .await
.replace_err("Invalid body")?; .map_err(|_| EpoxyError::InvalidRequestBody)?;
body_content_type = req.headers().get("Content-Type").ok().flatten(); body_content_type = req.headers().get("Content-Type").ok().flatten();
Bytes::from(body.to_vec()) Bytes::from(body.to_vec())
} }
None => Bytes::new(), None => Bytes::new(),
}; };
let headers = Reflect::get(&options, &jval!("headers")) let headers = object_get(&options, "headers").and_then(|val| {
.map(|val| {
if web_sys::Headers::instanceof(&val) { if web_sys::Headers::instanceof(&val) {
Some(utils::entries_of_object(&Object::from_entries(&val).ok()?)) Some(entries_of_object(&Object::from_entries(&val).ok()?))
} else if val.is_truthy() { } else if val.is_truthy() {
Some(utils::entries_of_object(&Object::from(val))) Some(entries_of_object(&Object::from(val)))
} else { } else {
None None
} }
}) });
.unwrap_or(None);
let mut builder = Request::builder().uri(uri.clone()).method(req_method); let mut request_builder = Request::builder().uri(url.clone()).method(request_method);
let headers_map = builder.headers_mut().replace_err("Failed to get headers")?; // Generic InvalidRequest because this only returns None if the builder has some error
headers_map.insert("Accept-Encoding", HeaderValue::from_static("gzip, br")); // which we don't know
let headers_map = request_builder
.headers_mut()
.ok_or(EpoxyError::InvalidRequest)?;
headers_map.insert("Accept-Encoding", HeaderValue::from_static("identity"));
headers_map.insert("Connection", HeaderValue::from_static("keep-alive")); headers_map.insert("Connection", HeaderValue::from_static("keep-alive"));
headers_map.insert("User-Agent", HeaderValue::from_str(&self.useragent)?); headers_map.insert("User-Agent", HeaderValue::from_str(&self.user_agent)?);
headers_map.insert("Host", HeaderValue::from_str(uri_host)?); headers_map.insert("Host", HeaderValue::from_str(host)?);
if body_bytes.is_empty() {
if body.is_empty() {
headers_map.insert("Content-Length", HeaderValue::from_static("0")); headers_map.insert("Content-Length", HeaderValue::from_static("0"));
} }
if let Some(content_type) = body_content_type { if let Some(content_type) = body_content_type {
headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?); headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?);
} }
@ -329,122 +443,107 @@ impl EpoxyClient {
if let Some(headers) = headers { if let Some(headers) = headers {
for hdr in headers { for hdr in headers {
headers_map.insert( headers_map.insert(
HeaderName::from_bytes(hdr[0].as_bytes()) HeaderName::from_str(&hdr[0])?,
.replace_err("Failed to get hdr name")?, HeaderValue::from_str(&hdr[1])?,
HeaderValue::from_bytes(hdr[1].as_bytes())
.replace_err("Failed to get hdr value")?,
); );
} }
} }
let request = builder let (response, response_uri, redirected) = self
.body(HttpBody::new(body_bytes)) .send_req(request_builder.body(HttpBody::new(body))?, request_redirect)
.replace_err("Failed to make request")?; .await?;
let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?; let response_headers: Array = response
let resp_headers_raw = resp.headers().clone();
let resp_headers_jsarray = resp
.headers() .headers()
.iter() .iter()
.filter_map(|val| { .filter_map(|val| {
Some(Array::of2( Some(Array::of2(
&jval!(val.0.as_str()), &val.0.as_str().into(),
&jval!(val.1.to_str().ok()?), &val.1.to_str().ok()?.into(),
)) ))
}) })
.collect::<Array>(); .collect();
let response_headers = Object::from_entries(&response_headers)
.map_err(|_| EpoxyError::ResponseHeadersFromEntriesFailed)?;
let resp_headers = Object::from_entries(&resp_headers_jsarray) let response_headers_raw = response.headers().clone();
.replace_err("Failed to create response headers object")?;
let mut respinit = web_sys::ResponseInit::new(); let mut response_builder = ResponseInit::new();
respinit response_builder
.headers(&resp_headers) .headers(&response_headers)
.status(resp.status().as_u16()) .status(response.status().as_u16())
.status_text(resp.status().canonical_reason().unwrap_or_default()); .status_text(response.status().canonical_reason().unwrap_or_default());
let stream = if !utils::is_null_body(resp.status().as_u16()) { let response_stream = if !is_null_body(response.status().as_u16()) {
let compression = match resp let compression = match response
.headers() .headers()
.get("Content-Encoding") .get("Content-Encoding")
.and_then(|val| val.to_str().ok()) .and_then(|val| val.to_str().ok())
.unwrap_or_default() .unwrap_or_default()
{ {
"gzip" => Some(EpxCompression::Gzip), "gzip" => Some(EpoxyCompression::Gzip),
"br" => Some(EpxCompression::Brotli), "br" => Some(EpoxyCompression::Brotli),
_ => None, _ => None,
}; };
let incoming_body = IncomingBody::new(resp.into_body()); let response_body = IncomingBody::new(response.into_body()).into_async_read();
let decompressed_body = match compression { let decompressed_body = match compression {
Some(alg) => match alg { Some(alg) => match alg {
EpxCompression::Gzip => Either::Left(Either::Left(ReaderStream::new( EpoxyCompression::Gzip => {
async_comp::GzipDecoder::new(StreamReader::new(incoming_body)), Either::Left(Either::Left(async_comp::GzipDecoder::new(response_body)))
))), }
EpxCompression::Brotli => Either::Left(Either::Right(ReaderStream::new( EpoxyCompression::Brotli => {
async_comp::BrotliDecoder::new(StreamReader::new(incoming_body)), Either::Left(Either::Right(async_comp::BrotliDecoder::new(response_body)))
))), }
}, },
None => Either::Right(incoming_body), None => Either::Right(response_body),
}; };
Some( Some(ReadableStream::from_async_read(decompressed_body, 1024).into_raw())
wasm_streams::ReadableStream::from_stream(decompressed_body.map(|x| {
Ok(Uint8Array::from(
x.replace_err_jv("Failed to get frame from response")?
.as_ref(),
)
.into())
}))
.into_raw(),
)
} else { } else {
None None
}; };
let resp = let resp = web_sys::Response::new_with_opt_readable_stream_and_init(
web_sys::Response::new_with_opt_readable_stream_and_init(stream.as_ref(), &respinit) response_stream.as_ref(),
.replace_err("Failed to make response")?; &response_builder,
)
.map_err(|_| EpoxyError::ResponseNewFailed)?;
Object::define_property( Object::define_property(
&resp, &resp,
&jval!("url"), &"url".into(),
&utils::define_property_obj(jval!(resp_uri.to_string()), false) &utils::define_property_obj(response_uri.to_string().into(), false)
.replace_err("Failed to make define_property object for url")?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
Object::define_property( Object::define_property(
&resp, &resp,
&jval!("redirected"), &"redirected".into(),
&utils::define_property_obj(jval!(req_redirected), false) &utils::define_property_obj(redirected.into(), false)
.replace_err("Failed to make define_property object for redirected")?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
let raw_headers = Object::new(); let raw_headers = Object::new();
for (k, v) in resp_headers_raw.iter() { for (k, v) in response_headers_raw.iter() {
let k = jval!(k.to_string()); let k: JsValue = k.to_string().into();
let v = jval!(v.to_str()?.to_string()); let v: JsValue = v.to_str()?.to_string().into();
if let Ok(jv) = Reflect::get(&raw_headers, &k) { if let Ok(jv) = Reflect::get(&raw_headers, &k) {
if jv.is_array() { if jv.is_array() {
let arr = Array::from(&jv); let arr = Array::from(&jv);
arr.push(&v); arr.push(&v);
Reflect::set(&raw_headers, &k, &arr).flatten("Failed to set rawHeader")?; object_set(&raw_headers, &k, &arr)?;
} else if jv.is_truthy() { } else if jv.is_truthy() {
Reflect::set(&raw_headers, &k, &Array::of2(&jv, &v)) object_set(&raw_headers, &k, &Array::of2(&jv, &v))?;
.flatten("Failed to set rawHeader")?;
} else { } else {
Reflect::set(&raw_headers, &k, &v).flatten("Failed to set rawHeader")?; object_set(&raw_headers, &k, &v)?;
} }
} }
} }
Object::define_property( Object::define_property(
&resp, &resp,
&jval!("rawHeaders"), &"rawHeaders".into(),
&utils::define_property_obj(jval!(&raw_headers), false) &utils::define_property_obj(raw_headers.into(), false)
.replace_err("Failed to make define_property object for rawHeaders")?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
Ok(resp) Ok(resp)

View file

@ -0,0 +1,247 @@
use std::{pin::Pin, sync::Arc, task::Poll};
use futures_rustls::{
rustls::{ClientConfig, RootCertStore},
TlsConnector, TlsStream,
};
use futures_util::{future::Either, lock::Mutex, AsyncRead, AsyncWrite, Future};
use hyper_util_wasm::client::legacy::connect::{Connected, Connection};
use js_sys::{Array, Reflect, Uint8Array};
use pin_project_lite::pin_project;
use rustls_pki_types::{Der, TrustAnchor};
use tower_service::Service;
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::spawn_local;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, IoStream, MuxStreamIo, StreamType, WispError,
};
use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
fn object_to_trustanchor(obj: JsValue) -> Result<TrustAnchor<'static>, JsValue> {
let subject: Uint8Array = Reflect::get(&obj, &"subject".into())?.dyn_into()?;
let pub_key_info: Uint8Array =
Reflect::get(&obj, &"subject_public_key_info".into())?.dyn_into()?;
let name_constraints: Option<Uint8Array> = Reflect::get(&obj, &"name_constraints".into())
.and_then(|x| x.dyn_into())
.ok();
Ok(TrustAnchor {
subject: Der::from(subject.to_vec()),
subject_public_key_info: Der::from(pub_key_info.to_vec()),
name_constraints: name_constraints.map(|x| Der::from(x.to_vec())),
})
}
pub struct StreamProvider {
wisp_url: String,
wisp_v2: bool,
udp_extension: bool,
websocket_protocols: Vec<String>,
client_config: Arc<ClientConfig>,
current_client: Arc<Mutex<Option<ClientMux>>>,
}
pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = IoStream<ProviderUnencryptedStream, Vec<u8>>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
impl StreamProvider {
pub fn new(
wisp_url: String,
certs: Array,
options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> {
let certs: Result<Vec<TrustAnchor>, JsValue> =
certs.iter().map(object_to_trustanchor).collect();
let certstore = RootCertStore::from_iter(certs.map_err(|_| EpoxyError::InvalidCertStore)?);
let client_config = Arc::new(
ClientConfig::builder()
.with_root_certificates(certstore)
.with_no_client_auth(),
);
Ok(Self {
wisp_url,
current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required,
websocket_protocols: options.websocket_protocols.clone(),
client_config,
})
}
async fn create_client(&self) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder())];
let extensions = if self.wisp_v2 {
Some(extensions_vec.as_slice())
} else {
None
};
let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?;
if !write.wait_for_open().await {
return Err(EpoxyError::WebSocketConnectFailed);
}
let client = ClientMux::create(read, write, extensions).await?;
let (mux, fut) = if self.udp_extension {
client.with_udp_extension_required().await?
} else {
client.with_no_required_extensions()
};
self.current_client.lock().await.replace(mux);
let current_client = self.current_client.clone();
spawn_local(async move {
fut.await;
current_client.lock().await.take();
});
Ok(())
}
pub async fn get_stream(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<ProviderUnencryptedStream, EpoxyError> {
Box::pin(async {
if let Some(mux) = self.current_client.lock().await.as_ref() {
Ok(mux
.client_new_stream(stream_type, host, port)
.await?
.into_io())
} else {
self.create_client().await?;
self.get_stream(stream_type, host, port).await
}
})
.await
}
pub async fn get_asyncread(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self
.get_stream(stream_type, host, port)
.await?
.into_asyncrw())
}
pub async fn get_tls_stream(
&self,
host: String,
port: u16,
) -> Result<ProviderTlsAsyncRW, EpoxyError> {
let stream = self
.get_asyncread(StreamType::Tcp, host.clone(), port)
.await?;
let connector = TlsConnector::from(self.client_config.clone());
Ok(connector.connect(host.try_into()?, stream).await?.into())
}
}
pin_project! {
pub struct HyperIo {
#[pin]
inner: ProviderAsyncRW,
}
}
impl hyper::rt::Read for HyperIo {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let buf_slice: &mut [u8] = unsafe { std::mem::transmute(buf.as_mut()) };
match self.project().inner.poll_read(cx, buf_slice) {
Poll::Ready(bytes_read) => {
let bytes_read = bytes_read?;
unsafe {
buf.advance(bytes_read);
}
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
}
impl hyper::rt::Write for HyperIo {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}
impl Connection for HyperIo {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[derive(Clone)]
pub struct StreamProviderService(pub Arc<StreamProvider>);
impl Service<hyper::Uri> for StreamProviderService {
type Response = HyperIo;
type Error = EpoxyError;
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Uri) -> Self::Future {
let provider = self.0.clone();
Box::pin(async move {
let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?;
let host = req.host().ok_or(WispError::UriHasNoHost)?.to_string();
let port = req.port_u16().ok_or(WispError::UriHasNoPort)?;
Ok(HyperIo {
inner: match scheme {
"https" => Either::Left(provider.get_tls_stream(host, port).await?),
"http" => {
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
}
_ => return Err(EpoxyError::InvalidUrlScheme),
},
})
})
}
}

View file

@ -1,92 +0,0 @@
use crate::*;
use tokio::io::{split, AsyncWriteExt, WriteHalf};
#[wasm_bindgen(inspectable)]
pub struct EpxTlsStream {
tx: WriteHalf<EpxIoTlsStream>,
onerror: Function,
#[wasm_bindgen(readonly, getter_with_clone)]
pub url: String,
}
#[wasm_bindgen]
impl EpxTlsStream {
#[wasm_bindgen(constructor)]
pub fn new() -> Result<EpxTlsStream, JsError> {
Err(jerr!("Use EpoxyClient.connect_tls() instead."))
}
pub async fn connect(
tcp: &EpoxyClient,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxTlsStream, JsError> {
let onerr = onerror.clone();
let ret: Result<EpxTlsStream, JsError> = async move {
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = url.port().replace_err("URL must have a port")?.into();
let io = tcp.get_tls_io(url_host, url_port).await?;
let (rx, tx) = split(io);
let mut rx = ReaderStream::new(rx);
wasm_bindgen_futures::spawn_local(async move {
while let Some(Ok(data)) = rx.next().await {
let _ = onmessage.call1(
&JsValue::null(),
&jval!(Uint8Array::from(data.to_vec().as_slice())),
);
}
let _ = onclose.call0(&JsValue::null());
});
onopen
.call0(&Object::default())
.replace_err("Failed to call onopen")?;
Ok(Self {
tx,
onerror,
url: url.to_string(),
})
}
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone()));
Err(ret)
} else {
ret
}
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: JsValue) -> Result<(), JsError> {
let onerr = self.onerror.clone();
let ret = self
.tx
.write_all(
&utils::jval_to_u8_array(payload)
.await
.replace_err("Invalid payload")?
.to_vec(),
)
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(format!("{}", ret)));
Err(ret.into())
} else {
Ok(ret?)
}
}
#[wasm_bindgen]
pub async fn close(&mut self) -> Result<(), JsError> {
self.tx.shutdown().await?;
Ok(())
}
}

View file

@ -167,9 +167,3 @@ where
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
} }
} }
impl<T> hyper_util_wasm::client::legacy::connect::Connection for TokioIo<T> {
fn connected(&self) -> hyper_util_wasm::client::legacy::connect::Connected {
hyper_util_wasm::client::legacy::connect::Connected::new()
}
}

View file

@ -1,98 +0,0 @@
use crate::*;
use futures_util::{stream::SplitSink, SinkExt};
#[wasm_bindgen(inspectable)]
pub struct EpxUdpStream {
tx: SplitSink<MuxStreamIo, Vec<u8>>,
onerror: Function,
#[wasm_bindgen(readonly, getter_with_clone)]
pub url: String,
}
#[wasm_bindgen]
impl EpxUdpStream {
#[wasm_bindgen(constructor)]
pub fn new() -> Result<EpxUdpStream, JsError> {
Err(jerr!("Use EpoxyClient.connect_udp() instead."))
}
pub async fn connect(
tcp: &EpoxyClient,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxUdpStream, JsError> {
let onerr = onerror.clone();
let ret: Result<EpxUdpStream, JsError> = async move {
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = url.port().replace_err("URL must have a port")?.into();
let io = tcp
.mux
.write()
.await
.client_new_stream(StreamType::Udp, url_host.to_string(), url_port)
.await
.replace_err("Failed to open multiplexor channel")?
.into_io();
let (tx, mut rx) = io.split();
wasm_bindgen_futures::spawn_local(async move {
while let Some(Ok(data)) = rx.next().await {
let _ = onmessage.call1(
&JsValue::null(),
&jval!(Uint8Array::from(data.to_vec().as_slice())),
);
}
let _ = onclose.call0(&JsValue::null());
});
onopen
.call0(&Object::default())
.replace_err("Failed to call onopen")?;
Ok(Self {
tx,
onerror,
url: url.to_string(),
})
}
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone()));
Err(ret)
} else {
ret
}
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: JsValue) -> Result<(), JsError> {
let onerr = self.onerror.clone();
let ret = self
.tx
.send(
utils::jval_to_u8_array(payload)
.await
.replace_err("Invalid payload")?
.to_vec(),
)
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(format!("{}", ret)));
Err(ret.into())
} else {
Ok(ret?)
}
}
#[wasm_bindgen]
pub async fn close(&mut self) -> Result<(), JsError> {
self.tx.close().await?;
Ok(())
}
}

View file

@ -1,118 +1,25 @@
use crate::*; use std::{
pin::Pin,
task::{Context, Poll},
};
use rustls_pki_types::Der; use bytes::Bytes;
use wasm_bindgen::prelude::*; use futures_util::{Future, Stream};
use http::{HeaderValue, Uri};
use hyper::{body::Body, rt::Executor};
use js_sys::{Array, ArrayBuffer, Object, Reflect, Uint8Array};
use pin_project_lite::pin_project;
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture; use wasm_bindgen_futures::JsFuture;
use hyper::rt::Executor; use crate::EpoxyError;
use js_sys::ArrayBuffer;
use std::future::Future;
use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, WispError};
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console, js_name = debug)]
pub fn console_debug(s: &str);
#[wasm_bindgen(js_namespace = console, js_name = log)]
pub fn console_log(s: &str);
#[wasm_bindgen(js_namespace = console, js_name = error)]
pub fn console_error(s: &str);
}
macro_rules! debug {
($($t:tt)*) => (utils::console_debug(&format_args!($($t)*).to_string()))
}
macro_rules! log {
($($t:tt)*) => (utils::console_log(&format_args!($($t)*).to_string()))
}
#[allow(unused_macros)]
macro_rules! error {
($($t:tt)*) => (utils::console_error(&format_args!($($t)*).to_string()))
}
macro_rules! jerr {
($expr:expr) => {
JsError::new($expr)
};
}
macro_rules! jval {
($expr:expr) => {
JsValue::from($expr)
};
}
pub trait ReplaceErr {
type Ok;
fn replace_err(self, err: &str) -> Result<Self::Ok, JsError>;
fn replace_err_jv(self, err: &str) -> Result<Self::Ok, JsValue>;
}
impl<T, E: std::fmt::Debug> ReplaceErr for Result<T, E> {
type Ok = T;
fn replace_err(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsError> {
self.map_err(|x| jerr!(&format!("{}, original error: {:?}", err, x)))
}
fn replace_err_jv(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsValue> {
self.map_err(|x| jval!(&format!("{}, original error: {:?}", err, x)))
}
}
impl<T> ReplaceErr for Option<T> {
type Ok = T;
fn replace_err(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsError> {
self.ok_or_else(|| jerr!(err))
}
fn replace_err_jv(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsValue> {
self.ok_or_else(|| jval!(err))
}
}
// the... BOOLINATOR!
impl ReplaceErr for bool {
type Ok = ();
fn replace_err(self, err: &str) -> Result<(), JsError> {
if !self {
Err(jerr!(err))
} else {
Ok(())
}
}
fn replace_err_jv(self, err: &str) -> Result<(), JsValue> {
if !self {
Err(jval!(err))
} else {
Ok(())
}
}
}
// the... BOOLINATOR!
pub trait Boolinator {
fn flatten(self, err: &str) -> Result<(), JsError>;
}
impl Boolinator for Result<bool, JsValue> {
fn flatten(self, err: &str) -> Result<(), JsError> {
self.replace_err(err)?.replace_err(err)
}
}
pub trait UriExt { pub trait UriExt {
fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, JsError>; fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, EpoxyError>;
} }
impl UriExt for Uri { impl UriExt for Uri {
fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, JsError> { fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, EpoxyError> {
let new_uri = location.to_str()?.parse::<hyper::Uri>()?; let new_uri = location.to_str()?.parse::<hyper::Uri>()?;
let mut new_parts: http::uri::Parts = new_uri.into(); let mut new_parts: http::uri::Parts = new_uri.into();
if new_parts.scheme.is_none() { if new_parts.scheme.is_none() {
@ -141,8 +48,75 @@ where
} }
} }
pin_project! {
pub struct IncomingBody {
#[pin]
incoming: hyper::body::Incoming,
}
}
impl IncomingBody {
pub fn new(incoming: hyper::body::Incoming) -> IncomingBody {
IncomingBody { incoming }
}
}
impl Stream for IncomingBody {
type Item = std::io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let ret = this.incoming.poll_frame(cx);
match ret {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item {
Some(frame) => frame
.map(|x| {
x.into_data()
.map_err(|_| std::io::Error::other("not data frame"))
})
.ok(),
None => None,
}),
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
}
}
}
pub fn is_redirect(code: u16) -> bool {
[301, 302, 303, 307, 308].contains(&code)
}
pub fn is_null_body(code: u16) -> bool {
[101, 204, 205, 304].contains(&code)
}
pub fn object_get(obj: &Object, key: &str) -> Option<JsValue> {
Reflect::get(obj, &key.into()).ok()
}
pub fn object_set(obj: &Object, key: &JsValue, value: &JsValue) -> Result<(), EpoxyError> {
if Reflect::set(obj, key, value).map_err(|_| EpoxyError::RawHeaderSetFailed)? {
Ok(())
} else {
Err(EpoxyError::RawHeaderSetFailed)
}
}
pub async fn convert_body(val: JsValue) -> Result<(Uint8Array, web_sys::Request), JsValue> {
let req = web_sys::Request::new_with_str_and_init(
"/",
web_sys::RequestInit::new().method("POST").body(Some(&val)),
)?;
Ok((
JsFuture::from(req.array_buffer()?)
.await?
.dyn_into::<ArrayBuffer>()
.map(|x| Uint8Array::new(&x))?,
req,
))
}
pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> { pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
js_sys::Object::entries(obj) Object::entries(obj)
.to_vec() .to_vec()
.iter() .iter()
.filter_map(|val| { .filter_map(|val| {
@ -157,124 +131,10 @@ pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
pub fn define_property_obj(value: JsValue, writable: bool) -> Result<Object, JsValue> { pub fn define_property_obj(value: JsValue, writable: bool) -> Result<Object, JsValue> {
let entries: Array = [ let entries: Array = [
Array::of2(&jval!("value"), &value), Array::of2(&"value".into(), &value),
Array::of2(&jval!("writable"), &jval!(writable)), Array::of2(&"writable".into(), &writable.into()),
] ]
.iter() .iter()
.collect::<Array>(); .collect::<Array>();
Object::from_entries(&entries) Object::from_entries(&entries)
} }
pub fn is_redirect(code: u16) -> bool {
[301, 302, 303, 307, 308].contains(&code)
}
pub fn is_null_body(code: u16) -> bool {
[101, 204, 205, 304].contains(&code)
}
pub fn get_is_secure(url: &Uri) -> Result<bool, JsError> {
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
match url_scheme_str {
"https" | "wss" => Ok(true),
_ => Ok(false),
}
}
pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
if let Some(port) = url.port() {
Ok(port.as_u16())
} else if get_is_secure(url)? {
Ok(443)
} else {
Ok(80)
}
}
pub async fn make_mux(
url: &str,
) -> Result<
(
ClientMux,
impl Future<Output = Result<(), WispError>> + Send,
),
WispError,
> {
let (wtx, wrx) =
WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await;
Ok(
ClientMux::create(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())]))
.await?
.with_no_required_extensions(),
)
}
pub fn spawn_mux_fut(
mux: Arc<RwLock<ClientMux>>,
fut: impl Future<Output = Result<(), WispError>> + Send + 'static,
url: String,
) {
wasm_bindgen_futures::spawn_local(async move {
debug!("epoxy: mux future started");
if let Err(e) = fut.await {
log!("epoxy: error in mux future, restarting: {:?}", e);
while let Err(e) = replace_mux(mux.clone(), &url).await {
log!("epoxy: failed to restart mux future: {:?}", e);
wasmtimer::tokio::sleep(std::time::Duration::from_millis(500)).await;
}
}
debug!("epoxy: mux future exited");
});
}
pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await;
let _ = mux_write.close().await;
*mux_write = mux_replace;
drop(mux_write);
spawn_mux_fut(mux, fut, url.into());
Ok(())
}
pub async fn jval_to_u8_array(val: JsValue) -> Result<Uint8Array, JsValue> {
JsFuture::from(
web_sys::Request::new_with_str_and_init(
"/",
web_sys::RequestInit::new().method("POST").body(Some(&val)),
)?
.array_buffer()?,
)
.await?
.dyn_into::<ArrayBuffer>()
.map(|x| Uint8Array::new(&x))
}
pub async fn jval_to_u8_array_req(val: JsValue) -> Result<(Uint8Array, web_sys::Request), JsValue> {
let req = web_sys::Request::new_with_str_and_init(
"/",
web_sys::RequestInit::new().method("POST").body(Some(&val)),
)?;
Ok((
JsFuture::from(req.array_buffer()?)
.await?
.dyn_into::<ArrayBuffer>()
.map(|x| Uint8Array::new(&x))?,
req,
))
}
pub fn object_to_trustanchor(obj: JsValue) -> Result<TrustAnchor<'static>, JsValue> {
let subject: Uint8Array = Reflect::get(&obj, &jval!("subject"))?.dyn_into()?;
let pub_key_info: Uint8Array =
Reflect::get(&obj, &jval!("subject_public_key_info"))?.dyn_into()?;
let name_constraints: Option<Uint8Array> = Reflect::get(&obj, &jval!("name_constraints"))
.and_then(|x| x.dyn_into())
.ok();
Ok(TrustAnchor {
subject: Der::from(subject.to_vec()),
subject_public_key_info: Der::from(pub_key_info.to_vec()),
name_constraints: name_constraints.map(|x| Der::from(x.to_vec())),
})
}

View file

@ -1,104 +1,106 @@
use crate::*; use std::{str::from_utf8, sync::Arc};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::Bytes;
use fastwebsockets::{ use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
}; };
use futures_util::lock::Mutex; use futures_util::lock::Mutex;
use http_body_util::Full; use getrandom::getrandom;
use hyper::{ use http::{
header::{CONNECTION, UPGRADE}, header::{
upgrade::Upgraded, CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
StatusCode, },
Method, Request, Response, StatusCode, Uri,
}; };
use std::str::from_utf8; use hyper::{
body::Incoming,
upgrade::{self, Upgraded},
};
use js_sys::{ArrayBuffer, Function, Uint8Array};
use tokio::io::WriteHalf; use tokio::io::WriteHalf;
use wasm_bindgen::{prelude::*, JsError, JsValue};
use wasm_bindgen_futures::spawn_local;
#[wasm_bindgen(inspectable)] use crate::{tokioio::TokioIo, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody};
pub struct EpxWebSocket {
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
onerror: Function,
#[wasm_bindgen(readonly, getter_with_clone)]
pub url: String,
#[wasm_bindgen(readonly, getter_with_clone)]
pub protocols: Vec<String>,
#[wasm_bindgen(readonly, getter_with_clone)]
pub origin: String,
}
#[wasm_bindgen] #[wasm_bindgen]
impl EpxWebSocket { pub struct EpoxyWebSocket {
#[wasm_bindgen(constructor)] tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
pub fn new() -> Result<EpxWebSocket, JsError> {
Err(jerr!("Use EpoxyClient.connect_ws() instead."))
}
// shut up
#[allow(clippy::too_many_arguments)]
pub async fn connect(
tcp: &EpoxyClient,
onopen: Function,
onclose: Function,
onerror: Function, onerror: Function,
onmessage: Function, }
impl EpoxyWebSocket {
pub(crate) async fn connect(
client: &EpoxyClient,
handlers: EpoxyHandlers,
url: String, url: String,
protocols: Vec<String>, protocols: Vec<String>,
origin: String, ) -> Result<Self, EpoxyError> {
) -> Result<EpxWebSocket, JsError> { let EpoxyHandlers {
let onerr = onerror.clone(); onopen,
let ret: Result<EpxWebSocket, JsError> = async move { onclose,
let url = Uri::try_from(url).replace_err("Failed to parse URL")?; onerror,
let host = url.host().replace_err("URL must have a host")?; onmessage,
} = handlers;
let onerror_cloned = onerror.clone();
let ret: Result<EpoxyWebSocket, EpoxyError> = async move {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let mut rand: [u8; 16] = [0; 16]; let mut rand = [0u8; 16];
getrandom::getrandom(&mut rand)?; getrandom(&mut rand)?;
let key = STANDARD.encode(rand); let key = BASE64_STANDARD.encode(rand);
let mut builder = Request::builder() let mut request = Request::builder()
.method("GET") .method(Method::GET)
.uri(url.clone()) .uri(url.clone())
.header("Host", host) .header(HOST, host)
.header("Origin", origin.clone())
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header("Sec-WebSocket-Key", key) .header(UPGRADE, "websocket")
.header("Sec-WebSocket-Version", "13"); .header(SEC_WEBSOCKET_KEY, key)
.header(SEC_WEBSOCKET_VERSION, "13");
if !protocols.is_empty() { if !protocols.is_empty() {
builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", ")); request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
} }
let req = builder.body(Full::<Bytes>::new(Bytes::new()))?; let request = request.body(HttpBody::new(Bytes::new()))?;
let mut response = tcp.hyper_client.request(req).await?; let mut response = client.client.request(request).await?;
verify(&response)?; verify(&response)?;
let ws = WebSocket::after_handshake( let websocket = WebSocket::after_handshake(
TokioIo::new(hyper::upgrade::on(&mut response).await?), TokioIo::new(upgrade::on(&mut response).await?),
Role::Client, Role::Client,
); );
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = websocket.split(tokio::io::split);
let mut rx = FragmentCollectorRead::new(rx); let mut rx = FragmentCollectorRead::new(rx);
let tx = Arc::new(Mutex::new(tx)); let tx = Arc::new(Mutex::new(tx));
let tx_cloned = tx.clone();
wasm_bindgen_futures::spawn_local(async move { let read_tx = tx.clone();
while let Ok(frame) = rx let onerror_cloned = onerror.clone();
.read_frame(&mut |arg| async { tx_cloned.lock().await.write_frame(arg).await })
spawn_local(async move {
loop {
match rx
.read_frame(&mut |arg| async {
read_tx.lock().await.write_frame(arg).await
})
.await .await
{ {
match frame.opcode { Ok(frame) => match frame.opcode {
OpCode::Text => { OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) { if let Ok(str) = from_utf8(&frame.payload) {
let _ = onmessage.call1(&JsValue::null(), &jval!(str)); let _ = onmessage.call1(&JsValue::null(), &str.into());
} }
} }
OpCode::Binary => { OpCode::Binary => {
let _ = onmessage.call1( let _ = onmessage.call1(
&JsValue::null(), &JsValue::null(),
&jval!(Uint8Array::from(frame.payload.to_vec().as_slice())), &Uint8Array::from(frame.payload.to_vec().as_slice()).into(),
); );
} }
OpCode::Close => { OpCode::Close => {
@ -107,100 +109,109 @@ impl EpxWebSocket {
} }
// ping/pong/continue // ping/pong/continue
_ => {} _ => {}
},
Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break;
} }
} }
}
let _ = onclose.call0(&JsValue::null());
}); });
onopen let _ = onopen.call0(&JsValue::null());
.call0(&Object::default())
.replace_err("Failed to call onopen")?;
Ok(Self { Ok(Self {
tx, tx,
onerror, onerror: onerror_cloned,
origin,
protocols,
url: url.to_string(),
}) })
} }
.await; .await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone())); match ret {
Err(ret) Ok(ok) => Ok(ok),
} else { Err(err) => {
ret let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
} }
} }
#[wasm_bindgen] pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
pub async fn send_text(&self, payload: String) -> Result<(), JsError> { let ret = if let Some(str) = payload.as_string() {
let onerr = self.onerror.clone();
let ret = self
.tx
.lock()
.await
.write_frame(Frame::text(Payload::Owned(payload.as_bytes().to_vec())))
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.to_string()));
Err(ret.into())
} else {
Ok(ret?)
}
}
#[wasm_bindgen]
pub async fn send_binary(&self, payload: Uint8Array) -> Result<(), JsError> {
let onerr = self.onerror.clone();
let ret = self
.tx
.lock()
.await
.write_frame(Frame::binary(Payload::Owned(payload.to_vec())))
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.to_string()));
Err(ret.into())
} else {
Ok(ret?)
}
}
#[wasm_bindgen]
pub async fn close(&self) -> Result<(), JsError> {
self.tx self.tx
.lock()
.await
.write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec())))
.await
.map_err(EpoxyError::from)
} else if let Ok(binary) = payload.dyn_into::<ArrayBuffer>() {
self.tx
.lock()
.await
.write_frame(Frame::binary(Payload::Owned(
Uint8Array::new(&binary).to_vec(),
)))
.await
.map_err(EpoxyError::from)
} else {
Err(EpoxyError::WsInvalidPayload)
};
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn close(&self) -> Result<(), EpoxyError> {
let ret = self
.tx
.lock() .lock()
.await .await
.write_frame(Frame::close(CloseCode::Normal.into(), b"")) .write_frame(Frame::close(CloseCode::Normal.into(), b""))
.await?; .await;
Ok(()) match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err.into())
}
}
} }
} }
// https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189 // https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189
fn verify(response: &Response<Incoming>) -> Result<(), JsError> { fn verify(response: &Response<Incoming>) -> Result<(), EpoxyError> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(jerr!("epoxy ws connect: Invalid status code")); return Err(EpoxyError::WsInvalidStatusCode);
} }
let headers = response.headers(); let headers = response.headers();
if !headers if !headers
.get("Upgrade") .get(UPGRADE)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(jerr!("epoxy ws connect: Invalid upgrade header")); return Err(EpoxyError::WsInvalidUpgradeHeader);
} }
if !headers if !headers
.get("Connection") .get(CONNECTION)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(jerr!("epoxy ws connect: Invalid upgrade header")); return Err(EpoxyError::WsInvalidConnectionHeader);
} }
Ok(()) Ok(())

View file

@ -1,142 +1,23 @@
use crate::*; use std::sync::{
use std::{ atomic::{AtomicBool, Ordering},
ops::Deref, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll} Arc,
}; };
use async_trait::async_trait;
use bytes::BytesMut; use bytes::BytesMut;
use event_listener::Event; use event_listener::Event;
use futures_util::{FutureExt, Stream}; use flume::Receiver;
use hyper::body::Body; use futures_util::FutureExt;
use js_sys::ArrayBuffer; use js_sys::{Array, ArrayBuffer, Uint8Array};
use pin_project_lite::pin_project;
use send_wrapper::SendWrapper; use send_wrapper::SendWrapper;
use std::future::Future; use wasm_bindgen::{closure::Closure, JsCast};
use tokio::sync::mpsc;
use web_sys::{BinaryType, MessageEvent, WebSocket}; use web_sys::{BinaryType, MessageEvent, WebSocket};
use wisp_mux::{ use wisp_mux::{
ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite}, ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite},
WispError, WispError,
}; };
pin_project! { use crate::EpoxyError;
pub struct IncomingBody {
#[pin]
incoming: hyper::body::Incoming,
}
}
impl IncomingBody {
pub fn new(incoming: hyper::body::Incoming) -> IncomingBody {
IncomingBody { incoming }
}
}
impl Stream for IncomingBody {
type Item = std::io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let ret = this.incoming.poll_frame(cx);
match ret {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item {
Some(frame) => frame
.map(|x| {
x.into_data()
.map_err(|_| std::io::Error::other("not data frame"))
})
.ok(),
None => None,
}),
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
}
}
}
#[derive(Clone)]
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux>>, pub String);
impl tower_service::Service<hyper::Uri> for ServiceWrapper {
type Response = TokioIo<EpxIoUnencryptedStream>;
type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Uri) -> Self::Future {
let mux = self.0.clone();
let mux_url = self.1.clone();
async move {
let stream = mux
.write()
.await
.client_new_stream(
StreamType::Tcp,
req.host().ok_or(WispError::UriHasNoHost)?.to_string(),
req.port().ok_or(WispError::UriHasNoPort)?.into(),
)
.await;
if stream
.as_ref()
.is_err_and(|e| matches!(e, WispError::WsImplSocketClosed))
{
utils::replace_mux(mux, &mux_url).await?;
}
Ok(TokioIo::new(stream?.into_io().into_asyncrw()))
}
}
}
#[derive(Clone)]
pub struct TlsWispService {
pub service: ServiceWrapper,
pub rustls_config: Arc<rustls::ClientConfig>,
}
impl tower_service::Service<hyper::Uri> for TlsWispService {
type Response = TokioIo<EpxIoStream>;
type Error = WispError;
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: http::Uri) -> Self::Future {
let mut service = self.service.clone();
let rustls_config = self.rustls_config.clone();
Box::pin(async move {
let uri_host = req
.host()
.ok_or(WispError::UriHasNoHost)?
.to_string()
.clone();
let uri_parsed = Uri::builder()
.authority(format!(
"{}:{}",
uri_host,
utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)?
))
.build()
.map_err(|x| WispError::Other(Box::new(x)))?;
let stream = service.call(uri_parsed).await?.into_inner();
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
let connector = TlsConnector::from(rustls_config);
Ok(TokioIo::new(Either::Left(
connector
.connect(
uri_host.try_into().map_err(|_| WispError::InvalidUri)?,
stream,
)
.await
.map_err(|x| WispError::Other(Box::new(x)))?,
)))
} else {
Ok(TokioIo::new(Either::Right(stream)))
}
})
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum WebSocketError { pub enum WebSocketError {
@ -189,12 +70,12 @@ pub struct WebSocketWrapper {
} }
pub struct WebSocketReader { pub struct WebSocketReader {
read_rx: mpsc::UnboundedReceiver<WebSocketMessage>, read_rx: Receiver<WebSocketMessage>,
closed: Arc<AtomicBool>, closed: Arc<AtomicBool>,
close_event: Arc<Event>, close_event: Arc<Event>,
} }
#[async_trait::async_trait] #[async_trait]
impl WebSocketRead for WebSocketReader { impl WebSocketRead for WebSocketReader {
async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> { async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> {
use WebSocketMessage::*; use WebSocketMessage::*;
@ -202,11 +83,11 @@ impl WebSocketRead for WebSocketReader {
return Err(WispError::WsImplSocketClosed); return Err(WispError::WsImplSocketClosed);
} }
let res = futures_util::select! { let res = futures_util::select! {
data = self.read_rx.recv().fuse() => data, data = self.read_rx.recv_async() => data.ok(),
_ = self.close_event.listen().fuse() => Some(Closed), _ = self.close_event.listen().fuse() => Some(Closed),
}; };
match res.ok_or(WispError::WsImplSocketClosed)? { match res.ok_or(WispError::WsImplSocketClosed)? {
Message(bin) => Ok(Frame::binary(BytesMut::from(bin.deref()))), Message(bin) => Ok(Frame::binary(BytesMut::from(bin.as_slice()))),
Error => Err(WebSocketError::Unknown.into()), Error => Err(WebSocketError::Unknown.into()),
Closed => Err(WispError::WsImplSocketClosed), Closed => Err(WispError::WsImplSocketClosed),
} }
@ -214,8 +95,8 @@ impl WebSocketRead for WebSocketReader {
} }
impl WebSocketWrapper { impl WebSocketWrapper {
pub fn connect(url: &str, protocols: Vec<String>) -> Result<(Self, WebSocketReader), JsValue> { pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> {
let (read_tx, read_rx) = mpsc::unbounded_channel(); let (read_tx, read_rx) = flume::unbounded();
let closed = Arc::new(AtomicBool::new(false)); let closed = Arc::new(AtomicBool::new(false));
let open_event = Arc::new(Event::new()); let open_event = Arc::new(Event::new());
@ -261,13 +142,13 @@ impl WebSocketWrapper {
&protocols &protocols
.iter() .iter()
.fold(Array::new(), |acc, x| { .fold(Array::new(), |acc, x| {
acc.push(&jval!(x)); acc.push(&x.into());
acc acc
}) })
.into(), .into(),
) )
} }
.replace_err("Failed to make websocket")?; .map_err(|_| EpoxyError::WebSocketConnectFailed)?;
ws.set_binary_type(BinaryType::Arraybuffer); ws.set_binary_type(BinaryType::Arraybuffer);
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
@ -294,15 +175,18 @@ impl WebSocketWrapper {
)) ))
} }
pub async fn wait_for_open(&self) { pub async fn wait_for_open(&self) -> bool {
if self.closed.load(Ordering::Acquire) {
return false;
}
futures_util::select! { futures_util::select! {
_ = self.open_event.listen().fuse() => (), _ = self.open_event.listen().fuse() => true,
_ = self.error_event.listen().fuse() => (), _ = self.error_event.listen().fuse() => false,
} }
} }
} }
#[async_trait::async_trait] #[async_trait]
impl WebSocketWrite for WebSocketWrapper { impl WebSocketWrite for WebSocketWrapper {
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> {
use wisp_mux::ws::OpCode::*; use wisp_mux::ws::OpCode::*;

View file

@ -1,25 +0,0 @@
#!/usr/bin/env bash
# https://aweirdimagination.net/2020/06/28/kill-child-jobs-on-script-exit/
cleanup() {
pkill -P $$
}
for sig in INT QUIT HUP TERM; do
trap "
cleanup
trap - $sig EXIT
kill -s $sig "'"$$"' "$sig"
done
trap cleanup EXIT
set -euo pipefail
shopt -s inherit_errexit
(cd ..; cargo b --bin epoxy-server)
../target/debug/epoxy-server &
server_pid=$!
sleep 1
echo "server_pid: $server_pid"
GECKODRIVER=$(which geckodriver) cargo test --target wasm32-unknown-unknown
CHROMEDRIVER=$(which chromedriver) cargo test --target wasm32-unknown-unknown

View file

@ -1,300 +0,0 @@
use default_env::default_env;
use epoxy_client::EpoxyClient;
use js_sys::{Array, JsString, Object, Reflect, Uint8Array, JSON};
use rustls_pki_types::TrustAnchor;
use tokio::sync::OnceCell;
use wasm_bindgen::JsValue;
use wasm_bindgen_futures::JsFuture;
use wasm_bindgen_test::*;
use web_sys::{FormData, Headers, Response, UrlSearchParams};
wasm_bindgen_test_configure!(run_in_dedicated_worker);
static USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36";
static EPOXY_CLIENT: OnceCell<EpoxyClient> = OnceCell::const_new();
pub fn trustanchor_to_object(cert: &TrustAnchor) -> Result<JsValue, JsValue> {
let val = Object::new();
Reflect::set(
&val,
&JsValue::from("subject"),
&Uint8Array::from(cert.subject.as_ref()),
)?;
Reflect::set(
&val,
&JsValue::from("subject_public_key_info"),
&Uint8Array::from(cert.subject_public_key_info.as_ref()),
)?;
Reflect::set(
&val,
&JsValue::from("name_constraints"),
&JsValue::from(
cert.name_constraints
.as_ref()
.map(|x| Uint8Array::from(x.as_ref())),
),
)?;
Ok(val.into())
}
async fn get_client_w_ua(useragent: &str, redirect_limit: usize) -> EpoxyClient {
EpoxyClient::new(
"ws://localhost:4000".into(),
useragent.into(),
redirect_limit,
webpki_roots::TLS_SERVER_ROOTS
.iter()
.map(trustanchor_to_object)
.collect::<Result<Array, JsValue>>()
.expect("Failed to create certs"),
)
.await
.ok()
.expect("Failed to create client")
}
macro_rules! fetch {
($url:expr, $opts:expr) => {
EPOXY_CLIENT
.get_or_init(|| get_client_w_ua(USER_AGENT, 10))
.await
.fetch($url, $opts)
.await
.ok()
.expect("Failed to fetch")
};
}
macro_rules! httpbin {
($url:literal) => {
concat!(default_env!("HTTPBIN_URL", "https://httpbin.org/"), $url)
};
}
async fn get_body_json(resp: &Response) -> JsValue {
JsFuture::from(resp.json().unwrap()).await.unwrap()
}
async fn get_body_text(resp: &Response) -> JsValue {
JsFuture::from(resp.text().unwrap()).await.unwrap()
}
fn get_header(body: &JsValue, header: &str) -> Result<JsValue, JsValue> {
Reflect::get(body, &JsValue::from("headers"))
.and_then(|x| Reflect::get(&x, &JsValue::from(header)))
}
fn get_resp_body(body: &JsValue) -> Result<JsValue, JsValue> {
Reflect::get(body, &JsValue::from("data"))
}
fn get_resp_form(body: &JsValue) -> Result<JsValue, JsValue> {
Reflect::get(body, &JsValue::from("form"))
}
fn check_resp(resp: &Response, url: &str, status: u16, status_text: &str) {
assert_eq!(resp.url(), url);
assert_eq!(resp.status(), status);
assert_eq!(resp.status_text(), status_text);
}
#[wasm_bindgen_test]
async fn get() {
let url = httpbin!("get");
let resp = fetch!(url.into(), Object::new());
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_header(&body, "User-Agent"),
Ok(JsValue::from(USER_AGENT))
);
}
#[wasm_bindgen_test]
async fn gzip() {
let url = httpbin!("gzip");
let resp = fetch!(url.into(), Object::new());
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_header(&body, "Accept-Encoding"),
Ok(JsValue::from("gzip, br"))
);
}
#[wasm_bindgen_test]
async fn brotli() {
let url = httpbin!("brotli");
let resp = fetch!(url.into(), Object::new());
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_header(&body, "Accept-Encoding"),
Ok(JsValue::from("gzip, br"))
);
}
#[wasm_bindgen_test]
async fn redirect() {
let url = httpbin!("redirect/2");
let resp = fetch!(url.into(), Object::new());
check_resp(&resp, httpbin!("get"), 200, "OK");
get_body_json(&resp).await;
}
#[wasm_bindgen_test]
async fn redirect_limit() {
// new client created due to redirect limit difference
let client = get_client_w_ua(USER_AGENT, 2).await;
let url = httpbin!("redirect/3");
let resp = client
.fetch(url.into(), Object::new())
.await
.ok()
.expect("Failed to fetch");
check_resp(&resp, httpbin!("relative-redirect/1"), 302, "Found");
assert_eq!(get_body_text(&resp).await, JsValue::from(""));
}
#[wasm_bindgen_test]
async fn redirect_manual() {
let url = httpbin!("redirect/2");
let obj = Object::new();
Reflect::set(&obj, &JsValue::from("redirect"), &JsValue::from("manual")).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 302, "Found");
get_body_text(&resp).await;
}
#[wasm_bindgen_test]
async fn post_string() {
let url = httpbin!("post");
let obj = Object::new();
Reflect::set(&obj, &JsValue::from("method"), &JsValue::from("POST")).unwrap();
Reflect::set(&obj, &JsValue::from("body"), &JsValue::from("epoxy body")).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(get_resp_body(&body), Ok(JsValue::from("epoxy body")));
}
#[wasm_bindgen_test]
async fn post_arraybuffer() {
let url = httpbin!("post");
let obj = Object::new();
Reflect::set(&obj, &JsValue::from("method"), &JsValue::from("POST")).unwrap();
let req_body = b"epoxy body";
let u8array = Uint8Array::new_with_length(req_body.len().try_into().unwrap());
u8array.copy_from(req_body);
Reflect::set(&obj, &JsValue::from("body"), &u8array).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(get_resp_body(&body), Ok(JsValue::from("epoxy body")));
}
#[wasm_bindgen_test]
async fn post_formdata() {
let url = httpbin!("post");
let obj = Object::new();
Reflect::set(&obj, &JsValue::from("method"), &JsValue::from("POST")).unwrap();
let req_body = FormData::new().unwrap();
req_body.set_with_str("a", "b").unwrap();
Reflect::set(&obj, &JsValue::from("body"), &req_body).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_resp_form(&body).and_then(|x| JSON::stringify(&x)),
Ok(JsString::from(r#"{"a":"b"}"#))
);
assert!(JsString::from(get_header(&body, "Content-Type").unwrap())
.includes("multipart/form-data", 0));
}
#[wasm_bindgen_test]
async fn post_urlsearchparams() {
let url = httpbin!("post");
let obj = Object::new();
Reflect::set(&obj, &JsValue::from("method"), &JsValue::from("POST")).unwrap();
let req_body = UrlSearchParams::new_with_str("a=b").unwrap();
Reflect::set(&obj, &JsValue::from("body"), &req_body).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_resp_form(&body).and_then(|x| JSON::stringify(&x)),
Ok(JsString::from(r#"{"a":"b"}"#))
);
assert!(JsString::from(get_header(&body, "Content-Type").unwrap())
.includes("application/x-www-form-urlencoded", 0));
}
#[wasm_bindgen_test]
async fn headers_obj() {
let url = httpbin!("get");
let obj = Object::new();
let headers = Object::new();
Reflect::set(
&headers,
&JsValue::from("x-header-one"),
&JsValue::from("value"),
)
.unwrap();
Reflect::set(&obj, &JsValue::from("headers"), &headers).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_header(&body, "X-Header-One"),
Ok(JsValue::from("value"))
);
}
#[wasm_bindgen_test]
async fn headers_headers() {
let url = httpbin!("get");
let obj = Object::new();
let headers = Headers::new().unwrap();
headers.set("x-header-one", "value").unwrap();
Reflect::set(&obj, &JsValue::from("headers"), &headers).unwrap();
let resp = fetch!(url.into(), obj);
check_resp(&resp, url, 200, "OK");
let body: Object = get_body_json(&resp).await.into();
assert_eq!(
get_header(&body, "X-Header-One"),
Ok(JsValue::from("value"))
);
}