diff --git a/client/demo.js b/client/demo.js index 298e2f2..bde2e80 100644 --- a/client/demo.js +++ b/client/demo.js @@ -2,245 +2,257 @@ import epoxy from "./pkg/epoxy-module-bundled.js"; import CERTS from "./pkg/certs-module.js"; onmessage = async (msg) => { - console.debug("recieved demo:", msg); - let [ - should_feature_test, - should_multiparallel_test, - should_parallel_test, - should_multiperf_test, - should_perf_test, - should_ws_test, - should_tls_test, - should_udp_test, - should_reconnect_test, - should_perf2_test, - ] = msg.data; - console.log( - "%cWASM is significantly slower with DevTools open!", - "color:red;font-size:3rem;font-weight:bold" - ); + console.debug("recieved demo:", msg); + let [ + should_feature_test, + should_multiparallel_test, + should_parallel_test, + should_multiperf_test, + should_perf_test, + should_ws_test, + should_tls_test, + should_udp_test, + should_reconnect_test, + should_perf2_test, + ] = msg.data; + console.log( + "%cWASM is significantly slower with DevTools open!", + "color:red;font-size:3rem;font-weight:bold" + ); - const log = (str) => { - console.log(str); - postMessage(str); - } + const log = (str) => { + console.log(str); + postMessage(str); + } - const plog = (str) => { - console.log(str); - postMessage(JSON.stringify(str, null, 4)); - } + const plog = (str) => { + console.log(str); + postMessage(JSON.stringify(str, null, 4)); + } - const { EpoxyClient } = await epoxy(); + const { EpoxyClient, EpoxyClientOptions, EpoxyHandlers } = await epoxy(); - console.log("certs:", CERTS); + console.log("certs:", CERTS); - const tconn0 = performance.now(); - // args: websocket url, user agent, redirect limit, certs - let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10, CERTS); - const tconn1 = performance.now(); - log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); + let epoxy_client_options = new EpoxyClientOptions(); + epoxy_client_options.user_agent = navigator.userAgent; - // epoxy classes are inspectable - console.log(epoxy_client); - // you can change the user agent and redirect limit in JS - epoxy_client.redirectLimit = 15; + let epoxy_client = new EpoxyClient("ws://localhost:4000", CERTS, epoxy_client_options); - const test_mux = async (url) => { - const t0 = performance.now(); - await epoxy_client.fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; + const tconn0 = performance.now(); + await epoxy_client.replace_stream_provider(); + const tconn1 = performance.now(); + log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); - const test_native = async (url) => { - const t0 = performance.now(); - await fetch(url, { cache: "no-store" }); - const t1 = performance.now(); - return t1 - t0; - }; + // epoxy classes are inspectable + console.log(epoxy_client); + // you can change the user agent and redirect limit in JS + epoxy_client.redirect_limit = 15; - if (should_feature_test) { - let formdata = new FormData(); - formdata.append("a", "b"); - for (const url of [ - ["https://httpbin.org/get", {}], - ["https://httpbin.org/gzip", {}], - ["https://httpbin.org/brotli", {}], - ["https://httpbin.org/redirect/11", {}], - ["https://httpbin.org/redirect/1", { redirect: "manual" }], - ["https://httpbin.org/post", { method: "POST", body: new URLSearchParams("a=b") }], - ["https://httpbin.org/post", { method: "POST", body: formdata }], - ["https://httpbin.org/post", { method: "POST", body: "a" }], - ["https://httpbin.org/post", { method: "POST", body: (new TextEncoder()).encode("abc") }], - ["https://httpbin.org/get", { headers: {"a": "b", "b": "c"} }], - ["https://httpbin.org/get", { headers: new Headers({"a": "b", "b": "c"}) }] - ]) { - let resp = await epoxy_client.fetch(url[0], url[1]); - console.warn(url, resp, Object.fromEntries(resp.headers)); - log(await resp.text()); - } - } else if (should_multiparallel_test) { - const num_tests = 10; - let total_mux_minus_native = 0; - for (const _ of Array(num_tests).keys()) { - let total_mux = 0; - await Promise.all([...Array(num_tests).keys()].map(async i => { - log(`running mux test ${i}`); - return await test_mux("https://httpbin.org/get"); - })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); - total_mux = total_mux / num_tests; + const test_mux = async (url) => { + const t0 = performance.now(); + await epoxy_client.fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; - let total_native = 0; - await Promise.all([...Array(num_tests).keys()].map(async i => { - log(`running native test ${i}`); - return await test_native("https://httpbin.org/get"); - })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); - total_native = total_native / num_tests; + const test_native = async (url) => { + const t0 = performance.now(); + await fetch(url, { cache: "no-store" }); + const t1 = performance.now(); + return t1 - t0; + }; - log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); - log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); - total_mux_minus_native += total_mux - total_native; - } - total_mux_minus_native = total_mux_minus_native / num_tests; - log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); - } else if (should_parallel_test) { - const num_tests = 10; + if (should_feature_test) { + let formdata = new FormData(); + formdata.append("a", "b"); + for (const url of [ + ["https://httpbin.org/get", {}], + ["https://httpbin.org/gzip", {}], + ["https://httpbin.org/brotli", {}], + ["https://httpbin.org/redirect/11", {}], + ["https://httpbin.org/redirect/1", { redirect: "manual" }], + ["https://httpbin.org/post", { method: "POST", body: new URLSearchParams("a=b") }], + ["https://httpbin.org/post", { method: "POST", body: formdata }], + ["https://httpbin.org/post", { method: "POST", body: "a" }], + ["https://httpbin.org/post", { method: "POST", body: (new TextEncoder()).encode("abc") }], + ["https://httpbin.org/get", { headers: { "a": "b", "b": "c" } }], + ["https://httpbin.org/get", { headers: new Headers({ "a": "b", "b": "c" }) }] + ]) { + let resp = await epoxy_client.fetch(url[0], url[1]); + console.warn(url, resp, Object.fromEntries(resp.headers)); + log(await resp.text()); + } + } else if (should_multiparallel_test) { + const num_tests = 10; + let total_mux_minus_native = 0; + for (const _ of Array(num_tests).keys()) { + let total_mux = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running mux test ${i}`); + return await test_mux("https://httpbin.org/get"); + })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); + total_mux = total_mux / num_tests; - let total_mux = 0; - await Promise.all([...Array(num_tests).keys()].map(async i => { - log(`running mux test ${i}`); - return await test_mux("https://httpbin.org/get"); - })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); - total_mux = total_mux / num_tests; + let total_native = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running native test ${i}`); + return await test_native("https://httpbin.org/get"); + })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); + total_native = total_native / num_tests; - let total_native = 0; - await Promise.all([...Array(num_tests).keys()].map(async i => { - log(`running native test ${i}`); - return await test_native("https://httpbin.org/get"); - })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); - total_native = total_native / num_tests; + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + total_mux_minus_native += total_mux - total_native; + } + total_mux_minus_native = total_mux_minus_native / num_tests; + log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); + } else if (should_parallel_test) { + const num_tests = 10; - log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); - log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); - } else if (should_multiperf_test) { - const num_tests = 10; - let total_mux_minus_native = 0; - for (const _ of Array(num_tests).keys()) { - let total_mux = 0; - for (const i of Array(num_tests).keys()) { - log(`running mux test ${i}`); - total_mux += await test_mux("https://httpbin.org/get"); - } - total_mux = total_mux / num_tests; + let total_mux = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running mux test ${i}`); + return await test_mux("https://httpbin.org/get"); + })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); + total_mux = total_mux / num_tests; - let total_native = 0; - for (const i of Array(num_tests).keys()) { - log(`running native test ${i}`); - total_native += await test_native("https://httpbin.org/get"); - } - total_native = total_native / num_tests; + let total_native = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running native test ${i}`); + return await test_native("https://httpbin.org/get"); + })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); + total_native = total_native / num_tests; - log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); - log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); - total_mux_minus_native += total_mux - total_native; - } - total_mux_minus_native = total_mux_minus_native / num_tests; - log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); - } else if (should_perf_test) { - const num_tests = 10; + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + } else if (should_multiperf_test) { + const num_tests = 10; + let total_mux_minus_native = 0; + for (const _ of Array(num_tests).keys()) { + let total_mux = 0; + for (const i of Array(num_tests).keys()) { + log(`running mux test ${i}`); + total_mux += await test_mux("https://httpbin.org/get"); + } + total_mux = total_mux / num_tests; - let total_mux = 0; - for (const i of Array(num_tests).keys()) { - log(`running mux test ${i}`); - total_mux += await test_mux("https://httpbin.org/get"); - } - total_mux = total_mux / num_tests; + let total_native = 0; + for (const i of Array(num_tests).keys()) { + log(`running native test ${i}`); + total_native += await test_native("https://httpbin.org/get"); + } + total_native = total_native / num_tests; - let total_native = 0; - for (const i of Array(num_tests).keys()) { - log(`running native test ${i}`); - total_native += await test_native("https://httpbin.org/get"); - } - total_native = total_native / num_tests; + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + total_mux_minus_native += total_mux - total_native; + } + total_mux_minus_native = total_mux_minus_native / num_tests; + log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); + } else if (should_perf_test) { + const num_tests = 10; - log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); - log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); - } else if (should_ws_test) { - let ws = await epoxy_client.connect_ws( - () => log("opened"), - () => log("closed"), - err => console.error(err), - msg => log(msg), - "wss://echo.websocket.events", - [], - "localhost" - ); - while (true) { - log("sending `data`"); - await ws.send_text("data"); - await (new Promise((res, _) => setTimeout(res, 50))); - } - } else if (should_tls_test) { - let decoder = new TextDecoder(); - let ws = await epoxy_client.connect_tls( - () => log("opened"), - () => log("closed"), - err => console.error(err), - msg => { console.log(msg); log(decoder.decode(msg)) }, - "google.com:443", - ); - await ws.send("GET / HTTP 1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n"); - await (new Promise((res, _) => setTimeout(res, 500))); - await ws.close(); - } else if (should_udp_test) { - let decoder = new TextDecoder(); - // tokio example: `cargo r --example echo-udp -- 127.0.0.1:5000` - let ws = await epoxy_client.connect_udp( - () => log("opened"), - () => log("closed"), - err => console.error(err), - msg => { console.log(msg); log(decoder.decode(msg)) }, - "127.0.0.1:5000", - ); - while (true) { - log("sending `data`"); - await ws.send("data"); - await (new Promise((res, _) => setTimeout(res, 50))); - } - } else if (should_reconnect_test) { - while (true) { - try { - await epoxy_client.fetch("https://httpbin.org/get"); - } catch(e) {console.error(e)} - log("sent req"); - await (new Promise((res, _) => setTimeout(res, 500))); - } - } else if (should_perf2_test) { - const num_outer_tests = 10; - const num_inner_tests = 50; - let total_mux_multi = 0; - for (const _ of Array(num_outer_tests).keys()) { - let total_mux = 0; - await Promise.all([...Array(num_inner_tests).keys()].map(async i => { - log(`running mux test ${i}`); - return await test_mux("https://httpbin.org/get"); - })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); - total_mux = total_mux / num_inner_tests; + let total_mux = 0; + for (const i of Array(num_tests).keys()) { + log(`running mux test ${i}`); + total_mux += await test_mux("https://httpbin.org/get"); + } + total_mux = total_mux / num_tests; - log(`avg mux (${num_inner_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); - total_mux_multi += total_mux; - } - total_mux_multi = total_mux_multi / num_outer_tests; - log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); + let total_native = 0; + for (const i of Array(num_tests).keys()) { + log(`running native test ${i}`); + total_native += await test_native("https://httpbin.org/get"); + } + total_native = total_native / num_tests; - } else { - let resp = await epoxy_client.fetch("https://www.example.com/"); - console.log(resp, Object.fromEntries(resp.headers)); - log(await resp.text()); - } - log("done"); + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + } else if (should_ws_test) { + let handlers = new EpoxyHandlers( + () => log("opened"), + () => log("closed"), + err => console.error(err), + msg => log(`got "${msg}"`) + ); + let ws = await epoxy_client.connect_websocket( + handlers, + "wss://echo.websocket.events", + [], + ); + while (true) { + log("sending `data`"); + await ws.send("data"); + await (new Promise((res, _) => setTimeout(res, 10))); + } + } else if (should_tls_test) { + let decoder = new TextDecoder(); + let handlers = new EpoxyHandlers( + () => log("opened"), + () => log("closed"), + err => console.error(err), + msg => { console.log(msg); console.log(decoder.decode(msg).split("\r\n\r\n")[1].length); log(decoder.decode(msg)) }, + ); + let ws = await epoxy_client.connect_tls( + handlers, + "google.com:443", + ); + await ws.send("GET / HTTP 1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n"); + await (new Promise((res, _) => setTimeout(res, 500))); + await ws.close(); + } else if (should_udp_test) { + let decoder = new TextDecoder(); + let handlers = new EpoxyHandlers( + () => log("opened"), + () => log("closed"), + err => console.error(err), + msg => { console.log(msg); log(decoder.decode(msg)) }, + ); + // tokio example: `cargo r --example echo-udp -- 127.0.0.1:5000` + let ws = await epoxy_client.connect_udp( + handlers, + "127.0.0.1:5000", + ); + while (true) { + log("sending `data`"); + await ws.send("data"); + await (new Promise((res, _) => setTimeout(res, 10))); + } + } else if (should_reconnect_test) { + while (true) { + try { + await epoxy_client.fetch("https://httpbin.org/get"); + } catch (e) { console.error(e) } + log("sent req"); + await (new Promise((res, _) => setTimeout(res, 500))); + } + } else if (should_perf2_test) { + const num_outer_tests = 10; + const num_inner_tests = 50; + let total_mux_multi = 0; + for (const _ of Array(num_outer_tests).keys()) { + let total_mux = 0; + await Promise.all([...Array(num_inner_tests).keys()].map(async i => { + log(`running mux test ${i}`); + return await test_mux("https://httpbin.org/get"); + })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); + total_mux = total_mux / num_inner_tests; + + log(`avg mux (${num_inner_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + total_mux_multi += total_mux; + } + total_mux_multi = total_mux_multi / num_outer_tests; + log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); + + } else { + let resp = await epoxy_client.fetch("https://www.example.com/"); + console.log(resp, Object.fromEntries(resp.headers)); + log(await resp.text()); + } + log("done"); }; diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index 02cc975..6b4b7c7 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -1,4 +1,4 @@ -use bytes::{BufMut, BytesMut}; +use bytes::{buf::UninitSlice, BufMut, BytesMut}; use futures_util::{ io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, }; @@ -18,6 +18,7 @@ pub struct EpoxyIoStream { onerror: Function, } +#[wasm_bindgen] impl EpoxyIoStream { pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self { let (mut rx, tx) = stream.split(); @@ -32,16 +33,23 @@ impl EpoxyIoStream { let onerror_cloned = onerror.clone(); - // similar to tokio::io::ReaderStream + // similar to tokio_util::io::ReaderStream spawn_local(async move { let mut buf = BytesMut::with_capacity(4096); loop { - match rx.read(buf.as_mut()).await { + match rx + .read(unsafe { + std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut()) + }) + .await + { Ok(cnt) => { - unsafe { buf.advance_mut(cnt) }; + if cnt > 0 { + unsafe { buf.advance_mut(cnt) }; - let _ = onmessage - .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); + let _ = onmessage + .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); + } } Err(err) => { let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); @@ -101,6 +109,7 @@ pub struct EpoxyUdpStream { onerror: Function, } +#[wasm_bindgen] impl EpoxyUdpStream { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { let (tx, mut rx) = stream.split(); diff --git a/client/src/lib.rs b/client/src/lib.rs index decd466..750f3c7 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -37,23 +37,23 @@ type HttpBody = http_body_util::Full; #[derive(Debug, Error)] pub enum EpoxyError { - #[error(transparent)] + #[error("Invalid DNS name: {0:?}")] InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), - #[error(transparent)] + #[error("Wisp: {0:?}")] Wisp(#[from] wisp_mux::WispError), - #[error(transparent)] + #[error("IO: {0:?}")] Io(#[from] std::io::Error), - #[error(transparent)] + #[error("HTTP: {0:?}")] Http(#[from] http::Error), - #[error(transparent)] + #[error("Hyper client: {0:?}")] HyperClient(#[from] hyper_util_wasm::client::legacy::Error), - #[error(transparent)] + #[error("Hyper: {0:?}")] Hyper(#[from] hyper::Error), - #[error(transparent)] + #[error("HTTP ToStr: {0:?}")] ToStr(#[from] http::header::ToStrError), - #[error(transparent)] + #[error("Getrandom: {0:?}")] GetRandom(#[from] getrandom::Error), - #[error(transparent)] + #[error("Fastwebsockets: {0:?}")] FastWebSockets(#[from] fastwebsockets::WebSocketError), #[error("Invalid URL scheme")] @@ -196,13 +196,14 @@ impl EpoxyHandlers { } } -#[wasm_bindgen] +#[wasm_bindgen(inspectable)] pub struct EpoxyClient { stream_provider: Arc, client: Client, - redirect_limit: usize, - user_agent: String, + pub redirect_limit: usize, + #[wasm_bindgen(getter_with_clone)] + pub user_agent: String, } #[wasm_bindgen] @@ -235,6 +236,10 @@ impl EpoxyClient { }) } + pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> { + self.stream_provider.replace_client().await + } + pub async fn connect_websocket( &self, handlers: EpoxyHandlers, diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 4bde16f..7e6050b 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -4,7 +4,11 @@ use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, TlsConnector, TlsStream, }; -use futures_util::{future::Either, lock::Mutex, AsyncRead, AsyncWrite, Future}; +use futures_util::{ + future::Either, + lock::{Mutex, MutexGuard}, + AsyncRead, AsyncWrite, Future, +}; use hyper_util_wasm::client::legacy::connect::{Connected, Connection}; use js_sys::{Array, Reflect, Uint8Array}; use pin_project_lite::pin_project; @@ -14,7 +18,7 @@ use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen_futures::spawn_local; use wisp_mux::{ extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, - ClientMux, IoStream, MuxStreamIo, StreamType, WispError, + ClientMux, IoStream, MuxStreamIo, StreamType, }; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; @@ -75,7 +79,10 @@ impl StreamProvider { }) } - async fn create_client(&self) -> Result<(), EpoxyError> { + async fn create_client( + &self, + mut locked: MutexGuard<'_, Option>, + ) -> Result<(), EpoxyError> { let extensions_vec: Vec> = vec![Box::new(UdpProtocolExtensionBuilder())]; let extensions = if self.wisp_v2 { @@ -93,7 +100,7 @@ impl StreamProvider { } else { client.with_no_required_extensions() }; - self.current_client.lock().await.replace(mux); + locked.replace(mux); let current_client = self.current_client.clone(); spawn_local(async move { fut.await; @@ -102,6 +109,10 @@ impl StreamProvider { Ok(()) } + pub async fn replace_client(&self) -> Result<(), EpoxyError> { + self.create_client(self.current_client.lock().await).await + } + pub async fn get_stream( &self, stream_type: StreamType, @@ -109,13 +120,14 @@ impl StreamProvider { port: u16, ) -> Result { Box::pin(async { - if let Some(mux) = self.current_client.lock().await.as_ref() { + let locked = self.current_client.lock().await; + if let Some(mux) = locked.as_ref() { Ok(mux .client_new_stream(stream_type, host, port) .await? .into_io()) } else { - self.create_client().await?; + self.create_client(locked).await?; self.get_stream(stream_type, host, port).await } }) @@ -231,12 +243,16 @@ impl Service for StreamProviderService { let provider = self.0.clone(); Box::pin(async move { let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?; - let host = req.host().ok_or(WispError::UriHasNoHost)?.to_string(); - let port = req.port_u16().ok_or(WispError::UriHasNoPort)?; + let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); + let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme { + "https" | "wss" => Ok(443), + "http" | "ws" => Ok(80), + _ => Err(EpoxyError::NoUrlPort), + })?; Ok(HyperIo { inner: match scheme { - "https" => Either::Left(provider.get_tls_stream(host, port).await?), - "http" => { + "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), + "http" | "ws" => { Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) } _ => return Err(EpoxyError::InvalidUrlScheme), diff --git a/client/src/websocket.rs b/client/src/websocket.rs index 677a580..a08a15e 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -30,6 +30,7 @@ pub struct EpoxyWebSocket { onerror: Function, } +#[wasm_bindgen] impl EpoxyWebSocket { pub(crate) async fn connect( client: &EpoxyClient,