fix the websocket impl

This commit is contained in:
Toshit Chawda 2024-01-14 21:08:35 -08:00
parent fc17bacb9d
commit 1b7181d78f
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 248 additions and 251 deletions

View file

@ -30,7 +30,7 @@ wasm-streams = "0.4.0"
either = "1.9.0" either = "1.9.0"
tokio-util = { version = "0.7.10", features = ["io"] } tokio-util = { version = "0.7.10", features = ["io"] }
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
fastwebsockets = { version = "0.6.0", features=[]} fastwebsockets = { version = "0.6.0" }
rand = "0.8.5" rand = "0.8.5"
base64 = "0.21.7" base64 = "0.21.7"

View file

@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --no-typescript --target no-modules --out-dir out/ ../t
echo "[ws] bindgen finished" echo "[ws] bindgen finished"
mv out/wstcp_client_bg.wasm out/wstcp_client_unoptimized.wasm mv out/wstcp_client_bg.wasm out/wstcp_client_unoptimized.wasm
wasm-opt out/wstcp_client_unoptimized.wasm -o out/wstcp_client_bg.wasm time wasm-opt -O4 out/wstcp_client_unoptimized.wasm -o out/wstcp_client_bg.wasm
echo "[ws] optimized" echo "[ws] optimized"
AUTOGENERATED_SOURCE=$(<"out/wstcp_client.js") AUTOGENERATED_SOURCE=$(<"out/wstcp_client.js")

View file

@ -1,16 +0,0 @@
import fs from "fs";
import path from "path";
import binaryen from "binaryen";
import { fileURLToPath } from 'url';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
let fp = path.resolve(__dirname, './wat.wat');
const originBuffer = fs.readFileSync(fp).toString();
// const wasm = binaryen.readBinary(originBuffer);
const wast = originBuffer
.replace(/\(br_if \$label\$1[\s\n]+?\(i32.eq\n[\s\S\n]+?i32.const -1\)[\s\n]+\)[\s\n]+\)/g, '');
// const distBuffer = binaryen.parseText(wast).emitBinary();
fs.writeFileSync(fp, wast);

View file

