diff --git a/client/Cargo.toml b/client/Cargo.toml index acdb819..55f64ff 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -30,7 +30,7 @@ wasm-streams = "0.4.0" either = "1.9.0" tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } -fastwebsockets = { version = "0.6.0", features=[]} +fastwebsockets = { version = "0.6.0" } rand = "0.8.5" base64 = "0.21.7" diff --git a/client/build.sh b/client/build.sh index e3e9234..03c053b 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --no-typescript --target no-modules --out-dir out/ ../t echo "[ws] bindgen finished" mv out/wstcp_client_bg.wasm out/wstcp_client_unoptimized.wasm -wasm-opt out/wstcp_client_unoptimized.wasm -o out/wstcp_client_bg.wasm +time wasm-opt -O4 out/wstcp_client_unoptimized.wasm -o out/wstcp_client_bg.wasm echo "[ws] optimized" AUTOGENERATED_SOURCE=$(<"out/wstcp_client.js") diff --git a/client/disable-the-fucking-borrow-checker.mjs b/client/disable-the-fucking-borrow-checker.mjs deleted file mode 100644 index f93d512..0000000 --- a/client/disable-the-fucking-borrow-checker.mjs +++ /dev/null @@ -1,16 +0,0 @@ -import fs from "fs"; -import path from "path"; -import binaryen from "binaryen"; -import { fileURLToPath } from 'url'; -const __filename = fileURLToPath(import.meta.url); - -const __dirname = path.dirname(__filename); -let fp = path.resolve(__dirname, './wat.wat'); -const originBuffer = fs.readFileSync(fp).toString(); - -// const wasm = binaryen.readBinary(originBuffer); -const wast = originBuffer - .replace(/\(br_if \$label\$1[\s\n]+?\(i32.eq\n[\s\S\n]+?i32.const -1\)[\s\n]+\)[\s\n]+\)/g, ''); -// const distBuffer = binaryen.parseText(wast).emitBinary(); - -fs.writeFileSync(fp, wast); diff --git a/client/src/lib.rs b/client/src/lib.rs index 604d0ac..d1fc2b2 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -3,15 +3,13 @@ mod utils; mod tokioio; mod wrappers; +mod websocket; -use base64::{engine::general_purpose::STANDARD, Engine}; -use fastwebsockets::{Frame, OpCode, Payload, Role, WebSocket}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use wrappers::{IncomingBody, WsStreamWrapper}; -use std::{io::Read, ptr::null_mut, str::from_utf8, sync::Arc}; +use std::sync::Arc; use async_compression::tokio::bufread as async_comp; use bytes::Bytes; @@ -19,10 +17,10 @@ use futures_util::StreamExt; use http::{uri, HeaderName, HeaderValue, Request, Response}; use hyper::{ body::Incoming, - client::conn::http1::{handshake, Builder}, + client::conn::http1::Builder, Uri, }; -use js_sys::{Array, Function, Object, Reflect, Uint8Array}; +use js_sys::{Array, Object, Reflect, Uint8Array}; use penguin_mux_wasm::{Multiplexor, MuxStream}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_util::{ @@ -47,7 +45,7 @@ enum EpxCompression { type EpxTlsStream = TlsStream>; type EpxUnencryptedStream = MuxStream; -type EpxStream = Either; +type EpxStream = Either; async fn send_req( req: http::Request, @@ -115,171 +113,6 @@ async fn start() { utils::set_panic_hook(); } -#[wasm_bindgen] -pub struct WsWebSocket { - onopen: Function, - onclose: Function, - onerror: Function, - onmessage: Function, - ws: Option>, -} - -async fn wtf(iop: *mut EpxStream) { - let mut t = false; - unsafe { - let io = &mut *iop; - let mut v = vec![]; - loop { - let r = io.read_u8().await; - - if let Ok(u) = r { - v.push(u); - if t && u as char == '\r' { - let r = io.read_u8().await; - break; - } - if u as char == '\n' { - t = true; - } else { - t = false; - } - } else { - break; - } - } - log!("{}", &from_utf8(&v).unwrap().to_string()); - } -} - -#[wasm_bindgen] -impl WsWebSocket { - #[wasm_bindgen(constructor)] - pub fn new( - onopen: Function, - onclose: Function, - onmessage: Function, - onerror: Function, - ) -> Result { - Ok(Self { - onopen, - onclose, - onerror, - onmessage, - ws: None, - }) - } - - #[wasm_bindgen] - pub async fn connect( - &mut self, - tcp: &mut EpoxyClient, - url: String, - protocols: Vec, - origin: String, - ) -> Result<(), JsError> { - self.onopen.call0(&Object::default()); - let uri = url.parse::().replace_err("Failed to parse URL")?; - let mut io = tcp.get_http_io(&uri).await?; - - let r: [u8; 16] = rand::random(); - let key = STANDARD.encode(&r); - - let pathstr = if let Some(p) = uri.path_and_query() { - p.to_string() - } else { - uri.path().to_string() - }; - - io.write(format!("GET {} HTTP/1.1\r\n", pathstr).as_bytes()) - .await; - io.write(b"Sec-WebSocket-Version: 13\r\n").await; - io.write(format!("Sec-WebSocket-Key: {}\r\n", key).as_bytes()) - .await; - io.write(b"Connection: Upgrade\r\n").await; - io.write(b"Upgrade: websocket\r\n").await; - io.write(format!("Origin: {}\r\n", origin).as_bytes()).await; - io.write(format!("Host: {}\r\n", uri.host().unwrap()).as_bytes()) - .await; - io.write(b"\r\n").await; - - let iop: *mut EpxStream = &mut io; - wtf(iop).await; - - let mut ws = WebSocket::after_handshake(io, fastwebsockets::Role::Client); - ws.set_writev(false); - ws.set_auto_close(true); - ws.set_auto_pong(true); - - self.ws = Some(ws); - - Ok(()) - } - - #[wasm_bindgen] - pub fn ptr(&mut self) -> *mut WsWebSocket { - self - } - - #[wasm_bindgen] - pub async fn send(&mut self, payload: String) -> Result<(), JsError> { - let Some(ws) = self.ws.as_mut() else { - return Err(JsError::new("Tried to send() before handshake!")); - }; - ws.write_frame(Frame::new( - true, - OpCode::Text, - None, - Payload::Owned(payload.as_bytes().to_vec()), - )) - .await - .unwrap(); - // .replace_err("Failed to send WsWebSocket payload")?; - Ok(()) - } - - #[wasm_bindgen] - pub async fn recv(&mut self) -> Result<(), JsError> { - let Some(ws) = self.ws.as_mut() else { - return Err(JsError::new("Tried to recv() before handshake!")); - }; - loop { - let Ok(frame) = ws.read_frame().await else { - break; - }; - - match frame.opcode { - OpCode::Text => { - if let Ok(str) = from_utf8(&frame.payload) { - self.onmessage - .call1(&JsValue::null(), &jval!(str)) - .replace_err("missing onmessage handler")?; - } - } - OpCode::Binary => { - self.onmessage - .call1( - &JsValue::null(), - &jval!(Uint8Array::from(frame.payload.to_vec().as_slice())), - ) - .replace_err("missing onmessage handler")?; - } - - _ => panic!("unknown opcode {:?}", frame.opcode), - } - } - self.onclose - .call0(&JsValue::null()) - .replace_err("missing onclose handler")?; - Ok(()) - } -} - -#[wasm_bindgen] -pub async fn send(pointer: *mut WsWebSocket, payload: String) -> Result<(), JsError> { - let tcp = unsafe { &mut *pointer }; - tcp.send(payload).await -} - #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, diff --git a/client/src/web/index.js b/client/src/web/index.js index e422500..6530f89 100644 --- a/client/src/web/index.js +++ b/client/src/web/index.js @@ -1,69 +1,73 @@ (async () => { - console.log( - "%cWASM is significantly slower with DevTools open!", - "color:red;font-size:2rem;font-weight:bold" - ); + console.log( + "%cWASM is significantly slower with DevTools open!", + "color:red;font-size:2rem;font-weight:bold" + ); - const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test"); - const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); + const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test"); + const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); + const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test"); - await wasm_bindgen("./wstcp_client_bg.wasm"); + await wasm_bindgen("./wstcp_client_bg.wasm"); - const tconn0 = performance.now(); - // args: websocket url, user agent, redirect limit - let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10); - const tconn1 = performance.now(); - console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); + const tconn0 = performance.now(); + // args: websocket url, user agent, redirect limit + let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10); + const tconn1 = performance.now(); + console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); - if (should_feature_test) { - for (const url of [ - ["https://httpbin.org/get", {}], - ["https://httpbin.org/gzip", {}], - ["https://httpbin.org/brotli", {}], - ["https://httpbin.org/redirect/11", {}], - ["https://httpbin.org/redirect/1", { redirect: "manual" }] - ]) { - let resp = await wstcp.fetch(url[0], url[1]); - console.warn(url, resp, Object.fromEntries(resp.headers)); - console.warn(await resp.text()); + if (should_feature_test) { + for (const url of [ + ["https://httpbin.org/get", {}], + ["https://httpbin.org/gzip", {}], + ["https://httpbin.org/brotli", {}], + ["https://httpbin.org/redirect/11", {}], + ["https://httpbin.org/redirect/1", { redirect: "manual" }] + ]) { + let resp = await wstcp.fetch(url[0], url[1]); + console.warn(url, resp, Object.fromEntries(resp.headers)); + console.warn(await resp.text()); + } + } else if (should_perf_test) { + const test_mux = async (url) => { + const t0 = performance.now(); + await wstcp.fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; + + const test_native = async (url) => { + const t0 = performance.now(); + await fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; + + const num_tests = 10; + + let total_mux = 0; + for (const _ of Array(num_tests).keys()) { + total_mux += await test_mux("https://httpbin.org/get"); + } + total_mux = total_mux / num_tests; + + let total_native = 0; + for (const _ of Array(num_tests).keys()) { + total_native += await test_native("https://httpbin.org/get"); + } + total_native = total_native / num_tests; + + console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); + console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); + console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + } else if (should_ws_test) { + let ws = await new wasm_bindgen.WsWebSocket(() => console.log("opened"), () => console.log("closed"), msg => console.log(msg), wstcp, "ws://localhost:9000", [], "localhost"); + await ws.send("data"); + } else { + let resp = await wstcp.fetch("https://httpbin.org/get"); + console.warn(resp, Object.fromEntries(resp.headers)); + console.warn(await resp.text()); } - } else if (should_perf_test) { - const test_mux = async (url) => { - const t0 = performance.now(); - await wstcp.fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - - const test_native = async (url) => { - const t0 = performance.now(); - await fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - - const num_tests = 10; - - let total_mux = 0; - for (const _ of Array(num_tests).keys()) { - total_mux += await test_mux("https://httpbin.org/get"); - } - total_mux = total_mux / num_tests; - - let total_native = 0; - for (const _ of Array(num_tests).keys()) { - total_native += await test_native("https://httpbin.org/get"); - } - total_native = total_native / num_tests; - - console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); - console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); - console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); - } else { - let resp = await wstcp.fetch("https://httpbin.org/get"); - console.warn(resp, Object.fromEntries(resp.headers)); - console.warn(await resp.text()); - } - alert("you can open console now"); + if (!should_ws_test) alert("you can open console now"); })(); diff --git a/client/src/websocket.rs b/client/src/websocket.rs new file mode 100644 index 0000000..2329c1d --- /dev/null +++ b/client/src/websocket.rs @@ -0,0 +1,176 @@ +use crate::*; + +use base64::{engine::general_purpose::STANDARD, Engine}; +use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketError}; +use http_body_util::Empty; +use hyper::{ + client::conn::http1 as hyper_conn, + header::{CONNECTION, UPGRADE}, + StatusCode, +}; +use js_sys::Function; +use std::str::from_utf8; +use tokio::sync::{mpsc, oneshot}; + +enum EpxMsg { + SendText(String, oneshot::Sender>), + Close, +} + +#[wasm_bindgen] +pub struct EpxWebSocket { + msg_sender: mpsc::Sender, +} + +#[wasm_bindgen] +impl EpxWebSocket { + #[wasm_bindgen(constructor)] + pub async fn connect( + onopen: Function, + onclose: Function, + onmessage: Function, + tcp: &EpoxyClient, + url: String, + protocols: Vec, + origin: String, + ) -> Result { + let url = Uri::try_from(url).replace_err("Failed to parse URL")?; + let host = url.host().replace_err("URL must have a host")?; + + let rand: [u8; 16] = rand::random(); + let key = STANDARD.encode(rand); + + let mut builder = Request::builder() + .method("GET") + .uri(url.clone()) + .header("Host", host) + .header("Origin", origin) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header("Sec-WebSocket-Key", key) + .header("Sec-WebSocket-Version", "13"); + + if !protocols.is_empty() { + builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", ")); + } + + let req = builder.body(Empty::::new())?; + + let stream = tcp.get_http_io(&url).await?; + + let (mut sender, conn) = + hyper_conn::handshake::, Empty>(TokioIo::new(stream)) + .await?; + + wasm_bindgen_futures::spawn_local(async move { + if let Err(e) = conn.with_upgrades().await { + error!("wstcp: error in muxed hyper connection (ws)! {:?}", e); + } + }); + + let mut response = sender.send_request(req).await?; + verify(&response)?; + + let mut ws = WebSocket::after_handshake( + TokioIo::new(hyper::upgrade::on(&mut response).await?), + Role::Client, + ); + + let (msg_sender, mut rx) = mpsc::channel(1); + + wasm_bindgen_futures::spawn_local(async move { + loop { + tokio::select! { + frame = ws.read_frame() => { + if let Ok(frame) = frame { + error!("hiii"); + match frame.opcode { + OpCode::Text => { + if let Ok(str) = from_utf8(&frame.payload) { + let _ = onmessage.call1(&JsValue::null(), &jval!(str)); + } + } + OpCode::Binary => { + let _ = onmessage.call1( + &JsValue::null(), + &jval!(Uint8Array::from(frame.payload.to_vec().as_slice())), + ); + } + OpCode::Close => { + let _ = onclose.call0(&JsValue::null()); + break; + } + _ => panic!("unknown opcode {:?}", frame.opcode), + } + } + } + msg = rx.recv() => { + if let Some(msg) = msg { + match msg { + EpxMsg::SendText(payload, err) => { + let _ = err.send(ws.write_frame(Frame::text( + Payload::Owned(payload.as_bytes().to_vec()), + )) + .await); + } + EpxMsg::Close => break, + } + } else { + break; + } + } + } + } + let _ = ws.write_frame(Frame::close(CloseCode::Normal.into(), b"")) + .await; + }); + + onopen + .call0(&Object::default()) + .replace_err("Failed to call onopen")?; + + Ok(Self { msg_sender }) + } + + #[wasm_bindgen] + pub async fn send(&mut self, payload: String) -> Result<(), JsError> { + let (tx, rx) = oneshot::channel(); + self.msg_sender.send(EpxMsg::SendText(payload, tx)).await?; + Ok(rx.await??) + } + + #[wasm_bindgen] + pub async fn close(&mut self) -> Result<(), JsError> { + self.msg_sender.send(EpxMsg::Close).await?; + Ok(()) + } +} + +// https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189 +fn verify(response: &Response) -> Result<(), JsError> { + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(jerr!("wstcpws connect: Invalid status code")); + } + + let headers = response.headers(); + + if !headers + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + { + return Err(jerr!("wstcpws connect: Invalid upgrade header")); + } + + if !headers + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + { + return Err(jerr!("wstcpws connect: Invalid upgrade header")); + } + + Ok(()) +}