get it actually working

This commit is contained in:
Toshit Chawda 2024-06-12 13:42:26 -07:00
parent 177a0d2167
commit 5ec8b3b6de
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 291 additions and 248 deletions

View file

@ -2,245 +2,257 @@ import epoxy from "./pkg/epoxy-module-bundled.js";
import CERTS from "./pkg/certs-module.js"; import CERTS from "./pkg/certs-module.js";
onmessage = async (msg) => { onmessage = async (msg) => {
console.debug("recieved demo:", msg); console.debug("recieved demo:", msg);
let [ let [
should_feature_test, should_feature_test,
should_multiparallel_test, should_multiparallel_test,
should_parallel_test, should_parallel_test,
should_multiperf_test, should_multiperf_test,
should_perf_test, should_perf_test,
should_ws_test, should_ws_test,
should_tls_test, should_tls_test,
should_udp_test, should_udp_test,
should_reconnect_test, should_reconnect_test,
should_perf2_test, should_perf2_test,
] = msg.data; ] = msg.data;
console.log( console.log(
"%cWASM is significantly slower with DevTools open!", "%cWASM is significantly slower with DevTools open!",
"color:red;font-size:3rem;font-weight:bold" "color:red;font-size:3rem;font-weight:bold"
); );
const log = (str) => { const log = (str) => {
console.log(str); console.log(str);
postMessage(str); postMessage(str);
} }
const plog = (str) => { const plog = (str) => {
console.log(str); console.log(str);
postMessage(JSON.stringify(str, null, 4)); postMessage(JSON.stringify(str, null, 4));
} }
const { EpoxyClient } = await epoxy(); const { EpoxyClient, EpoxyClientOptions, EpoxyHandlers } = await epoxy();
console.log("certs:", CERTS); console.log("certs:", CERTS);
const tconn0 = performance.now(); let epoxy_client_options = new EpoxyClientOptions();
// args: websocket url, user agent, redirect limit, certs epoxy_client_options.user_agent = navigator.userAgent;
let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10, CERTS);
const tconn1 = performance.now();
log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
// epoxy classes are inspectable let epoxy_client = new EpoxyClient("ws://localhost:4000", CERTS, epoxy_client_options);
console.log(epoxy_client);
// you can change the user agent and redirect limit in JS
epoxy_client.redirectLimit = 15;
const test_mux = async (url) => { const tconn0 = performance.now();
const t0 = performance.now(); await epoxy_client.replace_stream_provider();
await epoxy_client.fetch(url); const tconn1 = performance.now();
const t1 = performance.now(); log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
return t1 - t0;
};
const test_native = async (url) => { // epoxy classes are inspectable
const t0 = performance.now(); console.log(epoxy_client);
await fetch(url, { cache: "no-store" }); // you can change the user agent and redirect limit in JS
const t1 = performance.now(); epoxy_client.redirect_limit = 15;
return t1 - t0;
};
if (should_feature_test) { const test_mux = async (url) => {
let formdata = new FormData(); const t0 = performance.now();
formdata.append("a", "b"); await epoxy_client.fetch(url);
for (const url of [ const t1 = performance.now();
["https://httpbin.org/get", {}], return t1 - t0;
["https://httpbin.org/gzip", {}], };
["https://httpbin.org/brotli", {}],
["https://httpbin.org/redirect/11", {}],
["https://httpbin.org/redirect/1", { redirect: "manual" }],
["https://httpbin.org/post", { method: "POST", body: new URLSearchParams("a=b") }],
["https://httpbin.org/post", { method: "POST", body: formdata }],
["https://httpbin.org/post", { method: "POST", body: "a" }],
["https://httpbin.org/post", { method: "POST", body: (new TextEncoder()).encode("abc") }],
["https://httpbin.org/get", { headers: {"a": "b", "b": "c"} }],
["https://httpbin.org/get", { headers: new Headers({"a": "b", "b": "c"}) }]
]) {
let resp = await epoxy_client.fetch(url[0], url[1]);
console.warn(url, resp, Object.fromEntries(resp.headers));
log(await resp.text());
}
} else if (should_multiparallel_test) {
const num_tests = 10;
let total_mux_minus_native = 0;
for (const _ of Array(num_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_tests;
let total_native = 0; const test_native = async (url) => {
await Promise.all([...Array(num_tests).keys()].map(async i => { const t0 = performance.now();
log(`running native test ${i}`); await fetch(url, { cache: "no-store" });
return await test_native("https://httpbin.org/get"); const t1 = performance.now();
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); return t1 - t0;
total_native = total_native / num_tests; };
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); if (should_feature_test) {
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); let formdata = new FormData();
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); formdata.append("a", "b");
total_mux_minus_native += total_mux - total_native; for (const url of [
} ["https://httpbin.org/get", {}],
total_mux_minus_native = total_mux_minus_native / num_tests; ["https://httpbin.org/gzip", {}],
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); ["https://httpbin.org/brotli", {}],
} else if (should_parallel_test) { ["https://httpbin.org/redirect/11", {}],
const num_tests = 10; ["https://httpbin.org/redirect/1", { redirect: "manual" }],
["https://httpbin.org/post", { method: "POST", body: new URLSearchParams("a=b") }],
["https://httpbin.org/post", { method: "POST", body: formdata }],
["https://httpbin.org/post", { method: "POST", body: "a" }],
["https://httpbin.org/post", { method: "POST", body: (new TextEncoder()).encode("abc") }],
["https://httpbin.org/get", { headers: { "a": "b", "b": "c" } }],
["https://httpbin.org/get", { headers: new Headers({ "a": "b", "b": "c" }) }]
]) {
let resp = await epoxy_client.fetch(url[0], url[1]);
console.warn(url, resp, Object.fromEntries(resp.headers));
log(await resp.text());
}
} else if (should_multiparallel_test) {
const num_tests = 10;
let total_mux_minus_native = 0;
for (const _ of Array(num_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_tests;
let total_mux = 0; let total_native = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => { await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running mux test ${i}`); log(`running native test ${i}`);
return await test_mux("https://httpbin.org/get"); return await test_native("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_tests; total_native = total_native / num_tests;
let total_native = 0; log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
await Promise.all([...Array(num_tests).keys()].map(async i => { log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`running native test ${i}`); log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
return await test_native("https://httpbin.org/get"); total_mux_minus_native += total_mux - total_native;
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); }
total_native = total_native / num_tests; total_mux_minus_native = total_mux_minus_native / num_tests;
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
} else if (should_parallel_test) {
const num_tests = 10;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); let total_mux = 0;
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); log(`running mux test ${i}`);
} else if (should_multiperf_test) { return await test_mux("https://httpbin.org/get");
const num_tests = 10; })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
let total_mux_minus_native = 0; total_mux = total_mux / num_tests;
for (const _ of Array(num_tests).keys()) {
let total_mux = 0;
for (const i of Array(num_tests).keys()) {
log(`running mux test ${i}`);
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_native = 0; let total_native = 0;
for (const i of Array(num_tests).keys()) { await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running native test ${i}`); log(`running native test ${i}`);
total_native += await test_native("https://httpbin.org/get"); return await test_native("https://httpbin.org/get");
} })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
total_native = total_native / num_tests; total_native = total_native / num_tests;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
total_mux_minus_native += total_mux - total_native; } else if (should_multiperf_test) {
} const num_tests = 10;
total_mux_minus_native = total_mux_minus_native / num_tests; let total_mux_minus_native = 0;
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); for (const _ of Array(num_tests).keys()) {
} else if (should_perf_test) { let total_mux = 0;
const num_tests = 10; for (const i of Array(num_tests).keys()) {
log(`running mux test ${i}`);
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_mux = 0; let total_native = 0;
for (const i of Array(num_tests).keys()) { for (const i of Array(num_tests).keys()) {
log(`running mux test ${i}`); log(`running native test ${i}`);
total_mux += await test_mux("https://httpbin.org/get"); total_native += await test_native("https://httpbin.org/get");
} }
total_mux = total_mux / num_tests; total_native = total_native / num_tests;
let total_native = 0; log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
for (const i of Array(num_tests).keys()) { log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`running native test ${i}`); log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
total_native += await test_native("https://httpbin.org/get"); total_mux_minus_native += total_mux - total_native;
} }
total_native = total_native / num_tests; total_mux_minus_native = total_mux_minus_native / num_tests;
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
} else if (should_perf_test) {
const num_tests = 10;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); let total_mux = 0;
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); for (const i of Array(num_tests).keys()) {
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); log(`running mux test ${i}`);
} else if (should_ws_test) { total_mux += await test_mux("https://httpbin.org/get");
let ws = await epoxy_client.connect_ws( }
() => log("opened"), total_mux = total_mux / num_tests;
() => log("closed"),
err => console.error(err),
msg => log(msg),
"wss://echo.websocket.events",
[],
"localhost"
);
while (true) {
log("sending `data`");
await ws.send_text("data");
await (new Promise((res, _) => setTimeout(res, 50)));
}
} else if (should_tls_test) {
let decoder = new TextDecoder();
let ws = await epoxy_client.connect_tls(
() => log("opened"),
() => log("closed"),
err => console.error(err),
msg => { console.log(msg); log(decoder.decode(msg)) },
"google.com:443",
);
await ws.send("GET / HTTP 1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n");
await (new Promise((res, _) => setTimeout(res, 500)));
await ws.close();
} else if (should_udp_test) {
let decoder = new TextDecoder();
// tokio example: `cargo r --example echo-udp -- 127.0.0.1:5000`
let ws = await epoxy_client.connect_udp(
() => log("opened"),
() => log("closed"),
err => console.error(err),
msg => { console.log(msg); log(decoder.decode(msg)) },
"127.0.0.1:5000",
);
while (true) {
log("sending `data`");
await ws.send("data");
await (new Promise((res, _) => setTimeout(res, 50)));
}
} else if (should_reconnect_test) {
while (true) {
try {
await epoxy_client.fetch("https://httpbin.org/get");
} catch(e) {console.error(e)}
log("sent req");
await (new Promise((res, _) => setTimeout(res, 500)));
}
} else if (should_perf2_test) {
const num_outer_tests = 10;
const num_inner_tests = 50;
let total_mux_multi = 0;
for (const _ of Array(num_outer_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_inner_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_inner_tests;
log(`avg mux (${num_inner_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); let total_native = 0;
total_mux_multi += total_mux; for (const i of Array(num_tests).keys()) {
} log(`running native test ${i}`);
total_mux_multi = total_mux_multi / num_outer_tests; total_native += await test_native("https://httpbin.org/get");
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); }
total_native = total_native / num_tests;
} else { log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
let resp = await epoxy_client.fetch("https://www.example.com/"); log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
console.log(resp, Object.fromEntries(resp.headers)); log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
log(await resp.text()); } else if (should_ws_test) {
} let handlers = new EpoxyHandlers(
log("done"); () => log("opened"),
() => log("closed"),
err => console.error(err),
msg => log(`got "${msg}"`)
);
let ws = await epoxy_client.connect_websocket(
handlers,
"wss://echo.websocket.events",
[],
);
while (true) {
log("sending `data`");
await ws.send("data");
await (new Promise((res, _) => setTimeout(res, 10)));
}
} else if (should_tls_test) {
let decoder = new TextDecoder();
let handlers = new EpoxyHandlers(
() => log("opened"),
() => log("closed"),
err => console.error(err),
msg => { console.log(msg); console.log(decoder.decode(msg).split("\r\n\r\n")[1].length); log(decoder.decode(msg)) },
);
let ws = await epoxy_client.connect_tls(
handlers,
"google.com:443",
);
await ws.send("GET / HTTP 1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n");
await (new Promise((res, _) => setTimeout(res, 500)));
await ws.close();
} else if (should_udp_test) {
let decoder = new TextDecoder();
let handlers = new EpoxyHandlers(
() => log("opened"),
() => log("closed"),
err => console.error(err),
msg => { console.log(msg); log(decoder.decode(msg)) },
);
// tokio example: `cargo r --example echo-udp -- 127.0.0.1:5000`
let ws = await epoxy_client.connect_udp(
handlers,
"127.0.0.1:5000",
);
while (true) {
log("sending `data`");
await ws.send("data");
await (new Promise((res, _) => setTimeout(res, 10)));
}
} else if (should_reconnect_test) {
while (true) {
try {
await epoxy_client.fetch("https://httpbin.org/get");
} catch (e) { console.error(e) }
log("sent req");
await (new Promise((res, _) => setTimeout(res, 500)));
}
} else if (should_perf2_test) {
const num_outer_tests = 10;
const num_inner_tests = 50;
let total_mux_multi = 0;
for (const _ of Array(num_outer_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_inner_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_inner_tests;
log(`avg mux (${num_inner_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
total_mux_multi += total_mux;
}
total_mux_multi = total_mux_multi / num_outer_tests;
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
} else {
let resp = await epoxy_client.fetch("https://www.example.com/");
console.log(resp, Object.fromEntries(resp.headers));
log(await resp.text());
}
log("done");
}; };

View file

@ -1,4 +1,4 @@
use bytes::{BufMut, BytesMut}; use bytes::{buf::UninitSlice, BufMut, BytesMut};
use futures_util::{ use futures_util::{
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
}; };
@ -18,6 +18,7 @@ pub struct EpoxyIoStream {
onerror: Function, onerror: Function,
} }
#[wasm_bindgen]
impl EpoxyIoStream { impl EpoxyIoStream {
pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self {
let (mut rx, tx) = stream.split(); let (mut rx, tx) = stream.split();
@ -32,16 +33,23 @@ impl EpoxyIoStream {
let onerror_cloned = onerror.clone(); let onerror_cloned = onerror.clone();
// similar to tokio::io::ReaderStream // similar to tokio_util::io::ReaderStream
spawn_local(async move { spawn_local(async move {
let mut buf = BytesMut::with_capacity(4096); let mut buf = BytesMut::with_capacity(4096);
loop { loop {
match rx.read(buf.as_mut()).await { match rx
.read(unsafe {
std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut())
})
.await
{
Ok(cnt) => { Ok(cnt) => {
unsafe { buf.advance_mut(cnt) }; if cnt > 0 {
unsafe { buf.advance_mut(cnt) };
let _ = onmessage let _ = onmessage
.call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref()));
}
} }
Err(err) => { Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
@ -101,6 +109,7 @@ pub struct EpoxyUdpStream {
onerror: Function, onerror: Function,
} }
#[wasm_bindgen]
impl EpoxyUdpStream { impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (tx, mut rx) = stream.split(); let (tx, mut rx) = stream.split();

View file

@ -37,23 +37,23 @@ type HttpBody = http_body_util::Full<Bytes>;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum EpoxyError { pub enum EpoxyError {
#[error(transparent)] #[error("Invalid DNS name: {0:?}")]
InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError),
#[error(transparent)] #[error("Wisp: {0:?}")]
Wisp(#[from] wisp_mux::WispError), Wisp(#[from] wisp_mux::WispError),
#[error(transparent)] #[error("IO: {0:?}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error(transparent)] #[error("HTTP: {0:?}")]
Http(#[from] http::Error), Http(#[from] http::Error),
#[error(transparent)] #[error("Hyper client: {0:?}")]
HyperClient(#[from] hyper_util_wasm::client::legacy::Error), HyperClient(#[from] hyper_util_wasm::client::legacy::Error),
#[error(transparent)] #[error("Hyper: {0:?}")]
Hyper(#[from] hyper::Error), Hyper(#[from] hyper::Error),
#[error(transparent)] #[error("HTTP ToStr: {0:?}")]
ToStr(#[from] http::header::ToStrError), ToStr(#[from] http::header::ToStrError),
#[error(transparent)] #[error("Getrandom: {0:?}")]
GetRandom(#[from] getrandom::Error), GetRandom(#[from] getrandom::Error),
#[error(transparent)] #[error("Fastwebsockets: {0:?}")]
FastWebSockets(#[from] fastwebsockets::WebSocketError), FastWebSockets(#[from] fastwebsockets::WebSocketError),
#[error("Invalid URL scheme")] #[error("Invalid URL scheme")]
@ -196,13 +196,14 @@ impl EpoxyHandlers {
} }
} }
#[wasm_bindgen] #[wasm_bindgen(inspectable)]
pub struct EpoxyClient { pub struct EpoxyClient {
stream_provider: Arc<StreamProvider>, stream_provider: Arc<StreamProvider>,
client: Client<StreamProviderService, HttpBody>, client: Client<StreamProviderService, HttpBody>,
redirect_limit: usize, pub redirect_limit: usize,
user_agent: String, #[wasm_bindgen(getter_with_clone)]
pub user_agent: String,
} }
#[wasm_bindgen] #[wasm_bindgen]
@ -235,6 +236,10 @@ impl EpoxyClient {
}) })
} }
pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> {
self.stream_provider.replace_client().await
}
pub async fn connect_websocket( pub async fn connect_websocket(
&self, &self,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,

View file

@ -4,7 +4,11 @@ use futures_rustls::{
rustls::{ClientConfig, RootCertStore}, rustls::{ClientConfig, RootCertStore},
TlsConnector, TlsStream, TlsConnector, TlsStream,
}; };
use futures_util::{future::Either, lock::Mutex, AsyncRead, AsyncWrite, Future}; use futures_util::{
future::Either,
lock::{Mutex, MutexGuard},
AsyncRead, AsyncWrite, Future,
};
use hyper_util_wasm::client::legacy::connect::{Connected, Connection}; use hyper_util_wasm::client::legacy::connect::{Connected, Connection};
use js_sys::{Array, Reflect, Uint8Array}; use js_sys::{Array, Reflect, Uint8Array};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
@ -14,7 +18,7 @@ use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, IoStream, MuxStreamIo, StreamType, WispError, ClientMux, IoStream, MuxStreamIo, StreamType,
}; };
use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
@ -75,7 +79,10 @@ impl StreamProvider {
}) })
} }
async fn create_client(&self) -> Result<(), EpoxyError> { async fn create_client(
&self,
mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> = let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder())]; vec![Box::new(UdpProtocolExtensionBuilder())];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
@ -93,7 +100,7 @@ impl StreamProvider {
} else { } else {
client.with_no_required_extensions() client.with_no_required_extensions()
}; };
self.current_client.lock().await.replace(mux); locked.replace(mux);
let current_client = self.current_client.clone(); let current_client = self.current_client.clone();
spawn_local(async move { spawn_local(async move {
fut.await; fut.await;
@ -102,6 +109,10 @@ impl StreamProvider {
Ok(()) Ok(())
} }
pub async fn replace_client(&self) -> Result<(), EpoxyError> {
self.create_client(self.current_client.lock().await).await
}
pub async fn get_stream( pub async fn get_stream(
&self, &self,
stream_type: StreamType, stream_type: StreamType,
@ -109,13 +120,14 @@ impl StreamProvider {
port: u16, port: u16,
) -> Result<ProviderUnencryptedStream, EpoxyError> { ) -> Result<ProviderUnencryptedStream, EpoxyError> {
Box::pin(async { Box::pin(async {
if let Some(mux) = self.current_client.lock().await.as_ref() { let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() {
Ok(mux Ok(mux
.client_new_stream(stream_type, host, port) .client_new_stream(stream_type, host, port)
.await? .await?
.into_io()) .into_io())
} else { } else {
self.create_client().await?; self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await self.get_stream(stream_type, host, port).await
} }
}) })
@ -231,12 +243,16 @@ impl Service<hyper::Uri> for StreamProviderService {
let provider = self.0.clone(); let provider = self.0.clone();
Box::pin(async move { Box::pin(async move {
let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?; let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?;
let host = req.host().ok_or(WispError::UriHasNoHost)?.to_string(); let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string();
let port = req.port_u16().ok_or(WispError::UriHasNoPort)?; let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme {
"https" | "wss" => Ok(443),
"http" | "ws" => Ok(80),
_ => Err(EpoxyError::NoUrlPort),
})?;
Ok(HyperIo { Ok(HyperIo {
inner: match scheme { inner: match scheme {
"https" => Either::Left(provider.get_tls_stream(host, port).await?), "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?),
"http" => { "http" | "ws" => {
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
} }
_ => return Err(EpoxyError::InvalidUrlScheme), _ => return Err(EpoxyError::InvalidUrlScheme),

View file

@ -30,6 +30,7 @@ pub struct EpoxyWebSocket {
onerror: Function, onerror: Function,
} }
#[wasm_bindgen]
impl EpoxyWebSocket { impl EpoxyWebSocket {
pub(crate) async fn connect( pub(crate) async fn connect(
client: &EpoxyClient, client: &EpoxyClient,