@ -3,15 +3,13 @@
mod utils; mod utils;
mod tokioio; mod tokioio;
mod wrappers; mod wrappers;
mod websocket;
use base64::{engine::general_purpose::STANDARD, Engine};
use fastwebsockets::{Frame, OpCode, Payload, Role, WebSocket};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokioio::TokioIo; use tokioio::TokioIo;
use utils::{ReplaceErr, UriExt}; use utils::{ReplaceErr, UriExt};
use wrappers::{IncomingBody, WsStreamWrapper}; use wrappers::{IncomingBody, WsStreamWrapper};
use std::{io::Read, ptr::null_mut, str::from_utf8, sync::Arc}; use std::sync::Arc;
use async_compression::tokio::bufread as async_comp; use async_compression::tokio::bufread as async_comp;
use bytes::Bytes; use bytes::Bytes;
@ -19,10 +17,10 @@ use futures_util::StreamExt;
use http::{uri, HeaderName, HeaderValue, Request, Response}; use http::{uri, HeaderName, HeaderValue, Request, Response};
use hyper::{ use hyper::{
body::Incoming, body::Incoming,
client::conn::http1::{handshake, Builder}, client::conn::http1::Builder,
Uri, Uri,
}; };
use js_sys::{Array, Function, Object, Reflect, Uint8Array}; use js_sys::{Array, Object, Reflect, Uint8Array};
use penguin_mux_wasm::{Multiplexor, MuxStream}; use penguin_mux_wasm::{Multiplexor, MuxStream};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
use tokio_util::{ use tokio_util::{
@ -47,7 +45,7 @@ enum EpxCompression {
type EpxTlsStream = TlsStream<MuxStream<WsStreamWrapper>>; type EpxTlsStream = TlsStream<MuxStream<WsStreamWrapper>>;
type EpxUnencryptedStream = MuxStream<WsStreamWrapper>; type EpxUnencryptedStream = MuxStream<WsStreamWrapper>;
type EpxStream = Either<WsTcpTlsStream, WsTcpUnencryptedStream>; type EpxStream = Either<EpxTlsStream, EpxUnencryptedStream>;
async fn send_req( async fn send_req(
req: http::Request<HttpBody>, req: http::Request<HttpBody>,
@ -115,171 +113,6 @@ async fn start() {
utils::set_panic_hook(); utils::set_panic_hook();
} }
#[wasm_bindgen]
pub struct WsWebSocket {
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
ws: Option<WebSocket<EpxStream>>,
}
async fn wtf(iop: *mut EpxStream) {
let mut t = false;
unsafe {
let io = &mut *iop;
let mut v = vec![];
loop {
let r = io.read_u8().await;
if let Ok(u) = r {
v.push(u);
if t && u as char == '\r' {
let r = io.read_u8().await;
break;
}
if u as char == '\n' {
t = true;
} else {
t = false;
}
} else {
break;
}
}
log!("{}", &from_utf8(&v).unwrap().to_string());
}
}
#[wasm_bindgen]
impl WsWebSocket {
#[wasm_bindgen(constructor)]
pub fn new(
onopen: Function,
onclose: Function,
onmessage: Function,
onerror: Function,
) -> Result<WsWebSocket, JsError> {
Ok(Self {
onopen,
onclose,
onerror,
onmessage,
ws: None,
})
}
#[wasm_bindgen]
pub async fn connect(
&mut self,
tcp: &mut EpoxyClient,
url: String,
protocols: Vec<String>,
origin: String,
) -> Result<(), JsError> {
self.onopen.call0(&Object::default());
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?;
let mut io = tcp.get_http_io(&uri).await?;
let r: [u8; 16] = rand::random();
let key = STANDARD.encode(&r);
let pathstr = if let Some(p) = uri.path_and_query() {
p.to_string()
} else {
uri.path().to_string()
};
io.write(format!("GET {} HTTP/1.1\r\n", pathstr).as_bytes())
.await;
io.write(b"Sec-WebSocket-Version: 13\r\n").await;
io.write(format!("Sec-WebSocket-Key: {}\r\n", key).as_bytes())
.await;
io.write(b"Connection: Upgrade\r\n").await;
io.write(b"Upgrade: websocket\r\n").await;
io.write(format!("Origin: {}\r\n", origin).as_bytes()).await;
io.write(format!("Host: {}\r\n", uri.host().unwrap()).as_bytes())
.await;
io.write(b"\r\n").await;
let iop: *mut EpxStream = &mut io;
wtf(iop).await;
let mut ws = WebSocket::after_handshake(io, fastwebsockets::Role::Client);
ws.set_writev(false);
ws.set_auto_close(true);
ws.set_auto_pong(true);
self.ws = Some(ws);
Ok(())
}
#[wasm_bindgen]
pub fn ptr(&mut self) -> *mut WsWebSocket {
self
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: String) -> Result<(), JsError> {
let Some(ws) = self.ws.as_mut() else {
return Err(JsError::new("Tried to send() before handshake!"));
};
ws.write_frame(Frame::new(
true,
OpCode::Text,
None,
Payload::Owned(payload.as_bytes().to_vec()),
))
.await
.unwrap();
// .replace_err("Failed to send WsWebSocket payload")?;
Ok(())
}
#[wasm_bindgen]
pub async fn recv(&mut self) -> Result<(), JsError> {
let Some(ws) = self.ws.as_mut() else {
return Err(JsError::new("Tried to recv() before handshake!"));
};
loop {
let Ok(frame) = ws.read_frame().await else {
break;
};
match frame.opcode {
OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) {
self.onmessage
.call1(&JsValue::null(), &jval!(str))
.replace_err("missing onmessage handler")?;
}
}
OpCode::Binary => {
self.onmessage
.call1(
&JsValue::null(),
&jval!(Uint8Array::from(frame.payload.to_vec().as_slice())),
)
.replace_err("missing onmessage handler")?;
}
_ => panic!("unknown opcode {:?}", frame.opcode),
}
}
self.onclose
.call0(&JsValue::null())
.replace_err("missing onclose handler")?;
Ok(())
}
}
#[wasm_bindgen]
pub async fn send(pointer: *mut WsWebSocket, payload: String) -> Result<(), JsError> {
let tcp = unsafe { &mut *pointer };
tcp.send(payload).await
}
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyClient { pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>, rustls_config: Arc<rustls::ClientConfig>,

View file

@ -1,69 +1,73 @@
(async () => { (async () => {
console.log( console.log(
"%cWASM is significantly slower with DevTools open!", "%cWASM is significantly slower with DevTools open!",
"color:red;font-size:2rem;font-weight:bold" "color:red;font-size:2rem;font-weight:bold"
); );
const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test"); const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test");
const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test");
const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test");
await wasm_bindgen("./wstcp_client_bg.wasm"); await wasm_bindgen("./wstcp_client_bg.wasm");
const tconn0 = performance.now(); const tconn0 = performance.now();
// args: websocket url, user agent, redirect limit // args: websocket url, user agent, redirect limit
let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10); let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10);
const tconn1 = performance.now(); const tconn1 = performance.now();
console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
if (should_feature_test) { if (should_feature_test) {
for (const url of [ for (const url of [
["https://httpbin.org/get", {}], ["https://httpbin.org/get", {}],
["https://httpbin.org/gzip", {}], ["https://httpbin.org/gzip", {}],
["https://httpbin.org/brotli", {}], ["https://httpbin.org/brotli", {}],
["https://httpbin.org/redirect/11", {}], ["https://httpbin.org/redirect/11", {}],
["https://httpbin.org/redirect/1", { redirect: "manual" }] ["https://httpbin.org/redirect/1", { redirect: "manual" }]
]) { ]) {
let resp = await wstcp.fetch(url[0], url[1]); let resp = await wstcp.fetch(url[0], url[1]);
console.warn(url, resp, Object.fromEntries(resp.headers)); console.warn(url, resp, Object.fromEntries(resp.headers));
console.warn(await resp.text()); console.warn(await resp.text());
}
} else if (should_perf_test) {
const test_mux = async (url) => {
const t0 = performance.now();
await wstcp.fetch(url);
const t1 = performance.now();
return t1 - t0;
};
const test_native = async (url) => {
const t0 = performance.now();
await fetch(url);
const t1 = performance.now();
return t1 - t0;
};
const num_tests = 10;
let total_mux = 0;
for (const _ of Array(num_tests).keys()) {
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_native = 0;
for (const _ of Array(num_tests).keys()) {
total_native += await test_native("https://httpbin.org/get");
}
total_native = total_native / num_tests;
console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`);
console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`);
console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
} else if (should_ws_test) {
let ws = await new wasm_bindgen.WsWebSocket(() => console.log("opened"), () => console.log("closed"), msg => console.log(msg), wstcp, "ws://localhost:9000", [], "localhost");
await ws.send("data");
} else {
let resp = await wstcp.fetch("https://httpbin.org/get");
console.warn(resp, Object.fromEntries(resp.headers));
console.warn(await resp.text());
} }
} else if (should_perf_test) { if (!should_ws_test) alert("you can open console now");
const test_mux = async (url) => {
const t0 = performance.now();
await wstcp.fetch(url);
const t1 = performance.now();
return t1 - t0;
};
const test_native = async (url) => {
const t0 = performance.now();
await fetch(url);
const t1 = performance.now();
return t1 - t0;
};
const num_tests = 10;
let total_mux = 0;
for (const _ of Array(num_tests).keys()) {
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_native = 0;
for (const _ of Array(num_tests).keys()) {
total_native += await test_native("https://httpbin.org/get");
}
total_native = total_native / num_tests;
console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`);
console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`);
console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
} else {
let resp = await wstcp.fetch("https://httpbin.org/get");
console.warn(resp, Object.fromEntries(resp.headers));
console.warn(await resp.text());
}
alert("you can open console now");
})(); })();

176
client/src/websocket.rs Normal file
View file

@ -0,0 +1,176 @@
use crate::*;
use base64::{engine::general_purpose::STANDARD, Engine};
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketError};
use http_body_util::Empty;
use hyper::{
client::conn::http1 as hyper_conn,
header::{CONNECTION, UPGRADE},
StatusCode,
};
use js_sys::Function;
use std::str::from_utf8;
use tokio::sync::{mpsc, oneshot};
enum EpxMsg {
SendText(String, oneshot::Sender<Result<(), WebSocketError>>),
Close,
}
#[wasm_bindgen]
pub struct EpxWebSocket {
msg_sender: mpsc::Sender<EpxMsg>,
}
#[wasm_bindgen]
impl EpxWebSocket {
#[wasm_bindgen(constructor)]
pub async fn connect(
onopen: Function,
onclose: Function,
onmessage: Function,
tcp: &EpoxyClient,
url: String,
protocols: Vec<String>,
origin: String,
) -> Result<EpxWebSocket, JsError> {
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
let host = url.host().replace_err("URL must have a host")?;
let rand: [u8; 16] = rand::random();
let key = STANDARD.encode(rand);
let mut builder = Request::builder()
.method("GET")
.uri(url.clone())
.header("Host", host)
.header("Origin", origin)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header("Sec-WebSocket-Key", key)
.header("Sec-WebSocket-Version", "13");
if !protocols.is_empty() {
builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", "));
}
let req = builder.body(Empty::<Bytes>::new())?;
let stream = tcp.get_http_io(&url).await?;
let (mut sender, conn) =
hyper_conn::handshake::<TokioIo<EpxStream>, Empty<Bytes>>(TokioIo::new(stream))
.await?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.with_upgrades().await {
error!("wstcp: error in muxed hyper connection (ws)! {:?}", e);
}
});
let mut response = sender.send_request(req).await?;
verify(&response)?;
let mut ws = WebSocket::after_handshake(
TokioIo::new(hyper::upgrade::on(&mut response).await?),
Role::Client,
);
let (msg_sender, mut rx) = mpsc::channel(1);
wasm_bindgen_futures::spawn_local(async move {
loop {
tokio::select! {
frame = ws.read_frame() => {
if let Ok(frame) = frame {
error!("hiii");
match frame.opcode {
OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) {
let _ = onmessage.call1(&JsValue::null(), &jval!(str));
}
}
OpCode::Binary => {
let _ = onmessage.call1(
&JsValue::null(),
&jval!(Uint8Array::from(frame.payload.to_vec().as_slice())),
);
}
OpCode::Close => {
let _ = onclose.call0(&JsValue::null());
break;
}
_ => panic!("unknown opcode {:?}", frame.opcode),
}
}
}
msg = rx.recv() => {
if let Some(msg) = msg {
match msg {
EpxMsg::SendText(payload, err) => {
let _ = err.send(ws.write_frame(Frame::text(
Payload::Owned(payload.as_bytes().to_vec()),
))
.await);
}
EpxMsg::Close => break,
}
} else {
break;
}
}
}
}
let _ = ws.write_frame(Frame::close(CloseCode::Normal.into(), b""))
.await;
});
onopen
.call0(&Object::default())
.replace_err("Failed to call onopen")?;
Ok(Self { msg_sender })
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: String) -> Result<(), JsError> {
let (tx, rx) = oneshot::channel();
self.msg_sender.send(EpxMsg::SendText(payload, tx)).await?;
Ok(rx.await??)
}
#[wasm_bindgen]
pub async fn close(&mut self) -> Result<(), JsError> {
self.msg_sender.send(EpxMsg::Close).await?;
Ok(())
}
}
// https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189
fn verify(response: &Response<Incoming>) -> Result<(), JsError> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(jerr!("wstcpws connect: Invalid status code"));
}
let headers = response.headers();
if !headers
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(jerr!("wstcpws connect: Invalid upgrade header"));
}
if !headers
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(jerr!("wstcpws connect: Invalid upgrade header"));
}
Ok(())
}