From ad7a34e86d681863deef85d9ca206d5ffda7c03e Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 22 Jan 2024 08:59:53 -0800 Subject: [PATCH 01/26] add wisp lib --- Cargo.lock | 137 ++++++------------- Cargo.toml | 2 +- server/Cargo.toml | 9 +- server/src/main.rs | 330 ++++++++++++++++++++++++--------------------- wisp/.gitignore | 1 + wisp/Cargo.lock | 320 +++++++++++++++++++++++++++++++++++++++++++ wisp/Cargo.toml | 11 ++ wisp/src/lib.rs | 25 ++++ wisp/src/packet.rs | 237 ++++++++++++++++++++++++++++++++ wisp/src/ws.rs | 40 ++++++ 10 files changed, 857 insertions(+), 255 deletions(-) create mode 100644 wisp/.gitignore create mode 100644 wisp/Cargo.lock create mode 100644 wisp/Cargo.toml create mode 100644 wisp/src/lib.rs create mode 100644 wisp/src/packet.rs create mode 100644 wisp/src/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 4f61515..3db0195 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -214,12 +214,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "data-encoding" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" - [[package]] name = "digest" version = "0.10.7" @@ -271,13 +265,16 @@ dependencies = [ name = "epoxy-server" version = "1.0.0" dependencies = [ + "bytes", + "fastwebsockets", + "futures-util", "http-body-util", "hyper", "hyper-util", - "rusty-penguin", "tokio", "tokio-native-tls", - "tokio-tungstenite", + "tokio-util", + "wisp-mux", ] [[package]] @@ -302,7 +299,13 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f63dd7b57f9b33b1741fa631c9522eb35d43e96dcca4a6a91d5e4ca7c93acdc1" dependencies = [ + "base64", + "http-body-util", + "hyper", + "hyper-util", + "pin-project", "rand", + "sha1", "simdutf8", "thiserror", "tokio", @@ -340,15 +343,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - [[package]] name = "futures" version = "0.3.30" @@ -567,16 +561,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "idna" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "itoa" version = "1.0.10" @@ -779,12 +763,6 @@ dependencies = [ "wasm-bindgen-futures", ] -[[package]] -name = "percent-encoding" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - [[package]] name = "pharos" version = "0.5.3" @@ -795,6 +773,26 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -951,23 +949,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "rusty-penguin" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aefd4b85c815cf35675640924e0e73d9847bbdec8aa2e7daa8703fc5161f11d9" -dependencies = [ - "bytes", - "futures-util", - "http 0.2.11", - "parking_lot", - "rand", - "thiserror", - "tokio", - "tokio-tungstenite", - "tracing", -] - [[package]] name = "schannel" version = "0.1.23" @@ -1116,21 +1097,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tinyvec" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "tokio" version = "1.35.1" @@ -1204,6 +1170,7 @@ dependencies = [ "futures-sink", "pin-project-lite", "tokio", + "tracing", ] [[package]] @@ -1251,14 +1218,9 @@ checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" dependencies = [ "byteorder", "bytes", - "data-encoding", - "http 1.0.0", - "httparse", "log", "rand", - "sha1", "thiserror", - "url", "utf-8", ] @@ -1268,44 +1230,18 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" - [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" -[[package]] -name = "unicode-normalization" -version = "0.1.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" -dependencies = [ - "tinyvec", -] - [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" -[[package]] -name = "url" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - [[package]] name = "utf-8" version = "0.7.6" @@ -1569,6 +1505,15 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +[[package]] +name = "wisp-mux" +version = "0.1.0" +dependencies = [ + "bytes", + "futures", + "futures-util", +] + [[package]] name = "ws_stream_wasm" version = "0.7.4" diff --git a/Cargo.toml b/Cargo.toml index 0d2e374..1927a61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["server", "client"] +members = ["server", "client", "wisp"] [patch.crates-io] rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } diff --git a/server/Cargo.toml b/server/Cargo.toml index 6abb7ca..c0a4a10 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -4,10 +4,13 @@ version = "1.0.0" edition = "2021" [dependencies] +bytes = "1.5.0" +fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } +futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } hyper-util = { version = "0.1.2", features = ["tokio"] } -rusty-penguin = { version = "0.5.3", default-features = false } -tokio = { version = "1.35.1", features = ["rt-multi-thread", "net", "macros"] } +tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio-native-tls = "0.3.1" -tokio-tungstenite = "0.21.0" +tokio-util = { version = "0.7.10", features = ["codec"] } +wisp-mux = { path = "../wisp" } diff --git a/server/src/main.rs b/server/src/main.rs index fc56579..6318929 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,176 +1,196 @@ -use std::{convert::Infallible, env, net::SocketAddr, sync::Arc}; +use std::io::Error; +use bytes::Bytes; +use fastwebsockets::{ + upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError, +}; +use futures_util::{SinkExt, StreamExt}; use hyper::{ - body::Incoming, - header::{ - HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, - SEC_WEBSOCKET_VERSION, UPGRADE, - }, - server::conn::http1, - service::service_fn, - upgrade::Upgraded, - Method, Request, Response, StatusCode, Version, + body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request, + Response, StatusCode, }; use hyper_util::rt::TokioIo; -use penguin_mux::{Multiplexor, MuxStream}; -use tokio::{ - net::{TcpListener, TcpStream}, - task::{JoinError, JoinSet}, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_native_tls::{native_tls, TlsAcceptor}; -use tokio_tungstenite::{ - tungstenite::{handshake::derive_accept_key, protocol::Role}, - WebSocketStream, -}; +use tokio_util::codec::{BytesCodec, Framed}; -type Body = http_body_util::Empty; +type HttpBody = http_body_util::Empty; -type MultiplexorStream = MuxStream>>; - -async fn forward(mut stream: MultiplexorStream) -> Result<(), JoinError> { - println!("forwarding"); - let host = std::str::from_utf8(&stream.dest_host).unwrap(); - let mut tcp_stream = TcpStream::connect((host, stream.dest_port)).await.unwrap(); - println!("connected to {:?}", tcp_stream.peer_addr().unwrap()); - tokio::io::copy_bidirectional(&mut stream, &mut tcp_stream) - .await - .unwrap(); - println!("finished"); - Ok(()) -} - -async fn handle_connection(ws_stream: WebSocketStream>, addr: SocketAddr) { - println!("WebSocket connection established: {}", addr); - let mux = Multiplexor::new(ws_stream, penguin_mux::Role::Server, None, None); - let mut jobs = JoinSet::new(); - println!("muxing"); - loop { - tokio::select! { - Some(result) = jobs.join_next() => { - match result { - Ok(Ok(())) => {} - Ok(Err(err)) | Err(err) => eprintln!("failed to forward: {:?}", err), - } - } - Ok(result) = mux.server_new_stream_channel() => { - jobs.spawn(forward(result)); - } - else => { - break; - } - } - } - println!("{} disconnected", &addr); -} - -async fn handle_request( - mut req: Request, - addr: SocketAddr, -) -> Result, Infallible> { - let headers = req.headers(); - let derived = headers - .get(SEC_WEBSOCKET_KEY) - .map(|k| derive_accept_key(k.as_bytes())); - - let mut negotiated_protocol: Option = None; - if let Some(protocols) = headers - .get(SEC_WEBSOCKET_PROTOCOL) - .and_then(|h| h.to_str().ok()) - { - negotiated_protocol = protocols.split(',').next().map(|h| h.trim().to_string()); - } - - if req.method() != Method::GET - || req.version() < Version::HTTP_11 - || !headers - .get(CONNECTION) - .and_then(|h| h.to_str().ok()) - .map(|h| { - h.split(|c| c == ' ' || c == ',') - .any(|p| p.eq_ignore_ascii_case("upgrade")) - }) - .unwrap_or(false) - || !headers - .get(UPGRADE) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) - || !headers - .get(SEC_WEBSOCKET_VERSION) - .map(|h| h == "13") - .unwrap_or(false) - || derived.is_none() - { - return Ok(Response::new(Body::default())); - } - - let ver = req.version(); - tokio::task::spawn(async move { - match hyper::upgrade::on(&mut req).await { - Ok(upgraded) => { - let upgraded = TokioIo::new(upgraded); - handle_connection( - WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, - addr, - ) - .await; - } - Err(e) => eprintln!("upgrade error: {}", e), - } - }); - - let mut res = Response::new(Body::default()); - *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - *res.version_mut() = ver; - res.headers_mut() - .append(CONNECTION, HeaderValue::from_static("Upgrade")); - res.headers_mut() - .append(UPGRADE, HeaderValue::from_static("websocket")); - res.headers_mut() - .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); - if let Some(protocol) = negotiated_protocol { - res.headers_mut() - .append(SEC_WEBSOCKET_PROTOCOL, protocol.parse().unwrap()); - } - - Ok(res) -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let addr = env::args() - .nth(1) - .unwrap_or_else(|| "0.0.0.0:4000".to_string()) - .parse::()?; +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<(), Error> { let pem = include_bytes!("./pem.pem"); let key = include_bytes!("./key.pem"); + let identity = native_tls::Identity::from_pkcs8(pem, key).expect("failed to make identity"); + let prefix = if let Some(prefix) = std::env::args().nth(1) { + prefix + } else { + "/".to_string() + }; + let port = if let Some(prefix) = std::env::args().nth(1) { + prefix + } else { + "4000".to_string() + }; - let identity = native_tls::Identity::from_pkcs8(pem, key).expect("invalid pem/key"); - - let acceptor = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity).unwrap()); - let acceptor = Arc::new(acceptor); - - let listener = TcpListener::bind(addr).await?; - - println!("listening on {}", addr); - - loop { - let (stream, remote_addr) = listener.accept().await?; - let acceptor = acceptor.clone(); + let socket = TcpListener::bind(format!("0.0.0.0:{}", port)) + .await + .expect("failed to bind"); + let acceptor = TlsAcceptor::from( + native_tls::TlsAcceptor::new(identity).expect("failed to make tls acceptor"), + ); + let acceptor = std::sync::Arc::new(acceptor); + println!("listening on 0.0.0.0:4000"); + while let Ok((stream, addr)) = socket.accept().await { + let acceptor_cloned = acceptor.clone(); + let prefix_cloned = prefix.clone(); tokio::spawn(async move { - let stream = acceptor.accept(stream).await.expect("not tls"); + let stream = acceptor_cloned.accept(stream).await.expect("not tls"); let io = TokioIo::new(stream); - - let service = service_fn(move |req| handle_request(req, remote_addr)); - + let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); let conn = http1::Builder::new() .serve_connection(io, service) .with_upgrades(); - if let Err(err) = conn.await { - eprintln!("failed to serve connection: {:?}", err); + println!("{:?}: failed to serve conn: {:?}", addr, err); } }); } + + Ok(()) } + +async fn accept_http( + mut req: Request, + addr: String, + prefix: String, +) -> Result, WebSocketError> { + if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) { + let uri = req.uri().clone(); + let (mut res, fut) = upgrade::upgrade(&mut req)?; + + tokio::spawn(async move { + if *uri.path() != prefix { + if let Err(e) = + accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await + { + println!("{:?}: error in ws handling: {:?}", addr, e); + } + } + }); + + if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") { + let first_protocol = protocol + .to_str() + .expect("failed to get protocol") + .split(',') + .next() + .expect("failed to get first protocol") + .trim(); + res.headers_mut().insert( + "Sec-Websocket-Protocol", + HeaderValue::from_str(first_protocol).unwrap(), + ); + } + + Ok(res) + } else { + Ok(Response::builder() + .status(StatusCode::OK) + .body(HttpBody::new()) + .unwrap()) + } +} + +async fn accept_wsproxy( + fut: upgrade::UpgradeFut, + incoming_uri: String, + addr: String, + prefix: String, +) -> Result<(), Box> { + let mut ws_stream = FragmentCollector::new(fut.await?); + + // should always have prefix + let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap(); + + println!("{:?}: connected", addr); + + let tcp_stream = match TcpStream::connect(incoming_uri).await { + Ok(stream) => stream, + Err(err) => { + ws_stream + .write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect")) + .await + .unwrap(); + return Err(Box::new(err)); + } + }; + let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); + + loop { + tokio::select! { + event = ws_stream.read_frame() => { + match event { + Ok(frame) => { + print!("{:?}: event ws - ", addr); + match frame.opcode { + OpCode::Text | OpCode::Binary => { + if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() { + println!("sent success"); + } else { + println!("sent FAILED"); + } + } + OpCode::Close => { + if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { + println!("closed success"); + } else { + println!("closed FAILED"); + } + break; + } + _ => { + println!("ignored"); + } + } + }, + Err(err) => { + print!("{:?}: err in ws: {:?} - ", addr, err); + if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { + println!("closed tcp success"); + } else { + println!("closed tcp FAILED"); + } + break; + } + } + }, + event = tcp_stream_framed.next() => { + if let Some(res) = event { + print!("{:?}: event tcp - ", addr); + match res { + Ok(buf) => { + if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() { + println!("sent success"); + } else { + println!("sent FAILED"); + } + } + Err(_) => { + if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() { + println!("closed success"); + } else { + println!("closed FAILED"); + } + } + } + } + } + } + } + + println!("\"{}\": connection closed", addr); + + Ok(()) +} + diff --git a/wisp/.gitignore b/wisp/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/wisp/.gitignore @@ -0,0 +1 @@ +/target diff --git a/wisp/Cargo.lock b/wisp/Cargo.lock new file mode 100644 index 0000000..19bc2ba --- /dev/null +++ b/wisp/Cargo.lock @@ -0,0 +1,320 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + +[[package]] +name = "libc" +version = "0.2.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro2" +version = "1.0.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "wisp-mux" +version = "0.1.0" +dependencies = [ + "bytes", + "dashmap", + "futures", + "futures-util", +] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml new file mode 100644 index 0000000..a660280 --- /dev/null +++ b/wisp/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "wisp-mux" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.5.0" +futures = "0.3.30" +futures-util = "0.3.30" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs new file mode 100644 index 0000000..c8147ae --- /dev/null +++ b/wisp/src/lib.rs @@ -0,0 +1,25 @@ +mod packet; +mod ws; + +pub use crate::packet::*; + +#[derive(Debug, PartialEq)] +pub enum Role { + Client, + Server, +} + +pub enum WispError { + PacketTooSmall, + InvalidPacketType, + WsFrameInvalidType, + WsFrameNotFinished, + WsImplError(Box), + Utf8Error(std::str::Utf8Error), +} + +impl From for WispError { + fn from(err: std::str::Utf8Error) -> WispError { + WispError::Utf8Error(err) + } +} diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs new file mode 100644 index 0000000..b1091b8 --- /dev/null +++ b/wisp/src/packet.rs @@ -0,0 +1,237 @@ +use crate::ws; +use crate::WispError; +use bytes::{Buf, BufMut, Bytes}; + +#[derive(Debug)] +pub struct ConnectPacket { + stream_type: u8, + destination_port: u16, + destination_hostname: String, +} + +impl ConnectPacket { + pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self { + Self { + stream_type, + destination_port, + destination_hostname, + } + } +} + +impl TryFrom for ConnectPacket { + type Error = WispError; + fn try_from(mut bytes: Bytes) -> Result { + if bytes.remaining() < (1 + 2) { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + stream_type: bytes.get_u8(), + destination_port: bytes.get_u16_le(), + destination_hostname: std::str::from_utf8(&bytes)?.to_string(), + }) + } +} + +impl From for Vec { + fn from(packet: ConnectPacket) -> Self { + let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); + encoded.put_u8(packet.stream_type); + encoded.put_u16_le(packet.destination_port); + encoded.extend(packet.destination_hostname.bytes()); + encoded + } +} + +#[derive(Debug)] +pub struct ContinuePacket { + buffer_remaining: u32, +} + +impl ContinuePacket { + pub fn new(buffer_remaining: u32) -> Self { + Self { buffer_remaining } + } +} + +impl TryFrom for ContinuePacket { + type Error = WispError; + fn try_from(mut bytes: Bytes) -> Result { + if bytes.remaining() < 4 { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + buffer_remaining: bytes.get_u32_le(), + }) + } +} + +impl From for Vec { + fn from(packet: ContinuePacket) -> Self { + let mut encoded = Self::with_capacity(4); + encoded.put_u32_le(packet.buffer_remaining); + encoded + } +} + +#[derive(Debug)] +pub struct ClosePacket { + reason: u8, +} + +impl ClosePacket { + pub fn new(reason: u8) -> Self { + Self { reason } + } +} + +impl TryFrom for ClosePacket { + type Error = WispError; + fn try_from(mut bytes: Bytes) -> Result { + if bytes.remaining() < 1 { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + reason: bytes.get_u8(), + }) + } +} + +impl From for Vec { + fn from(packet: ClosePacket) -> Self { + let mut encoded = Self::with_capacity(1); + encoded.put_u8(packet.reason); + encoded + } +} + +#[derive(Debug)] +pub enum PacketType { + Connect(ConnectPacket), + Data(Vec), + Continue(ContinuePacket), + Close(ClosePacket), +} + +impl PacketType { + pub fn as_u8(&self) -> u8 { + use PacketType::*; + match self { + Connect(_) => 0x01, + Data(_) => 0x02, + Continue(_) => 0x03, + Close(_) => 0x04, + } + } +} + +impl From for Vec { + fn from(packet: PacketType) -> Self { + use PacketType::*; + match packet { + Connect(x) => x.into(), + Data(x) => x, + Continue(x) => x.into(), + Close(x) => x.into(), + } + } +} + +#[derive(Debug)] +pub struct Packet { + stream_id: u32, + packet: PacketType, +} + +impl Packet { + pub fn new(stream_id: u32, packet: PacketType) -> Self { + Self { stream_id, packet } + } + + pub fn new_connect( + stream_id: u32, + stream_type: u8, + destination_port: u16, + destination_hostname: String, + ) -> Self { + Self { + stream_id, + packet: PacketType::Connect(ConnectPacket::new( + stream_type, + destination_port, + destination_hostname, + )), + } + } + + pub fn new_data(stream_id: u32, data: Vec) -> Self { + Self { + stream_id, + packet: PacketType::Data(data), + } + } + + pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { + Self { + stream_id, + packet: PacketType::Continue(ContinuePacket::new(buffer_remaining)), + } + } + + pub fn new_close(stream_id: u32, reason: u8) -> Self { + Self { + stream_id, + packet: PacketType::Close(ClosePacket::new(reason)), + } + } +} + +impl TryFrom for Packet { + type Error = WispError; + fn try_from(mut bytes: Bytes) -> Result { + if bytes.remaining() < 5 { + return Err(Self::Error::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + use PacketType::*; + Ok(Self { + stream_id: bytes.get_u32_le(), + packet: match packet_type { + 0x01 => Connect(ConnectPacket::try_from(bytes)?), + 0x02 => Data(bytes.to_vec()), + 0x03 => Continue(ContinuePacket::try_from(bytes)?), + 0x04 => Close(ClosePacket::try_from(bytes)?), + _ => return Err(Self::Error::InvalidPacketType), + }, + }) + } +} + +impl From for Vec { + fn from(packet: Packet) -> Self { + let mut encoded = Self::with_capacity(1 + 4); + encoded.push(packet.packet.as_u8()); + encoded.put_u32_le(packet.stream_id); + encoded.extend(Vec::::from(packet.packet)); + encoded + } +} + +impl TryFrom for Packet { + type Error = WispError; + fn try_from(frame: ws::Frame) -> Result { + if !frame.finished { + return Err(Self::Error::WsFrameNotFinished); + } + if frame.opcode != ws::OpCode::Binary { + return Err(Self::Error::WsFrameInvalidType); + } + frame.payload.try_into() + } +} + +impl From for ws::Frame { + fn from(packet: Packet) -> Self { + Self::binary(Vec::::from(packet).into()) + } +} diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs new file mode 100644 index 0000000..fbb1e56 --- /dev/null +++ b/wisp/src/ws.rs @@ -0,0 +1,40 @@ +use bytes::Bytes; + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum OpCode { + Text, + Binary, + Close, +} + +pub struct Frame { + pub finished: bool, + pub opcode: OpCode, + pub payload: Bytes, +} + +impl Frame { + pub fn text(payload: Bytes) -> Self { + Self { + finished: true, + opcode: OpCode::Text, + payload, + } + } + + pub fn binary(payload: Bytes) -> Self { + Self { + finished: true, + opcode: OpCode::Binary, + payload, + } + } + + pub fn close(payload: Bytes) -> Self { + Self { + finished: true, + opcode: OpCode::Close, + payload, + } + } +} From 1f23c26db6ba43442f5326be03a252532585aadb Mon Sep 17 00:00:00 2001 From: r58Playz Date: Mon, 22 Jan 2024 18:19:39 -0800 Subject: [PATCH 02/26] wisp part 1 --- Cargo.lock | 20 ++++++++ server/Cargo.toml | 1 + server/src/lockedws.rs | 23 +++++++++ server/src/main.rs | 113 +++++++++++++++++++++++++---------------- wisp/src/lib.rs | 2 +- wisp/src/packet.rs | 4 +- 6 files changed, 117 insertions(+), 46 deletions(-) create mode 100644 server/src/lockedws.rs diff --git a/Cargo.lock b/Cargo.lock index 3db0195..1db3161 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -214,6 +214,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "digest" version = "0.10.7" @@ -266,6 +279,7 @@ name = "epoxy-server" version = "1.0.0" dependencies = [ "bytes", + "dashmap", "fastwebsockets", "futures-util", "http-body-util", @@ -461,6 +475,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + [[package]] name = "hermit-abi" version = "0.3.3" diff --git a/server/Cargo.toml b/server/Cargo.toml index c0a4a10..11b0915 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] bytes = "1.5.0" +dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs new file mode 100644 index 0000000..53c7fb8 --- /dev/null +++ b/server/src/lockedws.rs @@ -0,0 +1,23 @@ +use fastwebsockets::{FragmentCollector, Frame, WebSocketError}; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; +use std::sync::Arc; +use tokio::sync::Mutex; + +type Ws = FragmentCollector>; + +pub struct LockedWebSocket(Arc>); + +impl LockedWebSocket { + pub fn new(ws: Ws) -> Self { + Self(Arc::new(Mutex::new(ws))) + } + + pub async fn read_frame(&self) -> Result { + self.0.lock().await.read_frame().await + } + + pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { + self.0.lock().await.write_frame(frame).await + } +} diff --git a/server/src/main.rs b/server/src/main.rs index 6318929..ef58f78 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,6 @@ -use std::io::Error; +mod lockedws; + +use std::{io::Error, sync::Arc}; use bytes::Bytes; use fastwebsockets::{ @@ -11,9 +13,12 @@ use hyper::{ }; use hyper_util::rt::TokioIo; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; +use wisp_mux::{ws, Packet, PacketType}; + type HttpBody = http_body_util::Empty; #[tokio::main(flavor = "multi_thread")] @@ -47,7 +52,8 @@ async fn main() -> Result<(), Error> { tokio::spawn(async move { let stream = acceptor_cloned.accept(stream).await.expect("not tls"); let io = TokioIo::new(stream); - let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); + let service = + service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); let conn = http1::Builder::new() .serve_connection(io, service) .with_upgrades(); @@ -72,10 +78,13 @@ async fn accept_http( tokio::spawn(async move { if *uri.path() != prefix { if let Err(e) = - accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await + accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()) + .await { println!("{:?}: error in ws handling: {:?}", addr, e); } + } else if let Err(e) = accept_ws(fut, addr.clone()).await { + println!("{:?}: error in ws handling: {:?}", addr, e); } }); @@ -102,18 +111,60 @@ async fn accept_http( } } +enum WsEvent { + Send(Bytes), + Close, +} + +async fn accept_ws( + fut: upgrade::UpgradeFut, + addr: String, +) -> Result<(), Box> { + let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?)); + + let stream_map = Arc::new(dashmap::DashMap::>::new()); + + println!("{:?}: connected", addr); + + while let Ok(mut frame) = ws_stream.read_frame().await { + use fastwebsockets::OpCode::*; + let frame = match frame.opcode { + Continuation => unreachable!(), + Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())), + Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())), + Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())), + Ping => continue, + Pong => continue, + }; + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(inner_packet) => { + let (ch_tx, ch_rx) = mpsc::unbounded_channel::(); + stream_map.clone().insert(packet.stream_id, ch_tx); + tokio::spawn(async move { + + }); + } + Data(inner_packet) => {} + Continue(_) => unreachable!(), + Close(inner_packet) => {} + } + } + } + + println!("{:?}: disconnected", addr); + Ok(()) +} + async fn accept_wsproxy( fut: upgrade::UpgradeFut, - incoming_uri: String, + incoming_uri: &str, addr: String, - prefix: String, ) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); - // should always have prefix - let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap(); - - println!("{:?}: connected", addr); + println!("{:?}: connected (wsproxy)", addr); let tcp_stream = match TcpStream::connect(incoming_uri).await { Ok(stream) => stream, @@ -132,56 +183,33 @@ async fn accept_wsproxy( event = ws_stream.read_frame() => { match event { Ok(frame) => { - print!("{:?}: event ws - ", addr); match frame.opcode { OpCode::Text | OpCode::Binary => { - if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); - } + let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await; } OpCode::Close => { - if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { - println!("closed success"); - } else { - println!("closed FAILED"); - } + // tokio closes the stream for us + drop(tcp_stream_framed); break; } - _ => { - println!("ignored"); - } + _ => {} } }, - Err(err) => { - print!("{:?}: err in ws: {:?} - ", addr, err); - if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { - println!("closed tcp success"); - } else { - println!("closed tcp FAILED"); - } + Err(_) => { + // tokio closes the stream for us + drop(tcp_stream_framed); break; } } }, event = tcp_stream_framed.next() => { if let Some(res) = event { - print!("{:?}: event tcp - ", addr); match res { Ok(buf) => { - if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); - } + let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await; } Err(_) => { - if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() { - println!("closed success"); - } else { - println!("closed FAILED"); - } + let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await; } } } @@ -189,8 +217,7 @@ async fn accept_wsproxy( } } - println!("\"{}\": connection closed", addr); + println!("{:?}: disconnected (wsproxy)", addr); Ok(()) } - diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index c8147ae..d8ade1c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,5 +1,5 @@ mod packet; -mod ws; +pub mod ws; pub use crate::packet::*; diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index b1091b8..2a9667f 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -139,8 +139,8 @@ impl From for Vec { #[derive(Debug)] pub struct Packet { - stream_id: u32, - packet: PacketType, + pub stream_id: u32, + pub packet: PacketType, } impl Packet { From 24d145cc6633d9f730ff16e7b21d0a892a708f27 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 22 Jan 2024 20:11:58 -0800 Subject: [PATCH 03/26] serverside done except it deadlocks --- Cargo.lock | 1 + server/Cargo.toml | 2 +- server/src/lockedws.rs | 1 + server/src/main.rs | 123 +++++++++++++++++++++++++++++++------ wisp/Cargo.toml | 6 +- wisp/src/fastwebsockets.rs | 40 ++++++++++++ wisp/src/lib.rs | 21 +++++++ wisp/src/packet.rs | 6 +- 8 files changed, 176 insertions(+), 24 deletions(-) create mode 100644 wisp/src/fastwebsockets.rs diff --git a/Cargo.lock b/Cargo.lock index 1db3161..0b35945 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,6 +1530,7 @@ name = "wisp-mux" version = "0.1.0" dependencies = [ "bytes", + "fastwebsockets", "futures", "futures-util", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index 11b0915..75d8355 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,4 +14,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio-native-tls = "0.3.1" tokio-util = { version = "0.7.10", features = ["codec"] } -wisp-mux = { path = "../wisp" } +wisp-mux = { path = "../wisp", features = ["fastwebsockets"] } diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs index 53c7fb8..7cf1822 100644 --- a/server/src/lockedws.rs +++ b/server/src/lockedws.rs @@ -6,6 +6,7 @@ use tokio::sync::Mutex; type Ws = FragmentCollector>; +#[derive(Clone)] pub struct LockedWebSocket(Arc>); impl LockedWebSocket { diff --git a/server/src/main.rs b/server/src/main.rs index ef58f78..3906cf5 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,8 +12,10 @@ use hyper::{ Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, +}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; @@ -120,35 +122,120 @@ async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, ) -> Result<(), Box> { - let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?)); + let ws_stream = FragmentCollector::new(fut.await?); + let ws_stream = lockedws::LockedWebSocket::new(ws_stream); let stream_map = Arc::new(dashmap::DashMap::>::new()); println!("{:?}: connected", addr); - while let Ok(mut frame) = ws_stream.read_frame().await { - use fastwebsockets::OpCode::*; - let frame = match frame.opcode { - Continuation => unreachable!(), - Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())), - Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())), - Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())), - Ping => continue, - Pong => continue, - }; - if let Ok(packet) = Packet::try_from(frame) { + ws_stream + .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) + .await?; + + while let Ok(frame) = ws_stream.read_frame().await { + if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { use PacketType::*; match packet.packet { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::unbounded_channel::(); + let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); stream_map.clone().insert(packet.stream_id, ch_tx); + let ws_stream_cloned = ws_stream.clone(); tokio::spawn(async move { - + let tcp_stream = match TcpStream::connect(format!( + "{}:{}", + inner_packet.destination_hostname, inner_packet.destination_port + )) + .await + { + Ok(stream) => stream, + Err(err) => { + ws_stream_cloned + .write_frame( + ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) + .into(), + ) + .await + .map_err(std::io::Error::other)?; + return Err(Box::new(err)); + } + }; + println!("muxing"); + let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); + loop { + tokio::select! { + event = tcp_stream.next() => { + println!("recvd"); + if let Some(res) = event { + match res { + Ok(buf) => { + ws_stream_cloned.write_frame( + ws::Frame::from( + Packet::new_data( + packet.stream_id, + buf.to_vec() + ) + ).into() + ).await.map_err(std::io::Error::other)?; + } + Err(err) => { + ws_stream_cloned + .write_frame( + ws::Frame::from(Packet::new_close( + packet.stream_id, + 0x03, + )) + .into(), + ) + .await + .map_err(std::io::Error::other)?; + return Err(Box::new(err)); + } + } + } + } + event = ch_rx.recv() => { + if let Some(event) = event { + match event { + WsEvent::Send(buf) => { + tcp_stream.send(buf).await?; + println!("sending"); + ws_stream_cloned + .write_frame( + ws::Frame::from( + Packet::new_continue( + packet.stream_id, + u32::MAX + ) + ).into() + ).await.map_err(std::io::Error::other)?; + println!("sent"); + } + WsEvent::Close => { + break; + } + } + } else { + break; + } + } + }; + } + Ok(()) }); } - Data(inner_packet) => {} + Data(inner_packet) => { + println!("recieved data for {:?}", packet.stream_id); + if let Some(stream) = stream_map.clone().get(&packet.stream_id) { + let _ = stream.send(WsEvent::Send(inner_packet.into())); + } + } Continue(_) => unreachable!(), - Close(inner_packet) => {} + Close(_) => { + if let Some(stream) = stream_map.clone().get(&packet.stream_id) { + let _ = stream.send(WsEvent::Close); + } + } } } } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index a660280..693c91e 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -3,9 +3,11 @@ name = "wisp-mux" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] bytes = "1.5.0" +fastwebsockets = { version = "0.6.0", optional = true } futures = "0.3.30" futures-util = "0.3.30" + +[features] +fastwebsockets = ["dep:fastwebsockets"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs new file mode 100644 index 0000000..ba499ac --- /dev/null +++ b/wisp/src/fastwebsockets.rs @@ -0,0 +1,40 @@ +use bytes::Bytes; +use fastwebsockets::{Payload, Frame, OpCode}; + +impl TryFrom for crate::ws::OpCode { + type Error = crate::WispError; + fn try_from(opcode: OpCode) -> Result { + use OpCode::*; + match opcode { + Continuation => Err(Self::Error::WsImplNotSupported), + Text => Ok(Self::Text), + Binary => Ok(Self::Binary), + Close => Ok(Self::Close), + Ping => Err(Self::Error::WsImplNotSupported), + Pong => Err(Self::Error::WsImplNotSupported), + } + } +} + +impl TryFrom> for crate::ws::Frame { + type Error = crate::WispError; + fn try_from(mut frame: Frame) -> Result { + let opcode = frame.opcode.try_into()?; + Ok(Self { + finished: frame.fin, + opcode, + payload: Bytes::copy_from_slice(frame.payload.to_mut()), + }) + } +} + +impl From for Frame<'_> { + fn from(frame: crate::ws::Frame) -> Self { + use crate::ws::OpCode::*; + match frame.opcode { + Text => Self::text(Payload::Owned(frame.payload.to_vec())), + Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), + Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())) + } + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index d8ade1c..897b7de 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "fastwebsockets")] +mod fastwebsockets; mod packet; pub mod ws; @@ -9,12 +11,14 @@ pub enum Role { Server, } +#[derive(Debug)] pub enum WispError { PacketTooSmall, InvalidPacketType, WsFrameInvalidType, WsFrameNotFinished, WsImplError(Box), + WsImplNotSupported, Utf8Error(std::str::Utf8Error), } @@ -23,3 +27,20 @@ impl From for WispError { WispError::Utf8Error(err) } } + +impl std::fmt::Display for WispError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + use WispError::*; + match self { + PacketTooSmall => write!(f, "Packet too small"), + InvalidPacketType => write!(f, "Invalid packet type"), + WsFrameInvalidType => write!(f, "Invalid websocket frame type"), + WsFrameNotFinished => write!(f, "Unfinished websocket frame"), + WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), + WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), + Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), + } + } +} + +impl std::error::Error for WispError {} diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 2a9667f..1c5f177 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -4,9 +4,9 @@ use bytes::{Buf, BufMut, Bytes}; #[derive(Debug)] pub struct ConnectPacket { - stream_type: u8, - destination_port: u16, - destination_hostname: String, + pub stream_type: u8, + pub destination_port: u16, + pub destination_hostname: String, } impl ConnectPacket { From 379e07d643fbb82bca86a8ac82a1457abe6c9c46 Mon Sep 17 00:00:00 2001 From: r58Playz Date: Wed, 24 Jan 2024 13:19:57 -0800 Subject: [PATCH 04/26] wisp-server-rust --- server/Cargo.toml | 2 +- server/src/lockedws.rs | 14 +++++--------- server/src/main.rs | 23 +++++++++-------------- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/server/Cargo.toml b/server/Cargo.toml index 75d8355..a0a64c1 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] bytes = "1.5.0" dashmap = "5.5.3" -fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } +fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] } futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs index 7cf1822..ffc8e13 100644 --- a/server/src/lockedws.rs +++ b/server/src/lockedws.rs @@ -1,23 +1,19 @@ -use fastwebsockets::{FragmentCollector, Frame, WebSocketError}; +use fastwebsockets::{WebSocketWrite, Frame, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::{io::WriteHalf, sync::Mutex}; -type Ws = FragmentCollector>; +type Ws = WebSocketWrite>>; #[derive(Clone)] -pub struct LockedWebSocket(Arc>); +pub struct LockedWebSocketWrite(Arc>); -impl LockedWebSocket { +impl LockedWebSocketWrite { pub fn new(ws: Ws) -> Self { Self(Arc::new(Mutex::new(ws))) } - pub async fn read_frame(&self) -> Result { - self.0.lock().await.read_frame().await - } - pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { self.0.lock().await.write_frame(frame).await } diff --git a/server/src/main.rs b/server/src/main.rs index 3906cf5..96e73c3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -122,25 +122,24 @@ async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, ) -> Result<(), Box> { - let ws_stream = FragmentCollector::new(fut.await?); - let ws_stream = lockedws::LockedWebSocket::new(ws_stream); + let (mut rx, tx) = fut.await?.split(tokio::io::split); + let tx = lockedws::LockedWebSocketWrite::new(tx); let stream_map = Arc::new(dashmap::DashMap::>::new()); println!("{:?}: connected", addr); - ws_stream - .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) + tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) .await?; - while let Ok(frame) = ws_stream.read_frame().await { + while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await { if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { use PacketType::*; match packet.packet { Connect(inner_packet) => { let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); stream_map.clone().insert(packet.stream_id, ch_tx); - let ws_stream_cloned = ws_stream.clone(); + let tx_cloned = tx.clone(); tokio::spawn(async move { let tcp_stream = match TcpStream::connect(format!( "{}:{}", @@ -150,7 +149,7 @@ async fn accept_ws( { Ok(stream) => stream, Err(err) => { - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) .into(), @@ -160,16 +159,14 @@ async fn accept_ws( return Err(Box::new(err)); } }; - println!("muxing"); let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); loop { tokio::select! { event = tcp_stream.next() => { - println!("recvd"); if let Some(res) = event { match res { Ok(buf) => { - ws_stream_cloned.write_frame( + tx_cloned.write_frame( ws::Frame::from( Packet::new_data( packet.stream_id, @@ -179,7 +176,7 @@ async fn accept_ws( ).await.map_err(std::io::Error::other)?; } Err(err) => { - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from(Packet::new_close( packet.stream_id, @@ -199,8 +196,7 @@ async fn accept_ws( match event { WsEvent::Send(buf) => { tcp_stream.send(buf).await?; - println!("sending"); - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from( Packet::new_continue( @@ -209,7 +205,6 @@ async fn accept_ws( ) ).into() ).await.map_err(std::io::Error::other)?; - println!("sent"); } WsEvent::Close => { break; From 2a5684192ab8b51fb4bb6797577d3d80f27e3470 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 27 Jan 2024 18:57:04 -0800 Subject: [PATCH 05/26] move the wisp logic into wisp lib --- Cargo.lock | 2 + server/src/lockedws.rs | 20 ---- server/src/main.rs | 235 ++++++++++++++----------------------- wisp/Cargo.toml | 6 +- wisp/src/fastwebsockets.rs | 67 ++++++++--- wisp/src/lib.rs | 108 ++++++++++++++++- wisp/src/packet.rs | 36 ++++-- wisp/src/ws.rs | 38 ++++++ 8 files changed, 314 insertions(+), 198 deletions(-) delete mode 100644 server/src/lockedws.rs diff --git a/Cargo.lock b/Cargo.lock index 0b35945..ad0128c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,9 +1530,11 @@ name = "wisp-mux" version = "0.1.0" dependencies = [ "bytes", + "dashmap", "fastwebsockets", "futures", "futures-util", + "tokio", ] [[package]] diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs deleted file mode 100644 index ffc8e13..0000000 --- a/server/src/lockedws.rs +++ /dev/null @@ -1,20 +0,0 @@ -use fastwebsockets::{WebSocketWrite, Frame, WebSocketError}; -use hyper::upgrade::Upgraded; -use hyper_util::rt::TokioIo; -use std::sync::Arc; -use tokio::{io::WriteHalf, sync::Mutex}; - -type Ws = WebSocketWrite>>; - -#[derive(Clone)] -pub struct LockedWebSocketWrite(Arc>); - -impl LockedWebSocketWrite { - pub fn new(ws: Ws) -> Self { - Self(Arc::new(Mutex::new(ws))) - } - - pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { - self.0.lock().await.write_frame(frame).await - } -} diff --git a/server/src/main.rs b/server/src/main.rs index 96e73c3..4dcbf0f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,25 +1,22 @@ -mod lockedws; - -use std::{io::Error, sync::Arc}; +#![feature(let_chains)] +use std::io::Error; use bytes::Bytes; use fastwebsockets::{ - upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError, + upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, + WebSocketError, }; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::mpsc, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; -use wisp_mux::{ws, Packet, PacketType}; +use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent}; type HttpBody = http_body_util::Empty; @@ -73,37 +70,26 @@ async fn accept_http( addr: String, prefix: String, ) -> Result, WebSocketError> { - if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) { + if upgrade::is_upgrade_request(&req) + && req.uri().path().to_string().starts_with(&prefix) + && let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") + && protocol == "wisp-v1" + { let uri = req.uri().clone(); let (mut res, fut) = upgrade::upgrade(&mut req)?; - tokio::spawn(async move { - if *uri.path() != prefix { - if let Err(e) = - accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()) - .await - { - println!("{:?}: error in ws handling: {:?}", addr, e); - } - } else if let Err(e) = accept_ws(fut, addr.clone()).await { - println!("{:?}: error in ws handling: {:?}", addr, e); - } - }); - - if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") { - let first_protocol = protocol - .to_str() - .expect("failed to get protocol") - .split(',') - .next() - .expect("failed to get first protocol") - .trim(); - res.headers_mut().insert( - "Sec-Websocket-Protocol", - HeaderValue::from_str(first_protocol).unwrap(), - ); + if *uri.path() != prefix { + tokio::spawn(async move { + accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await + }); + } else { + tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); } + res.headers_mut().insert( + "Sec-Websocket-Protocol", + HeaderValue::from_str("wisp-v1").unwrap(), + ); Ok(res) } else { Ok(Response::builder() @@ -113,127 +99,80 @@ async fn accept_http( } } -enum WsEvent { - Send(Bytes), - Close, +async fn handle_mux( + packet: ConnectPacket, + mut stream: MuxStream, +) -> Result<(), WispError> { + let uri = format!( + "{}:{}", + packet.destination_hostname, packet.destination_port + ); + match packet.stream_type { + StreamType::Tcp => { + let tcp_stream = TcpStream::connect(uri) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); + + loop { + tokio::select! { + event = stream.read() => { + match event { + Some(event) => match event { + WsEvent::Send(data) => { + tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?; + } + WsEvent::Close(_) => break, + }, + None => break + } + }, + event = tcp_stream_framed.next() => { + match event.and_then(|x| x.ok()) { + Some(event) => stream.write(event.into()).await?, + None => break + } + } + } + } + } + StreamType::Udp => todo!(), + } + Ok(()) } async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, -) -> Result<(), Box> { - let (mut rx, tx) = fut.await?.split(tokio::io::split); - let tx = lockedws::LockedWebSocketWrite::new(tx); - - let stream_map = Arc::new(dashmap::DashMap::>::new()); +) -> Result<(), Box> { + let (rx, tx) = fut.await?.split(tokio::io::split); + let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); - tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) - .await?; + let mut mux = ServerMux::new(rx, tx); - while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await { - if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { - use PacketType::*; - match packet.packet { - Connect(inner_packet) => { - let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); - stream_map.clone().insert(packet.stream_id, ch_tx); - let tx_cloned = tx.clone(); - tokio::spawn(async move { - let tcp_stream = match TcpStream::connect(format!( - "{}:{}", - inner_packet.destination_hostname, inner_packet.destination_port - )) + mux.server_loop(&mut |packet, stream| async move { + let tx_cloned = stream.get_write_half(); + let stream_id = stream.stream_id; + tokio::spawn(async move { + let _ = handle_mux(packet, stream) + .or_else(|err| async { + let _ = tx_cloned + .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03))) + .await; + Err(err) + }) + .and_then(|_| async { + tx_cloned + .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) .await - { - Ok(stream) => stream, - Err(err) => { - tx_cloned - .write_frame( - ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) - .into(), - ) - .await - .map_err(std::io::Error::other)?; - return Err(Box::new(err)); - } - }; - let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); - loop { - tokio::select! { - event = tcp_stream.next() => { - if let Some(res) = event { - match res { - Ok(buf) => { - tx_cloned.write_frame( - ws::Frame::from( - Packet::new_data( - packet.stream_id, - buf.to_vec() - ) - ).into() - ).await.map_err(std::io::Error::other)?; - } - Err(err) => { - tx_cloned - .write_frame( - ws::Frame::from(Packet::new_close( - packet.stream_id, - 0x03, - )) - .into(), - ) - .await - .map_err(std::io::Error::other)?; - return Err(Box::new(err)); - } - } - } - } - event = ch_rx.recv() => { - if let Some(event) = event { - match event { - WsEvent::Send(buf) => { - tcp_stream.send(buf).await?; - tx_cloned - .write_frame( - ws::Frame::from( - Packet::new_continue( - packet.stream_id, - u32::MAX - ) - ).into() - ).await.map_err(std::io::Error::other)?; - } - WsEvent::Close => { - break; - } - } - } else { - break; - } - } - }; - } - Ok(()) - }); - } - Data(inner_packet) => { - println!("recieved data for {:?}", packet.stream_id); - if let Some(stream) = stream_map.clone().get(&packet.stream_id) { - let _ = stream.send(WsEvent::Send(inner_packet.into())); - } - } - Continue(_) => unreachable!(), - Close(_) => { - if let Some(stream) = stream_map.clone().get(&packet.stream_id) { - let _ = stream.send(WsEvent::Close); - } - } - } - } - } + }) + .await; + }); + Ok(()) + }) + .await?; println!("{:?}: disconnected", addr); Ok(()) @@ -243,7 +182,7 @@ async fn accept_wsproxy( fut: upgrade::UpgradeFut, incoming_uri: &str, addr: String, -) -> Result<(), Box> { +) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); println!("{:?}: connected (wsproxy)", addr); diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 693c91e..14d3e92 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -5,9 +5,11 @@ edition = "2021" [dependencies] bytes = "1.5.0" -fastwebsockets = { version = "0.6.0", optional = true } +dashmap = "5.5.3" +fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" +tokio = { version = "1.35.1", optional = true } [features] -fastwebsockets = ["dep:fastwebsockets"] +fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index ba499ac..6aacb28 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,30 +1,30 @@ use bytes::Bytes; -use fastwebsockets::{Payload, Frame, OpCode}; +use fastwebsockets::{ + FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, +}; +use tokio::io::{AsyncRead, AsyncWrite}; -impl TryFrom for crate::ws::OpCode { - type Error = crate::WispError; - fn try_from(opcode: OpCode) -> Result { +impl From for crate::ws::OpCode { + fn from(opcode: OpCode) -> Self { use OpCode::*; match opcode { - Continuation => Err(Self::Error::WsImplNotSupported), - Text => Ok(Self::Text), - Binary => Ok(Self::Binary), - Close => Ok(Self::Close), - Ping => Err(Self::Error::WsImplNotSupported), - Pong => Err(Self::Error::WsImplNotSupported), + Continuation => unreachable!(), + Text => Self::Text, + Binary => Self::Binary, + Close => Self::Close, + Ping => Self::Ping, + Pong => Self::Pong, } } } -impl TryFrom> for crate::ws::Frame { - type Error = crate::WispError; - fn try_from(mut frame: Frame) -> Result { - let opcode = frame.opcode.try_into()?; - Ok(Self { +impl From> for crate::ws::Frame { + fn from(mut frame: Frame) -> Self { + Self { finished: frame.fin, - opcode, + opcode: frame.opcode.into(), payload: Bytes::copy_from_slice(frame.payload.to_mut()), - }) + } } } @@ -34,7 +34,38 @@ impl From for Frame<'_> { match frame.opcode { Text => Self::text(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), - Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())) + Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), + Ping => Self::new( + true, + OpCode::Ping, + None, + Payload::Owned(frame.payload.to_vec()), + ), + Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), } } } + +impl From for crate::WispError { + fn from(err: WebSocketError) -> Self { + Self::WsImplError(Box::new(err)) + } +} + +impl crate::ws::WebSocketRead for FragmentCollectorRead { + async fn wisp_read_frame( + &mut self, + tx: &mut crate::ws::LockedWebSocketWrite, + ) -> Result { + Ok(self + .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .await? + .into()) + } +} + +impl crate::ws::WebSocketWrite for WebSocketWrite { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { + self.write_frame(frame.into()).await.map_err(|e| e.into()) + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 897b7de..c1318a5 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -5,6 +5,11 @@ pub mod ws; pub use crate::packet::*; +use bytes::Bytes; +use dashmap::DashMap; +use futures::{channel::mpsc, StreamExt}; +use std::sync::Arc; + #[derive(Debug, PartialEq)] pub enum Role { Client, @@ -15,11 +20,13 @@ pub enum Role { pub enum WispError { PacketTooSmall, InvalidPacketType, + InvalidStreamType, WsFrameInvalidType, WsFrameNotFinished, - WsImplError(Box), + WsImplError(Box), WsImplNotSupported, Utf8Error(std::str::Utf8Error), + Other(Box), } impl From for WispError { @@ -34,13 +41,112 @@ impl std::fmt::Display for WispError { match self { PacketTooSmall => write!(f, "Packet too small"), InvalidPacketType => write!(f, "Invalid packet type"), + InvalidStreamType => write!(f, "Invalid stream type"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), + Other(err) => write!(f, "Other error: {:?}", err), } } } impl std::error::Error for WispError {} + +pub enum WsEvent { + Send(Bytes), + Close(ClosePacket), +} + +pub struct MuxStream +where + W: ws::WebSocketWrite, +{ + pub stream_id: u32, + rx: mpsc::UnboundedReceiver, + tx: ws::LockedWebSocketWrite, +} + +impl MuxStream { + pub async fn read(&mut self) -> Option { + self.rx.next().await + } + + pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + self.tx + .write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data))) + .await + } + + pub fn get_write_half(&self) -> ws::LockedWebSocketWrite { + self.tx.clone() + } +} + +pub struct ServerMux +where + R: ws::WebSocketRead, + W: ws::WebSocketWrite, +{ + rx: R, + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, +} + +impl ServerMux { + pub fn new(read: R, write: W) -> Self { + Self { + rx: read, + tx: ws::LockedWebSocketWrite::new(write), + stream_map: Arc::new(DashMap::new()), + } + } + + pub async fn server_loop( + &mut self, + handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> Result<(), WispError> + where + FR: std::future::Future>, + { + self.tx + .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX))) + .await?; + + while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(inner_packet) => { + let (ch_tx, ch_rx) = mpsc::unbounded(); + self.stream_map.clone().insert(packet.stream_id, ch_tx); + let _ = handler_fn( + inner_packet, + MuxStream { + stream_id: packet.stream_id, + rx: ch_rx, + tx: self.tx.clone(), + }, + ).await; + } + Data(data) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Send(data)); + self.tx + .write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX))) + .await?; + } + } + Continue(_) => unreachable!(), + Close(inner_packet) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + } + } + } + } + } + Ok(()) + } +} diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 1c5f177..98eb20e 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -2,15 +2,33 @@ use crate::ws; use crate::WispError; use bytes::{Buf, BufMut, Bytes}; +#[derive(Debug)] +pub enum StreamType { + Tcp = 0x01, + Udp = 0x02, +} + +impl TryFrom for StreamType { + type Error = WispError; + fn try_from(stream_type: u8) -> Result { + use StreamType::*; + match stream_type { + 0x01 => Ok(Tcp), + 0x02 => Ok(Udp), + _ => Err(Self::Error::InvalidStreamType), + } + } +} + #[derive(Debug)] pub struct ConnectPacket { - pub stream_type: u8, + pub stream_type: StreamType, pub destination_port: u16, pub destination_hostname: String, } impl ConnectPacket { - pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self { + pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self { Self { stream_type, destination_port, @@ -26,7 +44,7 @@ impl TryFrom for ConnectPacket { return Err(Self::Error::PacketTooSmall); } Ok(Self { - stream_type: bytes.get_u8(), + stream_type: bytes.get_u8().try_into()?, destination_port: bytes.get_u16_le(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(), }) @@ -36,7 +54,7 @@ impl TryFrom for ConnectPacket { impl From for Vec { fn from(packet: ConnectPacket) -> Self { let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); - encoded.put_u8(packet.stream_type); + encoded.put_u8(packet.stream_type as u8); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); encoded @@ -108,7 +126,7 @@ impl From for Vec { #[derive(Debug)] pub enum PacketType { Connect(ConnectPacket), - Data(Vec), + Data(Bytes), Continue(ContinuePacket), Close(ClosePacket), } @@ -130,7 +148,7 @@ impl From for Vec { use PacketType::*; match packet { Connect(x) => x.into(), - Data(x) => x, + Data(x) => x.to_vec(), Continue(x) => x.into(), Close(x) => x.into(), } @@ -150,7 +168,7 @@ impl Packet { pub fn new_connect( stream_id: u32, - stream_type: u8, + stream_type: StreamType, destination_port: u16, destination_hostname: String, ) -> Self { @@ -164,7 +182,7 @@ impl Packet { } } - pub fn new_data(stream_id: u32, data: Vec) -> Self { + pub fn new_data(stream_id: u32, data: Bytes) -> Self { Self { stream_id, packet: PacketType::Data(data), @@ -198,7 +216,7 @@ impl TryFrom for Packet { stream_id: bytes.get_u32_le(), packet: match packet_type { 0x01 => Connect(ConnectPacket::try_from(bytes)?), - 0x02 => Data(bytes.to_vec()), + 0x02 => Data(bytes), 0x03 => Continue(ContinuePacket::try_from(bytes)?), 0x04 => Close(ClosePacket::try_from(bytes)?), _ => return Err(Self::Error::InvalidPacketType), diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index fbb1e56..dc8bdcc 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -1,10 +1,14 @@ use bytes::Bytes; +use futures::lock::Mutex; +use std::sync::Arc; #[derive(Debug, PartialEq, Clone, Copy)] pub enum OpCode { Text, Binary, Close, + Ping, + Pong, } pub struct Frame { @@ -38,3 +42,37 @@ impl Frame { } } } + +pub trait WebSocketRead { + fn wisp_read_frame( + &mut self, + tx: &mut crate::ws::LockedWebSocketWrite, + ) -> impl std::future::Future>; +} + +pub trait WebSocketWrite { + fn wisp_write_frame( + &mut self, + frame: Frame, + ) -> impl std::future::Future>; +} + +pub struct LockedWebSocketWrite(Arc>) +where + S: WebSocketWrite; + +impl LockedWebSocketWrite { + pub fn new(ws: S) -> Self { + Self(Arc::new(Mutex::new(ws))) + } + + pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { + self.0.lock().await.wisp_write_frame(frame).await + } +} + +impl Clone for LockedWebSocketWrite { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} From 29adf77a2e547f7d02104b04f068d33994f5503f Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 27 Jan 2024 19:35:54 -0800 Subject: [PATCH 06/26] untested udp support (example client doesn't support udp) --- server/src/main.rs | 57 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 4dcbf0f..f254b7f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,7 +12,7 @@ use hyper::{ Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; @@ -102,7 +102,7 @@ async fn accept_http( async fn handle_mux( packet: ConnectPacket, mut stream: MuxStream, -) -> Result<(), WispError> { +) -> Result { let uri = format!( "{}:{}", packet.destination_hostname, packet.destination_port @@ -122,23 +122,47 @@ async fn handle_mux( WsEvent::Send(data) => { tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?; } - WsEvent::Close(_) => break, + WsEvent::Close(_) => return Ok(false), }, - None => break + None => break, } }, event = tcp_stream_framed.next() => { match event.and_then(|x| x.ok()) { Some(event) => stream.write(event.into()).await?, - None => break + None => return Ok(true), + } + } + } + } + } + StreamType::Udp => { + let udp_socket = UdpSocket::bind(uri) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + let mut data = vec![0u8; 65507]; // udp standard max datagram size + loop { + tokio::select! { + size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => { + let size = size?; + stream.write(Bytes::copy_from_slice(&data[..size])).await? + }, + event = stream.read() => { + match event { + Some(event) => match event { + WsEvent::Send(data) => { + udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?; + } + WsEvent::Close(_) => return Ok(false), + }, + None => break, } } } } } - StreamType::Udp => todo!(), } - Ok(()) + Ok(false) } async fn accept_ws( @@ -153,20 +177,25 @@ async fn accept_ws( let mut mux = ServerMux::new(rx, tx); mux.server_loop(&mut |packet, stream| async move { - let tx_cloned = stream.get_write_half(); + let tx_cloned_err = stream.get_write_half(); + let tx_cloned_ok = stream.get_write_half(); let stream_id = stream.stream_id; tokio::spawn(async move { let _ = handle_mux(packet, stream) - .or_else(|err| async { - let _ = tx_cloned + .or_else(|err| async move { + let _ = tx_cloned_err .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03))) .await; Err(err) }) - .and_then(|_| async { - tx_cloned - .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) - .await + .and_then(|should_send| async move { + if should_send { + tx_cloned_ok + .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) + .await + } else { + Ok(()) + } }) .await; }); From 8f85828e737f66320b4bff38beda997aeb7608d0 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 27 Jan 2024 19:38:49 -0800 Subject: [PATCH 07/26] fix args --- server/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main.rs b/server/src/main.rs index f254b7f..43993fd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -30,7 +30,7 @@ async fn main() -> Result<(), Error> { } else { "/".to_string() }; - let port = if let Some(prefix) = std::env::args().nth(1) { + let port = if let Some(prefix) = std::env::args().nth(2) { prefix } else { "4000".to_string() From e95d148488dce40fe0e57b0832137737283676e3 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 28 Jan 2024 11:20:41 -0800 Subject: [PATCH 08/26] add wasm ws support --- Cargo.lock | 2 + client/Cargo.toml | 1 + server/src/main.rs | 17 ++- wisp/Cargo.toml | 2 + wisp/src/fastwebsockets.rs | 11 +- wisp/src/lib.rs | 212 +++++++++++++++++++++++++++++++++++-- wisp/src/ws_stream_wasm.rs | 57 ++++++++++ 7 files changed, 277 insertions(+), 25 deletions(-) create mode 100644 wisp/src/ws_stream_wasm.rs diff --git a/Cargo.lock b/Cargo.lock index ad0128c..6326ae1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,7 @@ dependencies = [ "wasm-streams", "web-sys", "webpki-roots", + "wisp-mux", "ws_stream_wasm", ] @@ -1535,6 +1536,7 @@ dependencies = [ "futures", "futures-util", "tokio", + "ws_stream_wasm", ] [[package]] diff --git a/client/Cargo.toml b/client/Cargo.toml index 9fc6149..3a1b3bd 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -33,6 +33,7 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } rand = "0.8.5" base64 = "0.21.7" +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] } [dependencies.getrandom] features = ["js"] diff --git a/server/src/main.rs b/server/src/main.rs index 43993fd..11f6478 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -16,7 +16,7 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; -use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent}; +use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent}; type HttpBody = http_body_util::Empty; @@ -162,7 +162,7 @@ async fn handle_mux( } } } - Ok(false) + Ok(true) } async fn accept_ws( @@ -177,22 +177,17 @@ async fn accept_ws( let mut mux = ServerMux::new(rx, tx); mux.server_loop(&mut |packet, stream| async move { - let tx_cloned_err = stream.get_write_half(); - let tx_cloned_ok = stream.get_write_half(); - let stream_id = stream.stream_id; + let mut close_err = stream.get_close_handle(); + let mut close_ok = stream.get_close_handle(); tokio::spawn(async move { let _ = handle_mux(packet, stream) .or_else(|err| async move { - let _ = tx_cloned_err - .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03))) - .await; + let _ = close_err.close(0x03).await; Err(err) }) .and_then(|should_send| async move { if should_send { - tx_cloned_ok - .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) - .await + close_ok.close(0x02).await } else { Ok(()) } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 14d3e92..ae279ae 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,6 +10,8 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = futures = "0.3.30" futures-util = "0.3.30" tokio = { version = "1.35.1", optional = true } +ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] +ws_stream_wasm = ["dep:ws_stream_wasm"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 6aacb28..f020bfd 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -28,10 +28,11 @@ impl From> for crate::ws::Frame { } } -impl From for Frame<'_> { - fn from(frame: crate::ws::Frame) -> Self { +impl TryFrom for Frame<'_> { + type Error = crate::WispError; + fn try_from(frame: crate::ws::Frame) -> Result { use crate::ws::OpCode::*; - match frame.opcode { + Ok(match frame.opcode { Text => Self::text(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), @@ -42,7 +43,7 @@ impl From for Frame<'_> { Payload::Owned(frame.payload.to_vec()), ), Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), - } + }) } } @@ -66,6 +67,6 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { - self.write_frame(frame.into()).await.map_err(|e| e.into()) + self.write_frame(frame.try_into()?).await.map_err(|e| e.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index c1318a5..673e182 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -2,13 +2,18 @@ mod fastwebsockets; mod packet; pub mod ws; +#[cfg(feature = "ws_stream_wasm")] +mod ws_stream_wasm; pub use crate::packet::*; use bytes::Bytes; use dashmap::DashMap; -use futures::{channel::mpsc, StreamExt}; -use std::sync::Arc; +use futures::{channel::mpsc, channel::oneshot, SinkExt, StreamExt}; +use std::sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, +}; #[derive(Debug, PartialEq)] pub enum Role { @@ -21,9 +26,13 @@ pub enum WispError { PacketTooSmall, InvalidPacketType, InvalidStreamType, + InvalidStreamId, + MaxStreamCountReached, + StreamAlreadyClosed, WsFrameInvalidType, WsFrameNotFinished, WsImplError(Box), + WsImplSocketClosed, WsImplNotSupported, Utf8Error(std::str::Utf8Error), Other(Box), @@ -42,9 +51,13 @@ impl std::fmt::Display for WispError { PacketTooSmall => write!(f, "Packet too small"), InvalidPacketType => write!(f, "Invalid packet type"), InvalidStreamType => write!(f, "Invalid stream type"), + InvalidStreamId => write!(f, "Invalid stream id"), + MaxStreamCountReached => write!(f, "Maximum stream count reached"), + StreamAlreadyClosed => write!(f, "Stream already closed"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), + WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), Other(err) => write!(f, "Other error: {:?}", err), @@ -59,6 +72,10 @@ pub enum WsEvent { Close(ClosePacket), } +pub enum MuxEvent { + Close(u32, u8, oneshot::Sender>), +} + pub struct MuxStream where W: ws::WebSocketWrite, @@ -66,21 +83,75 @@ where pub stream_id: u32, rx: mpsc::UnboundedReceiver, tx: ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, } impl MuxStream { pub async fn read(&mut self) -> Option { - self.rx.next().await + if self.is_closed.load(Ordering::Acquire) { + return None; + } + match self.rx.next().await? { + WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), + WsEvent::Close(packet) => { + self.is_closed.store(true, Ordering::Release); + Some(WsEvent::Close(packet)) + } + } } pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } self.tx - .write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data))) + .write_frame(Packet::new_data(self.stream_id, data).into()) .await } - pub fn get_write_half(&self) -> ws::LockedWebSocketWrite { - self.tx.clone() + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + stream_id: self.stream_id, + close_channel: self.close_channel.clone(), + is_closed: self.is_closed.clone(), + } + } + + pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .send(MuxEvent::Close(self.stream_id, reason, tx)) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +} + +pub struct MuxStreamCloser { + stream_id: u32, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamCloser { + pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .send(MuxEvent::Close(self.stream_id, reason, tx)) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) } } @@ -92,14 +163,37 @@ where rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, + close_rx: mpsc::UnboundedReceiver, + close_tx: mpsc::UnboundedSender, } impl ServerMux { pub fn new(read: R, write: W) -> Self { + let (tx, rx) = mpsc::unbounded::(); Self { rx: read, tx: ws::LockedWebSocketWrite::new(write), stream_map: Arc::new(DashMap::new()), + close_rx: rx, + close_tx: tx, + } + } + + pub async fn server_bg_loop(&mut self) { + while let Some(msg) = self.close_rx.next().await { + match msg { + MuxEvent::Close(stream_id, reason, channel) => { + if self.stream_map.clone().remove(&stream_id).is_some() { + let _ = channel.send( + self.tx + .write_frame(Packet::new_close(stream_id, reason).into()) + .await, + ); + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + } } } @@ -111,7 +205,7 @@ impl ServerMux { FR: std::future::Future>, { self.tx - .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX))) + .write_frame(Packet::new_continue(0, u32::MAX).into()) .await?; while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { @@ -127,14 +221,19 @@ impl ServerMux { stream_id: packet.stream_id, rx: ch_rx, tx: self.tx.clone(), + close_channel: self.close_tx.clone(), + is_closed: AtomicBool::new(false).into(), }, - ).await; + ) + .await; } Data(data) => { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); self.tx - .write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX))) + .write_frame( + Packet::new_continue(packet.stream_id, u32::MAX).into(), + ) .await?; } } @@ -142,6 +241,7 @@ impl ServerMux { Close(inner_packet) => { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + self.stream_map.clone().remove(&packet.stream_id); } } } @@ -150,3 +250,97 @@ impl ServerMux { Ok(()) } } + +pub struct ClientMux +where + R: ws::WebSocketRead, + W: ws::WebSocketWrite, +{ + rx: R, + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, + next_free_stream_id: AtomicU32, + close_rx: mpsc::UnboundedReceiver, + close_tx: mpsc::UnboundedSender, +} + +impl ClientMux { + pub fn new(read: R, write: W) -> Self { + let (tx, rx) = mpsc::unbounded::(); + Self { + rx: read, + tx: ws::LockedWebSocketWrite::new(write), + stream_map: Arc::new(DashMap::new()), + next_free_stream_id: AtomicU32::new(1), + close_rx: rx, + close_tx: tx, + } + } + + pub async fn client_bg_loop(&mut self) { + while let Some(msg) = self.close_rx.next().await { + match msg { + MuxEvent::Close(stream_id, reason, channel) => { + if self.stream_map.clone().remove(&stream_id).is_some() { + let _ = channel.send( + self.tx + .write_frame(Packet::new_close(stream_id, reason).into()) + .await, + ); + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + } + } + } + + pub async fn client_loop(&mut self) -> Result<(), WispError> { + self.tx + .write_frame(Packet::new_continue(0, u32::MAX).into()) + .await?; + + while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(_) => unreachable!(), + Data(data) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Send(data)); + } + } + Continue(_) => {} + Close(inner_packet) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + self.stream_map.clone().remove(&packet.stream_id); + } + } + } + } + } + Ok(()) + } + + pub async fn client_new_stream( + &mut self, + ) -> Result, WispError> { + let (ch_tx, ch_rx) = mpsc::unbounded(); + let stream_id = self.next_free_stream_id.load(Ordering::Acquire); + self.next_free_stream_id.store( + stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?, + Ordering::Release, + ); + self.stream_map.clone().insert(stream_id, ch_tx); + Ok(MuxStream { + stream_id, + rx: ch_rx, + tx: self.tx.clone(), + close_channel: self.close_tx.clone(), + is_closed: AtomicBool::new(false).into(), + }) + } +} diff --git a/wisp/src/ws_stream_wasm.rs b/wisp/src/ws_stream_wasm.rs new file mode 100644 index 0000000..6e15816 --- /dev/null +++ b/wisp/src/ws_stream_wasm.rs @@ -0,0 +1,57 @@ +use futures::{SinkExt, StreamExt}; +use ws_stream_wasm::{WsErr, WsMessage, WsStream}; + +impl From for crate::ws::Frame { + fn from(msg: WsMessage) -> Self { + use crate::ws::OpCode; + match msg { + WsMessage::Text(str) => Self { + finished: true, + opcode: OpCode::Text, + payload: str.into(), + }, + WsMessage::Binary(bin) => Self { + finished: true, + opcode: OpCode::Binary, + payload: bin.into(), + }, + } + } +} + +impl TryFrom for WsMessage { + type Error = crate::WispError; + fn try_from(msg: crate::ws::Frame) -> Result { + use crate::ws::OpCode; + match msg.opcode { + OpCode::Text => Ok(Self::Text(std::str::from_utf8(&msg.payload)?.to_string())), + OpCode::Binary => Ok(Self::Binary(msg.payload.to_vec())), + _ => Err(Self::Error::WsImplNotSupported), + } + } +} + +impl From for crate::WispError { + fn from(err: WsErr) -> Self { + Self::WsImplError(Box::new(err)) + } +} + +impl crate::ws::WebSocketRead for WsStream { + async fn wisp_read_frame( + &mut self, + _: &mut crate::ws::LockedWebSocketWrite, + ) -> Result { + Ok(self + .next() + .await + .ok_or(crate::WispError::WsImplSocketClosed)? + .into()) + } +} + +impl crate::ws::WebSocketWrite for WsStream { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { + self.send(frame.try_into()?).await.map_err(|e| e.into()) + } +} From 14ddecf3fded2e4d90742e85a6b1ebbe90eb74fb Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 29 Jan 2024 09:04:15 -0800 Subject: [PATCH 09/26] split stream --- wisp/src/lib.rs | 119 +++++--------------------------- wisp/src/stream.rs | 168 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 103 deletions(-) create mode 100644 wisp/src/stream.rs diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 673e182..2eb0594 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,15 +1,16 @@ #[cfg(feature = "fastwebsockets")] mod fastwebsockets; mod packet; +mod stream; pub mod ws; #[cfg(feature = "ws_stream_wasm")] mod ws_stream_wasm; pub use crate::packet::*; +pub use crate::stream::*; -use bytes::Bytes; use dashmap::DashMap; -use futures::{channel::mpsc, channel::oneshot, SinkExt, StreamExt}; +use futures::{channel::mpsc, StreamExt}; use std::sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, @@ -67,94 +68,6 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub enum WsEvent { - Send(Bytes), - Close(ClosePacket), -} - -pub enum MuxEvent { - Close(u32, u8, oneshot::Sender>), -} - -pub struct MuxStream -where - W: ws::WebSocketWrite, -{ - pub stream_id: u32, - rx: mpsc::UnboundedReceiver, - tx: ws::LockedWebSocketWrite, - close_channel: mpsc::UnboundedSender, - is_closed: Arc, -} - -impl MuxStream { - pub async fn read(&mut self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - return None; - } - match self.rx.next().await? { - WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), - WsEvent::Close(packet) => { - self.is_closed.store(true, Ordering::Release); - Some(WsEvent::Close(packet)) - } - } - } - - pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.tx - .write_frame(Packet::new_data(self.stream_id, data).into()) - .await - } - - pub fn get_close_handle(&self) -> MuxStreamCloser { - MuxStreamCloser { - stream_id: self.stream_id, - close_channel: self.close_channel.clone(), - is_closed: self.is_closed.clone(), - } - } - - pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send(MuxEvent::Close(self.stream_id, reason, tx)) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; - self.is_closed.store(true, Ordering::Release); - Ok(()) - } -} - -pub struct MuxStreamCloser { - stream_id: u32, - close_channel: mpsc::UnboundedSender, - is_closed: Arc, -} - -impl MuxStreamCloser { - pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send(MuxEvent::Close(self.stream_id, reason, tx)) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; - self.is_closed.store(true, Ordering::Release); - Ok(()) - } -} - pub struct ServerMux where R: ws::WebSocketRead, @@ -217,13 +130,13 @@ impl ServerMux { self.stream_map.clone().insert(packet.stream_id, ch_tx); let _ = handler_fn( inner_packet, - MuxStream { - stream_id: packet.stream_id, - rx: ch_rx, - tx: self.tx.clone(), - close_channel: self.close_tx.clone(), - is_closed: AtomicBool::new(false).into(), - }, + MuxStream::new( + packet.stream_id, + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + ), ) .await; } @@ -335,12 +248,12 @@ impl ClientMux { Ordering::Release, ); self.stream_map.clone().insert(stream_id, ch_tx); - Ok(MuxStream { + Ok(MuxStream::new( stream_id, - rx: ch_rx, - tx: self.tx.clone(), - close_channel: self.close_tx.clone(), - is_closed: AtomicBool::new(false).into(), - }) + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + )) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs new file mode 100644 index 0000000..8c8a76b --- /dev/null +++ b/wisp/src/stream.rs @@ -0,0 +1,168 @@ +use bytes::Bytes; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +pub enum WsEvent { + Send(Bytes), + Close(crate::ClosePacket), +} + +pub enum MuxEvent { + Close(u32, u8, oneshot::Sender>), +} + +pub struct MuxStreamRead { + pub stream_id: u32, + rx: mpsc::UnboundedReceiver, + is_closed: Arc, +} + +impl MuxStreamRead { + pub async fn read(&mut self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + return None; + } + match self.rx.next().await? { + WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), + WsEvent::Close(packet) => { + self.is_closed.store(true, Ordering::Release); + Some(WsEvent::Close(packet)) + } + } + } +} + +pub struct MuxStreamWrite +where + W: crate::ws::WebSocketWrite, +{ + pub stream_id: u32, + tx: crate::ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamWrite { + pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + self.tx + .write_frame(crate::Packet::new_data(self.stream_id, data).into()) + .await + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + stream_id: self.stream_id, + close_channel: self.close_channel.clone(), + is_closed: self.is_closed.clone(), + } + } + + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) + .map_err(|x| crate::WispError::Other(Box::new(x)))?; + rx.await + .map_err(|x| crate::WispError::Other(Box::new(x)))??; + + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +} + +impl Drop for MuxStreamWrite { + fn drop(&mut self) { + let (tx, _) = oneshot::channel::>(); + let _ = self + .close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, 0x01, tx)); + } +} + +pub struct MuxStream +where + W: crate::ws::WebSocketWrite, +{ + pub stream_id: u32, + rx: MuxStreamRead, + tx: MuxStreamWrite, +} + +impl MuxStream { + pub(crate) fn new( + stream_id: u32, + rx: mpsc::UnboundedReceiver, + tx: crate::ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, + ) -> Self { + Self { + stream_id, + rx: MuxStreamRead { + stream_id, + rx, + is_closed: is_closed.clone(), + }, + tx: MuxStreamWrite { + stream_id, + tx, + close_channel, + is_closed: is_closed.clone(), + }, + } + } + + pub async fn read(&mut self) -> Option { + self.rx.read().await + } + + pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + self.tx.write(data).await + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + self.tx.get_close_handle() + } + + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + self.tx.close(reason).await + } + + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + (self.rx, self.tx) + } +} + +pub struct MuxStreamCloser { + stream_id: u32, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamCloser { + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) + .map_err(|x| crate::WispError::Other(Box::new(x)))?; + rx.await + .map_err(|x| crate::WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +} From be7d92b4c5d95cb1ef01a844f37975ac7f235de1 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 29 Jan 2024 19:30:55 -0800 Subject: [PATCH 10/26] finally implement AsyncRead/Write --- Cargo.lock | 2 + wisp/Cargo.toml | 2 + wisp/src/stream.rs | 102 +++++++++++++++++++++++++++++++++++++++++++-- wisp/src/ws.rs | 4 +- 4 files changed, 103 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6326ae1..dc862b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,11 +1530,13 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" name = "wisp-mux" version = "0.1.0" dependencies = [ + "async_io_stream", "bytes", "dashmap", "fastwebsockets", "futures", "futures-util", + "pin-project-lite", "tokio", "ws_stream_wasm", ] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index ae279ae..9dc0a2d 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -4,11 +4,13 @@ version = "0.1.0" edition = "2021" [dependencies] +async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" +pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 8c8a76b..ff86585 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,11 +1,18 @@ +use async_io_stream::IoStream; use bytes::Bytes; use futures::{ channel::{mpsc, oneshot}, - StreamExt, + sink, stream, + task::{Context, Poll}, + AsyncRead, AsyncWrite, Sink, Stream, StreamExt, }; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; pub enum WsEvent { @@ -36,6 +43,19 @@ impl MuxStreamRead { } } } + + pub(crate) fn into_stream(self) -> Pin>> { + Box::pin(stream::unfold(self, |mut rx| async move { + let evt = rx.read().await?; + Some(( + match evt { + WsEvent::Send(bytes) => bytes, + WsEvent::Close(_) => return None, + }, + rx, + )) + })) + } } pub struct MuxStreamWrite @@ -80,6 +100,16 @@ impl MuxStreamWrite { self.is_closed.store(true, Ordering::Release); Ok(()) } + + pub(crate) fn into_sink<'a>(self) -> Pin + 'a>> + where + W: 'a, + { + Box::pin(sink::unfold(self, |mut tx, data| async move { + tx.write(data).await?; + Ok(tx) + })) + } } impl Drop for MuxStreamWrite { @@ -143,6 +173,16 @@ impl MuxStream { pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { (self.rx, self.tx) } + + pub fn into_io<'a>(self) -> MuxStreamIo<'a> + where + W: 'a, + { + MuxStreamIo { + rx: self.rx.into_stream(), + tx: self.tx.into_sink(), + } + } } pub struct MuxStreamCloser { @@ -166,3 +206,57 @@ impl MuxStreamCloser { Ok(()) } } + +pin_project! { + pub struct MuxStreamIo<'a> { + #[pin] + rx: Pin + 'a>>, + #[pin] + tx: Pin + 'a>>, + } +} + +impl<'a> MuxStreamIo<'a> { + pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a { + IoStream::new(self.map(|x| Ok::, std::io::Error>(x.to_vec()))) + } +} + +impl Stream for MuxStreamIo<'_> { + type Item = Bytes; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_next(cx) + } +} + +impl Sink for MuxStreamIo<'_> { + type Error = crate::WispError; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_ready(cx) + } + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.project().tx.start_send(item) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx) + } +} + +impl Sink> for MuxStreamIo<'_> { + type Error = std::io::Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_ready(cx).map_err(std::io::Error::other) + } + fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + self.project().tx.start_send(item.into()).map_err(std::io::Error::other) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx).map_err(std::io::Error::other) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx).map_err(std::io::Error::other) + } +} diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index dc8bdcc..5b1243e 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -57,9 +57,7 @@ pub trait WebSocketWrite { ) -> impl std::future::Future>; } -pub struct LockedWebSocketWrite(Arc>) -where - S: WebSocketWrite; +pub struct LockedWebSocketWrite(Arc>); impl LockedWebSocketWrite { pub fn new(ws: S) -> Self { From c5cf95fcb16519552679f59ebdf275082cc3125a Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 30 Jan 2024 21:15:17 -0800 Subject: [PATCH 11/26] add wisp to client --- Cargo.lock | 95 ++-------------------- client/Cargo.toml | 4 +- client/src/lib.rs | 46 +++++++---- client/src/websocket.rs | 2 +- client/src/wrappers.rs | 110 +------------------------ server/src/main.rs | 6 +- wisp/Cargo.toml | 1 + wisp/src/fastwebsockets.rs | 6 +- wisp/src/lib.rs | 162 ++++++++++++++++++++++++++----------- wisp/src/stream.rs | 77 ++++++++---------- wisp/src/ws.rs | 8 +- wisp/src/ws_stream_wasm.rs | 13 +-- 12 files changed, 210 insertions(+), 320 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dc862b4..eec8ba9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,12 +133,6 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.5.0" @@ -248,6 +242,7 @@ name = "epoxy-client" version = "1.0.0" dependencies = [ "async-compression", + "async_io_stream", "base64", "bytes", "console_error_panic_hook", @@ -255,11 +250,10 @@ dependencies = [ "fastwebsockets", "futures-util", "getrandom", - "http 1.0.0", + "http", "http-body-util", "hyper", "js-sys", - "penguin-mux-wasm", "pin-project-lite", "rand", "ring", @@ -488,17 +482,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" -[[package]] -name = "http" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.0.0" @@ -517,7 +500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.0.0", + "http", ] [[package]] @@ -528,7 +511,7 @@ checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" dependencies = [ "bytes", "futures-util", - "http 1.0.0", + "http", "http-body", "pin-project-lite", ] @@ -554,7 +537,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "httparse", "httpdate", @@ -573,7 +556,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "hyper", "pin-project-lite", @@ -744,16 +727,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - [[package]] name = "parking_lot_core" version = "0.9.9" @@ -767,23 +740,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "penguin-mux-wasm" -version = "0.1.0" -source = "git+https://github.com/r58Playz/penguin-mux-wasm#69b413aedb6f50f55eac646fda361abe430eb022" -dependencies = [ - "bytes", - "futures-util", - "http 0.2.11", - "parking_lot", - "rand", - "thiserror", - "tokio", - "tokio-tungstenite", - "tracing", - "wasm-bindgen-futures", -] - [[package]] name = "pharos" version = "0.5.3" @@ -1129,7 +1085,6 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -1168,18 +1123,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.7.10" @@ -1201,21 +1144,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ "pin-project-lite", - "tracing-attributes", "tracing-core", ] -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tracing-core" version = "0.1.32" @@ -1231,20 +1162,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "log", - "rand", - "thiserror", - "utf-8", -] - [[package]] name = "typenum" version = "1.17.0" diff --git a/client/Cargo.toml b/client/Cargo.toml index 3a1b3bd..48ff0e5 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -16,7 +16,6 @@ http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["client", "http1"] } pin-project-lite = "0.2.13" -penguin-mux-wasm = { git = "https://github.com/r58Playz/penguin-mux-wasm" } tokio = { version = "1.35.1", default_features = false } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.39" @@ -33,7 +32,8 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } rand = "0.8.5" base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] } +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] } +async_io_stream = { version = "0.3.3", features = ["tokio_io"] } [dependencies.getrandom] features = ["js"] diff --git a/client/src/lib.rs b/client/src/lib.rs index 30aabde..80b9c58 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -8,17 +8,20 @@ mod wrappers; use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use websocket::EpxWebSocket; -use wrappers::{IncomingBody, WsStreamWrapper}; +use wrappers::IncomingBody; use std::sync::Arc; use async_compression::tokio::bufread as async_comp; +use async_io_stream::IoStream; use bytes::Bytes; -use futures_util::StreamExt; +use futures_util::{ + stream::SplitSink, + StreamExt, +}; use http::{uri, HeaderName, HeaderValue, Request, Response}; use hyper::{body::Incoming, client::conn::http1::Builder, Uri}; use js_sys::{Array, Function, Object, Reflect, Uint8Array}; -use penguin_mux_wasm::{Multiplexor, MuxStream}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_util::{ either::Either, @@ -26,6 +29,8 @@ use tokio_util::{ }; use wasm_bindgen::prelude::*; use web_sys::TextEncoder; +use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; +use ws_stream_wasm::{WsMeta, WsStream, WsMessage}; type HttpBody = http_body_util::Full; @@ -40,8 +45,8 @@ enum EpxCompression { Gzip, } -type EpxTlsStream = TlsStream>; -type EpxUnencryptedStream = MuxStream; +type EpxTlsStream = TlsStream>>; +type EpxUnencryptedStream = IoStream>; type EpxStream = Either; async fn send_req( @@ -113,7 +118,7 @@ async fn start() { #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, - mux: Multiplexor, + mux: ClientMux>, useragent: String, redirect_limit: usize, } @@ -138,11 +143,18 @@ impl EpoxyClient { } debug!("connecting to ws {:?}", ws_url); - let ws = WsStreamWrapper::connect(ws_url, None) + let (_, ws) = WsMeta::connect(ws_url, vec!["wisp-v1"]) .await .replace_err("Failed to connect to websocket")?; debug!("connected!"); - let mux = Multiplexor::new(ws, penguin_mux_wasm::Role::Client, None, None); + let (wtx, wrx) = ws.split(); + let (mux, fut) = ClientMux::new(wrx, wtx); + + wasm_bindgen_futures::spawn_local(async move { + if let Err(err) = fut.await { + error!("epoxy: error in mux future! {:?}", err); + } + }); let mut certstore = RootCertStore::empty(); certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); @@ -161,14 +173,16 @@ impl EpoxyClient { }) } - async fn get_http_io(&self, url: &Uri) -> Result { + async fn get_http_io(&mut self, url: &Uri) -> Result { let url_host = url.host().replace_err("URL must have a host")?; let url_port = utils::get_url_port(url)?; let channel = self .mux - .client_new_stream_channel(url_host.as_bytes(), url_port) + .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) .await - .replace_err("Failed to create multiplexor channel")?; + .replace_err("Failed to create multiplexor channel")? + .into_io() + .into_asyncrw(); if utils::get_is_secure(url)? { let cloned_uri = url_host.to_string().clone(); @@ -189,7 +203,7 @@ impl EpoxyClient { } async fn send_req( - &self, + &mut self, req: http::Request, should_redirect: bool, ) -> Result<(hyper::Response, Uri, bool), JsError> { @@ -217,7 +231,7 @@ impl EpoxyClient { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect_ws( - &self, + &mut self, onopen: Function, onclose: Function, onerror: Function, @@ -232,7 +246,11 @@ impl EpoxyClient { .await } - pub async fn fetch(&self, url: String, options: Object) -> Result { + pub async fn fetch( + &mut self, + url: String, + options: Object, + ) -> Result { let uri = url.parse::().replace_err("Failed to parse URL")?; let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?; if *uri_scheme != uri::Scheme::HTTP && *uri_scheme != uri::Scheme::HTTPS { diff --git a/client/src/websocket.rs b/client/src/websocket.rs index addae2c..2ce9149 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -30,7 +30,7 @@ impl EpxWebSocket { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect( - tcp: &EpoxyClient, + tcp: &mut EpoxyClient, onopen: Function, onclose: Function, onerror: Function, diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 1ecc702..8526a98 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -4,117 +4,9 @@ use std::{ task::{Context, Poll}, }; -use futures_util::{Sink, Stream}; +use futures_util::Stream; use hyper::body::Body; -use penguin_mux_wasm::ws; use pin_project_lite::pin_project; -use ws_stream_wasm::{WsErr, WsMessage, WsMeta, WsStream}; - -pin_project! { - pub struct WsStreamWrapper { - #[pin] - ws: WsStream, - } -} - -impl WsStreamWrapper { - pub async fn connect( - url: impl AsRef, - protocols: impl Into>>, - ) -> Result { - let (_, wsstream) = WsMeta::connect(url, protocols).await?; - Ok(WsStreamWrapper { ws: wsstream }) - } -} - -impl Stream for WsStreamWrapper { - type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_next(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(item.map(|x| { - Ok(match x { - WsMessage::Text(txt) => ws::Message::Text(txt), - WsMessage::Binary(bin) => ws::Message::Binary(bin), - }) - })), - Poll::Pending => Poll::>::Pending, - } - } -} - -fn wserr_to_ws_err(err: WsErr) -> ws::Error { - debug!("err: {:?}", err); - match err { - WsErr::ConnectionNotOpen => ws::Error::AlreadyClosed, - _ => ws::Error::Io(std::io::Error::other(format!("{:?}", err))), - } -} - -impl Sink for WsStreamWrapper { - type Error = ws::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_ready(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(match item { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - }), - Poll::Pending => Poll::>::Pending, - } - } - - fn start_send(self: Pin<&mut Self>, item: ws::Message) -> Result<(), Self::Error> { - use ws::Message::*; - let item = match item { - Text(txt) => WsMessage::Text(txt), - Binary(bin) => WsMessage::Binary(bin), - Close(_) => { - debug!("closing"); - return match self.ws.wrapped().close() { - Ok(_) => Ok(()), - Err(err) => Err(ws::Error::Io(std::io::Error::other(format!( - "ws close err: {:?}", - err - )))), - }; - } - Ping(_) | Pong(_) | Frame(_) => return Ok(()), - }; - let this = self.project(); - let ret = this.ws.start_send(item); - match ret { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - } - } - - // no point wrapping this as it's not going to do anything - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let ret = this.ws.poll_close(cx); - match ret { - Poll::Ready(item) => Poll::>::Ready(match item { - Ok(_) => Ok(()), - Err(err) => Err(wserr_to_ws_err(err)), - }), - Poll::Pending => Poll::>::Pending, - } - } -} - -impl ws::WebSocketStream for WsStreamWrapper { - fn ping_auto_pong(&self) -> bool { - true - } -} pin_project! { pub struct IncomingBody { diff --git a/server/src/main.rs b/server/src/main.rs index 11f6478..7b0b35c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -101,7 +101,7 @@ async fn accept_http( async fn handle_mux( packet: ConnectPacket, - mut stream: MuxStream, + mut stream: MuxStream, ) -> Result { let uri = format!( "{}:{}", @@ -174,9 +174,7 @@ async fn accept_ws( println!("{:?}: connected", addr); - let mut mux = ServerMux::new(rx, tx); - - mux.server_loop(&mut |packet, stream| async move { + ServerMux::handle(rx, tx, &mut |packet, stream| async move { let mut close_err = stream.get_close_handle(); let mut close_ok = stream.get_close_handle(); tokio::spawn(async move { diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 9dc0a2d..ee3c3d2 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -17,3 +17,4 @@ ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] ws_stream_wasm = ["dep:ws_stream_wasm"] +tokio_io = ["async_io_stream/tokio_io"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index f020bfd..fb31e4a 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -53,10 +53,10 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for FragmentCollectorRead { +impl crate::ws::WebSocketRead for FragmentCollectorRead { async fn wisp_read_frame( &mut self, - tx: &mut crate::ws::LockedWebSocketWrite, + tx: &crate::ws::LockedWebSocketWrite, ) -> Result { Ok(self .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) @@ -65,7 +65,7 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead } } -impl crate::ws::WebSocketWrite for WebSocketWrite { +impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { self.write_frame(frame.try_into()?).await.map_err(|e| e.into()) } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 2eb0594..d4f843e 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -10,7 +10,7 @@ pub use crate::packet::*; pub use crate::stream::*; use dashmap::DashMap; -use futures::{channel::mpsc, StreamExt}; +use futures::{channel::mpsc, Future, StreamExt}; use std::sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, @@ -68,38 +68,66 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub struct ServerMux +pub struct ServerMux where - R: ws::WebSocketRead, W: ws::WebSocketWrite, { - rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, - close_rx: mpsc::UnboundedReceiver, close_tx: mpsc::UnboundedSender, } -impl ServerMux { - pub fn new(read: R, write: W) -> Self { +impl ServerMux { + pub fn handle<'a, FR, R>( + read: R, + write: W, + handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> impl Future> + 'a + where + FR: std::future::Future> + 'a, + R: ws::WebSocketRead + 'a, + W: ws::WebSocketWrite + 'a, + { let (tx, rx) = mpsc::unbounded::(); - Self { - rx: read, - tx: ws::LockedWebSocketWrite::new(write), - stream_map: Arc::new(DashMap::new()), - close_rx: rx, + let write = ws::LockedWebSocketWrite::new(write); + let map = Arc::new(DashMap::new()); + let inner = ServerMux { + stream_map: map.clone(), + tx: write.clone(), close_tx: tx, - } + }; + inner.into_future(read, rx, handler_fn) } - pub async fn server_bg_loop(&mut self) { - while let Some(msg) = self.close_rx.next().await { + async fn into_future( + self, + rx: R, + close_rx: mpsc::UnboundedReceiver, + handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead, + FR: std::future::Future>, + { + futures::try_join! { + self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()), + self.server_msg_loop(rx, handler_fn) + } + .map(|_| ()) + } + + async fn server_close_loop( + &self, + mut close_rx: mpsc::UnboundedReceiver, + stream_map: Arc>>, + tx: ws::LockedWebSocketWrite, + ) -> Result<(), WispError> { + while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if self.stream_map.clone().remove(&stream_id).is_some() { + if stream_map.clone().remove(&stream_id).is_some() { let _ = channel.send( - self.tx - .write_frame(Packet::new_close(stream_id, reason).into()) + tx.write_frame(Packet::new_close(stream_id, reason).into()) .await, ); } else { @@ -108,20 +136,23 @@ impl ServerMux { } } } + Ok(()) } - pub async fn server_loop( - &mut self, + async fn server_msg_loop( + &self, + mut rx: R, handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, ) -> Result<(), WispError> where - FR: std::future::Future>, + R: ws::WebSocketRead, + FR: std::future::Future>, { self.tx .write_frame(Packet::new_continue(0, u32::MAX).into()) .await?; - while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + while let Ok(frame) = rx.wisp_read_frame(&self.tx).await { if let Ok(packet) = Packet::try_from(frame) { use PacketType::*; match packet.packet { @@ -164,34 +195,31 @@ impl ServerMux { } } -pub struct ClientMux +pub struct ClientMuxInner where - R: ws::WebSocketRead, W: ws::WebSocketWrite, { - rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, - next_free_stream_id: AtomicU32, - close_rx: mpsc::UnboundedReceiver, - close_tx: mpsc::UnboundedSender, } -impl ClientMux { - pub fn new(read: R, write: W) -> Self { - let (tx, rx) = mpsc::unbounded::(); - Self { - rx: read, - tx: ws::LockedWebSocketWrite::new(write), - stream_map: Arc::new(DashMap::new()), - next_free_stream_id: AtomicU32::new(1), - close_rx: rx, - close_tx: tx, - } +impl ClientMuxInner { + pub async fn into_future( + self, + rx: R, + close_rx: mpsc::UnboundedReceiver, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ()) } - pub async fn client_bg_loop(&mut self) { - while let Some(msg) = self.close_rx.next().await { + async fn client_bg_loop( + &self, + mut close_rx: mpsc::UnboundedReceiver, + ) -> Result<(), WispError> { + while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { if self.stream_map.clone().remove(&stream_id).is_some() { @@ -206,14 +234,14 @@ impl ClientMux { } } } + Ok(()) } - pub async fn client_loop(&mut self) -> Result<(), WispError> { - self.tx - .write_frame(Packet::new_continue(0, u32::MAX).into()) - .await?; - - while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + async fn client_loop(&self, mut rx: R) -> Result<(), WispError> + where + R: ws::WebSocketRead, + { + while let Ok(frame) = rx.wisp_read_frame(&self.tx).await { if let Ok(packet) = Packet::try_from(frame) { use PacketType::*; match packet.packet { @@ -235,12 +263,52 @@ impl ClientMux { } Ok(()) } +} + +pub struct ClientMux +where + W: ws::WebSocketWrite, +{ + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, + next_free_stream_id: AtomicU32, + close_tx: mpsc::UnboundedSender, +} + +impl ClientMux { + pub fn new(read: R, write: W) -> (Self, impl Future>) + where + R: ws::WebSocketRead, + { + let (tx, rx) = mpsc::unbounded::(); + let map = Arc::new(DashMap::new()); + let write = ws::LockedWebSocketWrite::new(write); + ( + Self { + tx: write.clone(), + stream_map: map.clone(), + next_free_stream_id: AtomicU32::new(1), + close_tx: tx, + }, + ClientMuxInner { + tx: write.clone(), + stream_map: map.clone(), + } + .into_future(read, rx), + ) + } pub async fn client_new_stream( &mut self, + stream_type: StreamType, + host: String, + port: u16, ) -> Result, WispError> { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_id = self.next_free_stream_id.load(Ordering::Acquire); + self.tx + .write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) + .await?; self.next_free_stream_id.store( stream_id .checked_add(1) diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index ff86585..cd9daab 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -4,7 +4,7 @@ use futures::{ channel::{mpsc, oneshot}, sink, stream, task::{Context, Poll}, - AsyncRead, AsyncWrite, Sink, Stream, StreamExt, + Sink, Stream, StreamExt, }; use pin_project_lite::pin_project; use std::{ @@ -44,7 +44,7 @@ impl MuxStreamRead { } } - pub(crate) fn into_stream(self) -> Pin>> { + pub(crate) fn into_stream(self) -> Pin + Send>> { Box::pin(stream::unfold(self, |mut rx| async move { let evt = rx.read().await?; Some(( @@ -68,7 +68,7 @@ where is_closed: Arc, } -impl MuxStreamWrite { +impl MuxStreamWrite { pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); @@ -101,10 +101,7 @@ impl MuxStreamWrite { Ok(()) } - pub(crate) fn into_sink<'a>(self) -> Pin + 'a>> - where - W: 'a, - { + pub(crate) fn into_sink(self) -> Pin + Send>> { Box::pin(sink::unfold(self, |mut tx, data| async move { tx.write(data).await?; Ok(tx) @@ -130,7 +127,7 @@ where tx: MuxStreamWrite, } -impl MuxStream { +impl MuxStream { pub(crate) fn new( stream_id: u32, rx: mpsc::UnboundedReceiver, @@ -174,10 +171,7 @@ impl MuxStream { (self.rx, self.tx) } - pub fn into_io<'a>(self) -> MuxStreamIo<'a> - where - W: 'a, - { + pub fn into_io(self) -> MuxStreamIo { MuxStreamIo { rx: self.rx.into_stream(), tx: self.tx.into_sink(), @@ -208,55 +202,54 @@ impl MuxStreamCloser { } pin_project! { - pub struct MuxStreamIo<'a> { + pub struct MuxStreamIo { #[pin] - rx: Pin + 'a>>, + rx: Pin + Send>>, #[pin] - tx: Pin + 'a>>, + tx: Pin + Send>>, } } -impl<'a> MuxStreamIo<'a> { - pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a { - IoStream::new(self.map(|x| Ok::, std::io::Error>(x.to_vec()))) +impl MuxStreamIo { + pub fn into_asyncrw(self) -> IoStream> { + IoStream::new(self) } } -impl Stream for MuxStreamIo<'_> { - type Item = Bytes; +impl Stream for MuxStreamIo { + type Item = Result, std::io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx) + self.project() + .rx + .poll_next(cx) + .map(|x| x.map(|x| Ok(x.to_vec()))) } } -impl Sink for MuxStreamIo<'_> { - type Error = crate::WispError; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_ready(cx) - } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - self.project().tx.start_send(item) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } -} - -impl Sink> for MuxStreamIo<'_> { +impl Sink> for MuxStreamIo { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_ready(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_ready(cx) + .map_err(std::io::Error::other) } fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - self.project().tx.start_send(item.into()).map_err(std::io::Error::other) + self.project() + .tx + .start_send(item.into()) + .map_err(std::io::Error::other) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_flush(cx) + .map_err(std::io::Error::other) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx).map_err(std::io::Error::other) + self.project() + .tx + .poll_close(cx) + .map_err(std::io::Error::other) } } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 5b1243e..f75c526 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -46,20 +46,20 @@ impl Frame { pub trait WebSocketRead { fn wisp_read_frame( &mut self, - tx: &mut crate::ws::LockedWebSocketWrite, - ) -> impl std::future::Future>; + tx: &crate::ws::LockedWebSocketWrite, + ) -> impl std::future::Future> + Send; } pub trait WebSocketWrite { fn wisp_write_frame( &mut self, frame: Frame, - ) -> impl std::future::Future>; + ) -> impl std::future::Future> + Send; } pub struct LockedWebSocketWrite(Arc>); -impl LockedWebSocketWrite { +impl LockedWebSocketWrite { pub fn new(ws: S) -> Self { Self(Arc::new(Mutex::new(ws))) } diff --git a/wisp/src/ws_stream_wasm.rs b/wisp/src/ws_stream_wasm.rs index 6e15816..410b537 100644 --- a/wisp/src/ws_stream_wasm.rs +++ b/wisp/src/ws_stream_wasm.rs @@ -1,4 +1,4 @@ -use futures::{SinkExt, StreamExt}; +use futures::{stream::{SplitStream, SplitSink}, SinkExt, StreamExt}; use ws_stream_wasm::{WsErr, WsMessage, WsStream}; impl From for crate::ws::Frame { @@ -37,10 +37,10 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for WsStream { +impl crate::ws::WebSocketRead for SplitStream { async fn wisp_read_frame( &mut self, - _: &mut crate::ws::LockedWebSocketWrite, + _: &crate::ws::LockedWebSocketWrite, ) -> Result { Ok(self .next() @@ -50,8 +50,11 @@ impl crate::ws::WebSocketRead for WsStream { } } -impl crate::ws::WebSocketWrite for WsStream { +impl crate::ws::WebSocketWrite for SplitSink { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { - self.send(frame.try_into()?).await.map_err(|e| e.into()) + self + .send(frame.try_into()?) + .await + .map_err(|e| e.into()) } } From 619a2a61c7bddb7f87ba0721828e063fd053bc9d Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 31 Jan 2024 06:47:00 -0800 Subject: [PATCH 12/26] make wasm smaller and update build/test system --- .gitignore | 9 ++++++--- Cargo.toml | 4 ++++ client/build.sh | 16 ++++++++++------ client/demo.js | 23 ++++++++++++++++------- client/index.html | 11 +++++++++-- package.json | 8 ++++---- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index bfd66de..dcb2577 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,10 @@ server/src/*.pem client/pkg client/out .direnv -client/index.js -client/module.js -client/module.d.ts +client/epoxy-bundled.js +client/epoxy-module-bundled.js +client/epoxy-module-bundled.d.ts +client/epoxy.js +client/epoxy.d.ts +client/epoxy.wasm pnpm-lock.yaml diff --git a/Cargo.toml b/Cargo.toml index 1927a61..2e7eb18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,7 @@ members = ["server", "client", "wisp"] [patch.crates-io] rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } + +[profile.release] +lto = true +opt-level = 'z' diff --git a/client/build.sh b/client/build.sh index 68a1dbe..6b0a324 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,21 +11,25 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -O4 out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -Oz out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] optimized" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") WASM_BASE64=$(base64 -w0 out/epoxy_client_bg.wasm) AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input) \{/__wbg_init() \{let input=\'data:application/wasm;base64,$WASM_BASE64\'} AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//return __wbg_finalize_init\(instance\, module\);/__wbg_finalize_init\(instance\, module\); return epoxy} -echo "$AUTOGENERATED_SOURCE" > index.js -cp index.js module.js -echo "module.exports = epoxy" >> module.js +echo "$AUTOGENERATED_SOURCE" > epoxy-bundled.js +cp epoxy-bundled.js epoxy-module-bundled.js +echo "module.exports = epoxy" >> epoxy-module-bundled.js AUTOGENERATED_TYPEDEFS=$(<"out/epoxy_client.d.ts") AUTOGENERATED_TYPEDEFS=${AUTOGENERATED_TYPEDEFS%%export class IntoUnderlyingByteSource*} -echo "$AUTOGENERATED_TYPEDEFS" >"module.d.ts" -echo "} export default function epoxy(): Promise;" >> "module.d.ts" +echo "$AUTOGENERATED_TYPEDEFS" >"epoxy-module-bundled.d.ts" +echo "} export default function epoxy(): Promise;" >> "epoxy-module-bundled.d.ts" + +cp out/epoxy_client.js epoxy.js +cp out/epoxy_client.d.ts epoxy.d.ts +cp out/epoxy_client_bg.wasm epoxy.wasm rm -rf out/ echo "[ws] done!" diff --git a/client/demo.js b/client/demo.js index da2d24d..0a39700 100644 --- a/client/demo.js +++ b/client/demo.js @@ -8,13 +8,20 @@ const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test"); + const log = (str) => { + let el = document.createElement("div"); + el.innerText = str; + document.getElementById("logs").appendChild(el); + console.warn(str); + } + let { EpoxyClient } = await epoxy(); const tconn0 = performance.now(); // args: websocket url, user agent, redirect limit let epoxy_client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10); const tconn1 = performance.now(); - console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); + log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); if (should_feature_test) { @@ -47,20 +54,22 @@ const num_tests = 10; let total_mux = 0; - for (const _ of Array(num_tests).keys()) { + for (const i of Array(num_tests).keys()) { + log(`running mux test ${i}`); total_mux += await test_mux("https://httpbin.org/get"); } total_mux = total_mux / num_tests; let total_native = 0; - for (const _ of Array(num_tests).keys()) { + for (const i of Array(num_tests).keys()) { + log(`running native test ${i}`); total_native += await test_native("https://httpbin.org/get"); } total_native = total_native / num_tests; - console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); - console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); - console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + log(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); + log(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); } else if (should_ws_test) { let ws = await epoxy_client.connect_ws( () => console.log("opened"), @@ -80,5 +89,5 @@ console.warn(resp, Object.fromEntries(resp.headers)); console.warn(await resp.text()); } - if (!should_ws_test) alert("you can open console now"); + log("done"); })(); diff --git a/client/index.html b/client/index.html index c8ff7ed..cf37260 100644 --- a/client/index.html +++ b/client/index.html @@ -2,12 +2,19 @@ epoxy - + + - running... (wait for the browser alert if not running ws test) +
+ running... (wait for the browser alert if not running ws test) +
+
diff --git a/package.json b/package.json index 3d6c731..15b432c 100644 --- a/package.json +++ b/package.json @@ -16,8 +16,8 @@ "author": "MercuryWorkshop", "repository": "https://github.com/MercuryWorkshop/epoxy-tls", "license": "MIT", - "browser": "./client/module.js", - "module": "./client/module.js", - "main": "./client/module.js", - "types": "./client/module.d.ts" + "browser": "./client/epoxy-module-bundled.js", + "module": "./client/epoxy-module-bundled.js", + "main": "./client/epoxy-module-bundled.js", + "types": "./client/epoxy-module-bundled.d.ts" } From fa2b84d646c0850c4650022b51de6df5b90a2e2d Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 31 Jan 2024 08:13:06 -0800 Subject: [PATCH 13/26] remove unnecessary mut self references --- client/demo.js | 35 +++++++++++++++++++++++++++++++++++ client/src/lib.rs | 8 ++++---- client/src/websocket.rs | 2 +- server/src/main.rs | 6 ++++-- wisp/src/lib.rs | 11 +++++------ wisp/src/stream.rs | 12 ++++++------ 6 files changed, 55 insertions(+), 19 deletions(-) diff --git a/client/demo.js b/client/demo.js index 0a39700..8f27fa6 100644 --- a/client/demo.js +++ b/client/demo.js @@ -5,6 +5,7 @@ ); const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test"); + const should_parallel_test = (new URL(window.location.href)).searchParams.has("parallel_test"); const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test"); @@ -36,6 +37,40 @@ console.warn(url, resp, Object.fromEntries(resp.headers)); console.warn(await resp.text()); } + } else if (should_parallel_test) { + const test_mux = async (url) => { + const t0 = performance.now(); + await epoxy_client.fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; + + const test_native = async (url) => { + const t0 = performance.now(); + await fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; + + const num_tests = 10; + + let total_mux = 0; + await Promise.all([...Array(num_tests).keys()].map(async i=>{ + log(`running mux test ${i}`); + return await test_mux("https://httpbin.org/get"); + })).then((vals)=>{total_mux = vals.reduce((acc, x) => acc + x, 0)}); + total_mux = total_mux / num_tests; + + let total_native = 0; + await Promise.all([...Array(num_tests).keys()].map(async i=>{ + log(`running native test ${i}`); + return await test_native("https://httpbin.org/get"); + })).then((vals)=>{total_native = vals.reduce((acc, x) => acc + x, 0)}); + total_native = total_native / num_tests; + + log(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); + log(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); } else if (should_perf_test) { const test_mux = async (url) => { const t0 = performance.now(); diff --git a/client/src/lib.rs b/client/src/lib.rs index 80b9c58..dabadf2 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -173,7 +173,7 @@ impl EpoxyClient { }) } - async fn get_http_io(&mut self, url: &Uri) -> Result { + async fn get_http_io(&self, url: &Uri) -> Result { let url_host = url.host().replace_err("URL must have a host")?; let url_port = utils::get_url_port(url)?; let channel = self @@ -203,7 +203,7 @@ impl EpoxyClient { } async fn send_req( - &mut self, + &self, req: http::Request, should_redirect: bool, ) -> Result<(hyper::Response, Uri, bool), JsError> { @@ -231,7 +231,7 @@ impl EpoxyClient { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect_ws( - &mut self, + &self, onopen: Function, onclose: Function, onerror: Function, @@ -247,7 +247,7 @@ impl EpoxyClient { } pub async fn fetch( - &mut self, + &self, url: String, options: Object, ) -> Result { diff --git a/client/src/websocket.rs b/client/src/websocket.rs index 2ce9149..addae2c 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -30,7 +30,7 @@ impl EpxWebSocket { // shut up #[allow(clippy::too_many_arguments)] pub async fn connect( - tcp: &mut EpoxyClient, + tcp: &EpoxyClient, onopen: Function, onclose: Function, onerror: Function, diff --git a/server/src/main.rs b/server/src/main.rs index 7b0b35c..0aa7194 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -117,6 +117,7 @@ async fn handle_mux( loop { tokio::select! { event = stream.read() => { + println!("ws rx"); match event { Some(event) => match event { WsEvent::Send(data) => { @@ -128,6 +129,7 @@ async fn handle_mux( } }, event = tcp_stream_framed.next() => { + println!("tcp rx"); match event.and_then(|x| x.ok()) { Some(event) => stream.write(event.into()).await?, None => return Ok(true), @@ -175,8 +177,8 @@ async fn accept_ws( println!("{:?}: connected", addr); ServerMux::handle(rx, tx, &mut |packet, stream| async move { - let mut close_err = stream.get_close_handle(); - let mut close_ok = stream.get_close_handle(); + let close_err = stream.get_close_handle(); + let close_ok = stream.get_close_handle(); tokio::spawn(async move { let _ = handle_mux(packet, stream) .or_else(|err| async move { diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index d4f843e..0e75e7d 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -10,7 +10,7 @@ pub use crate::packet::*; pub use crate::stream::*; use dashmap::DashMap; -use futures::{channel::mpsc, Future, StreamExt}; +use futures::{channel::mpsc, Future, FutureExt, StreamExt}; use std::sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, @@ -109,11 +109,10 @@ impl ServerMux { R: ws::WebSocketRead, FR: std::future::Future>, { - futures::try_join! { - self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()), - self.server_msg_loop(rx, handler_fn) + futures::select! { + x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, + x = self.server_msg_loop(rx, handler_fn).fuse() => x } - .map(|_| ()) } async fn server_close_loop( @@ -299,7 +298,7 @@ impl ClientMux { } pub async fn client_new_stream( - &mut self, + &self, stream_type: StreamType, host: String, port: u16, diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index cd9daab..3998c9d 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -69,7 +69,7 @@ where } impl MuxStreamWrite { - pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } @@ -86,7 +86,7 @@ impl MuxStreamWrite { } } - pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } @@ -102,7 +102,7 @@ impl MuxStreamWrite { } pub(crate) fn into_sink(self) -> Pin + Send>> { - Box::pin(sink::unfold(self, |mut tx, data| async move { + Box::pin(sink::unfold(self, |tx, data| async move { tx.write(data).await?; Ok(tx) })) @@ -155,7 +155,7 @@ impl MuxStream { self.rx.read().await } - pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { self.tx.write(data).await } @@ -163,7 +163,7 @@ impl MuxStream { self.tx.get_close_handle() } - pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { self.tx.close(reason).await } @@ -186,7 +186,7 @@ pub struct MuxStreamCloser { } impl MuxStreamCloser { - pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } From 9f1561fa765cc073a81515cff26f0ce2a6c4581d Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 3 Feb 2024 16:31:02 -0800 Subject: [PATCH 14/26] make server mux more like client mux and fix deadlock --- Cargo.lock | 1 - server/src/main.rs | 24 ++++--- wisp/Cargo.toml | 1 - wisp/src/lib.rs | 152 ++++++++++++++++++++++++++------------------- 4 files changed, 101 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eec8ba9..eefc704 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1449,7 +1449,6 @@ version = "0.1.0" dependencies = [ "async_io_stream", "bytes", - "dashmap", "fastwebsockets", "futures", "futures-util", diff --git a/server/src/main.rs b/server/src/main.rs index 0aa7194..aeca64e 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -20,7 +20,7 @@ use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, W type HttpBody = http_body_util::Empty; -#[tokio::main(flavor = "multi_thread")] +#[tokio::main] async fn main() -> Result<(), Error> { let pem = include_bytes!("./pem.pem"); let key = include_bytes!("./key.pem"); @@ -117,7 +117,6 @@ async fn handle_mux( loop { tokio::select! { event = stream.read() => { - println!("ws rx"); match event { Some(event) => match event { WsEvent::Send(data) => { @@ -129,10 +128,9 @@ async fn handle_mux( } }, event = tcp_stream_framed.next() => { - println!("tcp rx"); match event.and_then(|x| x.ok()) { Some(event) => stream.write(event.into()).await?, - None => return Ok(true), + None => break, } } } @@ -176,10 +174,18 @@ async fn accept_ws( println!("{:?}: connected", addr); - ServerMux::handle(rx, tx, &mut |packet, stream| async move { - let close_err = stream.get_close_handle(); - let close_ok = stream.get_close_handle(); + let (mut mux, fut) = ServerMux::new(rx, tx); + + tokio::spawn(async move { + if let Err(e) = fut.await { + println!("err in mux: {:?}", e); + } + }); + + while let Some((packet, stream)) = mux.server_new_stream().await { tokio::spawn(async move { + let close_err = stream.get_close_handle(); + let close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) .or_else(|err| async move { let _ = close_err.close(0x03).await; @@ -194,9 +200,7 @@ async fn accept_ws( }) .await; }); - Ok(()) - }) - .await?; + } println!("{:?}: disconnected", addr); Ok(()) diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index ee3c3d2..fc834d0 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] async_io_stream = "0.3.3" bytes = "1.5.0" -dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 0e75e7d..9326ece 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -9,11 +9,13 @@ mod ws_stream_wasm; pub use crate::packet::*; pub use crate::stream::*; -use dashmap::DashMap; -use futures::{channel::mpsc, Future, FutureExt, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, +use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt}; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, }; #[derive(Debug, PartialEq)] @@ -68,63 +70,45 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub struct ServerMux +struct ServerMuxInner where - W: ws::WebSocketWrite, + W: ws::WebSocketWrite + Send + 'static, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, close_tx: mpsc::UnboundedSender, } -impl ServerMux { - pub fn handle<'a, FR, R>( - read: R, - write: W, - handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream) -> FR, - ) -> impl Future> + 'a - where - FR: std::future::Future> + 'a, - R: ws::WebSocketRead + 'a, - W: ws::WebSocketWrite + 'a, - { - let (tx, rx) = mpsc::unbounded::(); - let write = ws::LockedWebSocketWrite::new(write); - let map = Arc::new(DashMap::new()); - let inner = ServerMux { - stream_map: map.clone(), - tx: write.clone(), - close_tx: tx, - }; - inner.into_future(read, rx, handler_fn) - } - - async fn into_future( +impl ServerMuxInner { + pub async fn into_future( self, rx: R, close_rx: mpsc::UnboundedReceiver, - handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, ) -> Result<(), WispError> where R: ws::WebSocketRead, - FR: std::future::Future>, { - futures::select! { + let ret = futures::select! { x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, - x = self.server_msg_loop(rx, handler_fn).fuse() => x - } + x = self.server_msg_loop(rx, muxstream_sender).fuse() => x + }; + self.stream_map.lock().await.iter().for_each(|x| { + let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01))); + }); + ret } async fn server_close_loop( &self, mut close_rx: mpsc::UnboundedReceiver, - stream_map: Arc>>, + stream_map: Arc>>>, tx: ws::LockedWebSocketWrite, ) -> Result<(), WispError> { while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if stream_map.clone().remove(&stream_id).is_some() { + if stream_map.lock().await.remove(&stream_id).is_some() { let _ = channel.send( tx.write_frame(Packet::new_close(stream_id, reason).into()) .await, @@ -138,14 +122,13 @@ impl ServerMux { Ok(()) } - async fn server_msg_loop( + async fn server_msg_loop( &self, mut rx: R, - handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, ) -> Result<(), WispError> where R: ws::WebSocketRead, - FR: std::future::Future>, { self.tx .write_frame(Packet::new_continue(0, u32::MAX).into()) @@ -157,21 +140,22 @@ impl ServerMux { match packet.packet { Connect(inner_packet) => { let (ch_tx, ch_rx) = mpsc::unbounded(); - self.stream_map.clone().insert(packet.stream_id, ch_tx); - let _ = handler_fn( - inner_packet, - MuxStream::new( - packet.stream_id, - ch_rx, - self.tx.clone(), - self.close_tx.clone(), - AtomicBool::new(false).into(), - ), - ) - .await; + self.stream_map.lock().await.insert(packet.stream_id, ch_tx); + muxstream_sender + .unbounded_send(( + inner_packet, + MuxStream::new( + packet.stream_id, + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + ), + )) + .map_err(|x| WispError::Other(Box::new(x)))?; } Data(data) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); self.tx .write_frame( @@ -182,24 +166,59 @@ impl ServerMux { } Continue(_) => unreachable!(), Close(inner_packet) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.clone().remove(&packet.stream_id); + self.stream_map.lock().await.remove(&packet.stream_id); } } } + } else { + break; } } + drop(muxstream_sender); Ok(()) } } +pub struct ServerMux +where + W: ws::WebSocketWrite + Send + 'static, +{ + muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, +} + +impl ServerMux { + pub fn new(read: R, write: W) -> (Self, impl Future>) + where + R: ws::WebSocketRead, + { + let (close_tx, close_rx) = mpsc::unbounded::(); + let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let write = ws::LockedWebSocketWrite::new(write); + let map = Arc::new(Mutex::new(HashMap::new())); + ( + Self { muxstream_recv: rx }, + ServerMuxInner { + tx: write, + close_tx, + stream_map: map.clone(), + } + .into_future(read, close_rx, tx), + ) + } + + pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { + self.muxstream_recv.next().await + } +} + pub struct ClientMuxInner where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, } impl ClientMuxInner { @@ -211,7 +230,10 @@ impl ClientMuxInner { where R: ws::WebSocketRead, { - futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ()) + futures::select! { + x = self.client_bg_loop(close_rx).fuse() => x, + x = self.client_loop(rx).fuse() => x + } } async fn client_bg_loop( @@ -221,7 +243,7 @@ impl ClientMuxInner { while let Some(msg) = close_rx.next().await { match msg { MuxEvent::Close(stream_id, reason, channel) => { - if self.stream_map.clone().remove(&stream_id).is_some() { + if self.stream_map.lock().await.remove(&stream_id).is_some() { let _ = channel.send( self.tx .write_frame(Packet::new_close(stream_id, reason).into()) @@ -246,15 +268,15 @@ impl ClientMuxInner { match packet.packet { Connect(_) => unreachable!(), Data(data) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); } } Continue(_) => {} Close(inner_packet) => { - if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.clone().remove(&packet.stream_id); + self.stream_map.lock().await.remove(&packet.stream_id); } } } @@ -269,7 +291,7 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>, + stream_map: Arc>>>, next_free_stream_id: AtomicU32, close_tx: mpsc::UnboundedSender, } @@ -280,7 +302,7 @@ impl ClientMux { R: ws::WebSocketRead, { let (tx, rx) = mpsc::unbounded::(); - let map = Arc::new(DashMap::new()); + let map = Arc::new(Mutex::new(HashMap::new())); let write = ws::LockedWebSocketWrite::new(write); ( Self { @@ -314,7 +336,7 @@ impl ClientMux { .ok_or(WispError::MaxStreamCountReached)?, Ordering::Release, ); - self.stream_map.clone().insert(stream_id, ch_tx); + self.stream_map.lock().await.insert(stream_id, ch_tx); Ok(MuxStream::new( stream_id, ch_rx, From be340c0f82ba70fa6583d2a5e5440d0b314c1414 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 3 Feb 2024 17:06:24 -0800 Subject: [PATCH 15/26] optimize for speed --- Cargo.toml | 2 +- client/build.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2e7eb18..032dadf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,4 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 'z' +opt-level = 3 diff --git a/client/build.sh b/client/build.sh index 6b0a324..5439060 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -Oz out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -O4 out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] optimized" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") From ac39d82a53b8d2ad67e7a4cc7560ba69acb36a78 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 3 Feb 2024 22:46:19 -0800 Subject: [PATCH 16/26] optimizations and more deadlock fixes --- Cargo.lock | 39 +++++++++- Cargo.toml | 5 +- client/build.sh | 2 +- client/demo.js | 125 ++++++++++++++++++++++----------- client/src/lib.rs | 7 ++ simple-wisp-client/Cargo.toml | 15 ++++ simple-wisp-client/src/main.rs | 105 +++++++++++++++++++++++++++ wisp/src/lib.rs | 4 +- 8 files changed, 253 insertions(+), 49 deletions(-) create mode 100644 simple-wisp-client/Cargo.toml create mode 100644 simple-wisp-client/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index eefc704..0d1f3d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -727,6 +727,16 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + [[package]] name = "parking_lot_core" version = "0.9.9" @@ -987,12 +997,35 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "simdutf8" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" +[[package]] +name = "simple-wisp-client" +version = "0.1.0" +dependencies = [ + "bytes", + "fastwebsockets", + "futures", + "http-body-util", + "hyper", + "tokio", + "tokio-native-tls", + "wisp-mux", +] + [[package]] name = "slab" version = "0.4.9" @@ -1076,16 +1109,18 @@ dependencies = [ [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", "libc", "mio", "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", diff --git a/Cargo.toml b/Cargo.toml index 032dadf..2fcf5e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,11 @@ [workspace] resolver = "2" -members = ["server", "client", "wisp"] +members = ["server", "client", "wisp", "simple-wisp-client"] [patch.crates-io] rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 3 +opt-level = 'z' +codegen-units = 1 diff --git a/client/build.sh b/client/build.sh index 5439060..9d5a889 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -O4 out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -Oz --vacuum --dce out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] optimized" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") diff --git a/client/demo.js b/client/demo.js index 8f27fa6..231842a 100644 --- a/client/demo.js +++ b/client/demo.js @@ -4,16 +4,21 @@ "color:red;font-size:3rem;font-weight:bold" ); - const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test"); - const should_parallel_test = (new URL(window.location.href)).searchParams.has("parallel_test"); - const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test"); - const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test"); + const params = (new URL(window.location.href)).searchParams; + + const should_feature_test = params.has("feature_test"); + const should_multiparallel_test = params.has("multi_parallel_test"); + const should_parallel_test = params.has("parallel_test"); + const should_multiperf_test = params.has("multi_perf_test"); + const should_perf_test = params.has("perf_test"); + const should_ws_test = params.has("ws_test"); const log = (str) => { let el = document.createElement("div"); el.innerText = str; document.getElementById("logs").appendChild(el); console.warn(str); + window.scrollTo(0, document.body.scrollHeight); } let { EpoxyClient } = await epoxy(); @@ -24,6 +29,19 @@ const tconn1 = performance.now(); log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); + const test_mux = async (url) => { + const t0 = performance.now(); + await epoxy_client.fetch(url); + const t1 = performance.now(); + return t1 - t0; + }; + + const test_native = async (url) => { + const t0 = performance.now(); + await fetch(url, { cache: "no-store" }); + const t1 = performance.now(); + return t1 - t0; + }; if (should_feature_test) { for (const url of [ @@ -37,55 +55,78 @@ console.warn(url, resp, Object.fromEntries(resp.headers)); console.warn(await resp.text()); } + } else if (should_multiparallel_test) { + const num_tests = 10; + let total_mux_minus_native = 0; + for (const _ of Array(num_tests).keys()) { + let total_mux = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running mux test ${i}`); + return await test_mux("https://httpbin.org/get"); + })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); + total_mux = total_mux / num_tests; + + let total_native = 0; + await Promise.all([...Array(num_tests).keys()].map(async i => { + log(`running native test ${i}`); + return await test_native("https://httpbin.org/get"); + })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); + total_native = total_native / num_tests; + + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + total_mux_minus_native += total_mux - total_native; + } + total_mux_minus_native = total_mux_minus_native / num_tests; + log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); } else if (should_parallel_test) { - const test_mux = async (url) => { - const t0 = performance.now(); - await epoxy_client.fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - - const test_native = async (url) => { - const t0 = performance.now(); - await fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - const num_tests = 10; let total_mux = 0; - await Promise.all([...Array(num_tests).keys()].map(async i=>{ + await Promise.all([...Array(num_tests).keys()].map(async i => { log(`running mux test ${i}`); return await test_mux("https://httpbin.org/get"); - })).then((vals)=>{total_mux = vals.reduce((acc, x) => acc + x, 0)}); + })).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) }); total_mux = total_mux / num_tests; let total_native = 0; - await Promise.all([...Array(num_tests).keys()].map(async i=>{ + await Promise.all([...Array(num_tests).keys()].map(async i => { log(`running native test ${i}`); return await test_native("https://httpbin.org/get"); - })).then((vals)=>{total_native = vals.reduce((acc, x) => acc + x, 0)}); + })).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) }); total_native = total_native / num_tests; - log(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); - log(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + } else if (should_multiperf_test) { + const num_tests = 10; + let total_mux_minus_native = 0; + for (const _ of Array(num_tests).keys()) { + let total_mux = 0; + for (const i of Array(num_tests).keys()) { + log(`running mux test ${i}`); + total_mux += await test_mux("https://httpbin.org/get"); + } + total_mux = total_mux / num_tests; + + let total_native = 0; + for (const i of Array(num_tests).keys()) { + log(`running native test ${i}`); + total_native += await test_native("https://httpbin.org/get"); + } + total_native = total_native / num_tests; + + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + total_mux_minus_native += total_mux - total_native; + } + total_mux_minus_native = total_mux_minus_native / num_tests; + log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); + } else if (should_perf_test) { - const test_mux = async (url) => { - const t0 = performance.now(); - await epoxy_client.fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - - const test_native = async (url) => { - const t0 = performance.now(); - await fetch(url); - const t1 = performance.now(); - return t1 - t0; - }; - const num_tests = 10; let total_mux = 0; @@ -102,9 +143,9 @@ } total_native = total_native / num_tests; - log(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`); - log(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`); - log(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); + log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`); + log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`); + log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`); } else if (should_ws_test) { let ws = await epoxy_client.connect_ws( () => console.log("opened"), diff --git a/client/src/lib.rs b/client/src/lib.rs index dabadf2..0ea0467 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -73,10 +73,12 @@ async fn send_req( None }; + debug!("sending req"); let res = req_sender .send_request(req) .await .replace_err("Failed to send request"); + debug!("recieved res"); match res { Ok(res) => { if utils::is_redirect(res.status().as_u16()) @@ -176,6 +178,7 @@ impl EpoxyClient { async fn get_http_io(&self, url: &Uri) -> Result { let url_host = url.host().replace_err("URL must have a host")?; let url_port = utils::get_url_port(url)?; + debug!("making channel"); let channel = self .mux .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) @@ -187,6 +190,7 @@ impl EpoxyClient { if utils::get_is_secure(url)? { let cloned_uri = url_host.to_string().clone(); let connector = TlsConnector::from(self.rustls_config.clone()); + debug!("connecting channel"); let io = connector .connect( cloned_uri @@ -196,8 +200,11 @@ impl EpoxyClient { ) .await .replace_err("Failed to perform TLS handshake")?; + debug!("connected channel"); Ok(EpxStream::Left(io)) } else { + debug!("connecting channel"); + debug!("connected channel"); Ok(EpxStream::Right(channel)) } } diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml new file mode 100644 index 0000000..972d5da --- /dev/null +++ b/simple-wisp-client/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "simple-wisp-client" +version = "0.1.0" +edition = "2021" + +[dependencies] +bytes = "1.5.0" +fastwebsockets = { version = "0.6.0", features = ["unstable-split", "upgrade"] } +futures = "0.3.30" +http-body-util = "0.1.0" +hyper = { version = "1.1.0", features = ["http1", "client"] } +tokio = { version = "1.36.0", features = ["full"] } +tokio-native-tls = "0.3.1" +wisp-mux = { path = "../wisp", features = ["fastwebsockets"]} + diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs new file mode 100644 index 0000000..db49784 --- /dev/null +++ b/simple-wisp-client/src/main.rs @@ -0,0 +1,105 @@ +use bytes::Bytes; +use fastwebsockets::{handshake, FragmentCollectorRead}; +use futures::io::AsyncWriteExt; +use http_body_util::Empty; +use hyper::{ + header::{CONNECTION, UPGRADE}, + Request, +}; +use std::{error::Error, future::Future}; +use tokio::net::TcpStream; +use tokio_native_tls::{native_tls, TlsConnector}; +use wisp_mux::{ClientMux, StreamType}; + +#[derive(Debug)] +struct StrError(String); + +impl StrError { + pub fn new(str: &str) -> Self { + Self(str.to_string()) + } +} + +impl std::fmt::Display for StrError { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + write!(fmt, "{}", self.0) + } +} + +impl Error for StrError {} + +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = std::env::args() + .nth(1) + .ok_or(StrError::new("no src addr"))?; + + let addr_port: u16 = std::env::args() + .nth(2) + .ok_or(StrError::new("no src port"))? + .parse()?; + + let addr_dest = std::env::args() + .nth(3) + .ok_or(StrError::new("no dest addr"))?; + + let addr_dest_port: u16 = std::env::args() + .nth(4) + .ok_or(StrError::new("no dest port"))? + .parse()?; + + let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; + let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); + let socket = cx.connect(&addr, socket).await?; + let req = Request::builder() + .method("GET") + .uri(format!("wss://{}:{}/", &addr, addr_port)) + .header("Host", &addr) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header( + "Sec-WebSocket-Key", + fastwebsockets::handshake::generate_key(), + ) + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Protocol", "wisp-v1") + .body(Empty::::new())?; + + let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; + + let (rx, tx) = ws.split(tokio::io::split); + let rx = FragmentCollectorRead::new(rx); + + let (mux, fut) = ClientMux::new(rx, tx); + + tokio::task::spawn(fut); + + let mut hi: u64 = 0; + loop { + let mut channel = mux + .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) + .await? + .into_io() + .into_asyncrw(); + for _ in 0..10 { + channel.write_all(b"hiiiiiiii").await?; + hi += 1; + println!("said hi {}", hi); + } + } + + #[allow(unreachable_code)] + Ok(()) +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 9326ece..37b81c5 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -168,8 +168,8 @@ impl ServerMuxInner { Close(inner_packet) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.lock().await.remove(&packet.stream_id); } + self.stream_map.lock().await.remove(&packet.stream_id); } } } else { @@ -276,8 +276,8 @@ impl ClientMuxInner { Close(inner_packet) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); - self.stream_map.lock().await.remove(&packet.stream_id); } + self.stream_map.lock().await.remove(&packet.stream_id); } } } From 54011e1f8ac280232347d4ac7861a7b2ffd48824 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 4 Feb 2024 00:02:21 -0800 Subject: [PATCH 17/26] raw tls sockets --- client/demo.js | 11 ++++++ client/src/lib.rs | 81 ++++++++++++++++++++++++--------------- client/src/tls_stream.rs | 82 ++++++++++++++++++++++++++++++++++++++++ client/src/websocket.rs | 2 +- 4 files changed, 144 insertions(+), 32 deletions(-) create mode 100644 client/src/tls_stream.rs diff --git a/client/demo.js b/client/demo.js index 231842a..30379b8 100644 --- a/client/demo.js +++ b/client/demo.js @@ -12,6 +12,7 @@ const should_multiperf_test = params.has("multi_perf_test"); const should_perf_test = params.has("perf_test"); const should_ws_test = params.has("ws_test"); + const should_tls_test = params.has("rawtls_test"); const log = (str) => { let el = document.createElement("div"); @@ -160,6 +161,16 @@ await ws.send("data"); await (new Promise((res, _) => setTimeout(res, 100))); } + } else if (should_tls_test) { + let decoder = new TextDecoder(); + let ws = await epoxy_client.connect_tls( + () => console.log("opened"), + () => console.log("closed"), + err => console.error(err), + msg => { console.log(msg); console.log(decoder.decode(msg)) }, + "alicesworld.tech:443", + ); + await ws.send("GET / HTTP 1.1\r\nHost: alicesworld.tech\r\nConnection: close\r\n\r\n"); } else { let resp = await epoxy_client.fetch("https://httpbin.org/get"); console.warn(resp, Object.fromEntries(resp.headers)); diff --git a/client/src/lib.rs b/client/src/lib.rs index 0ea0467..afd7b20 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,10 +1,12 @@ #![feature(let_chains)] #[macro_use] mod utils; +mod tls_stream; mod tokioio; mod websocket; mod wrappers; +use tls_stream::EpxTlsStream; use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use websocket::EpxWebSocket; @@ -15,10 +17,7 @@ use std::sync::Arc; use async_compression::tokio::bufread as async_comp; use async_io_stream::IoStream; use bytes::Bytes; -use futures_util::{ - stream::SplitSink, - StreamExt, -}; +use futures_util::{stream::SplitSink, StreamExt}; use http::{uri, HeaderName, HeaderValue, Request, Response}; use hyper::{body::Incoming, client::conn::http1::Builder, Uri}; use js_sys::{Array, Function, Object, Reflect, Uint8Array}; @@ -30,7 +29,7 @@ use tokio_util::{ use wasm_bindgen::prelude::*; use web_sys::TextEncoder; use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; -use ws_stream_wasm::{WsMeta, WsStream, WsMessage}; +use ws_stream_wasm::{WsMessage, WsMeta, WsStream}; type HttpBody = http_body_util::Full; @@ -45,14 +44,14 @@ enum EpxCompression { Gzip, } -type EpxTlsStream = TlsStream>>; -type EpxUnencryptedStream = IoStream>; -type EpxStream = Either; +type EpxIoTlsStream = TlsStream>>; +type EpxIoUnencryptedStream = IoStream>; +type EpxIoStream = Either; async fn send_req( req: http::Request, should_redirect: bool, - io: EpxStream, + io: EpxIoStream, ) -> Result { let (mut req_sender, conn) = Builder::new() .title_case_headers(true) @@ -175,10 +174,7 @@ impl EpoxyClient { }) } - async fn get_http_io(&self, url: &Uri) -> Result { - let url_host = url.host().replace_err("URL must have a host")?; - let url_port = utils::get_url_port(url)?; - debug!("making channel"); + async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result { let channel = self .mux .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) @@ -186,26 +182,42 @@ impl EpoxyClient { .replace_err("Failed to create multiplexor channel")? .into_io() .into_asyncrw(); + let cloned_uri = url_host.to_string().clone(); + let connector = TlsConnector::from(self.rustls_config.clone()); + debug!("connecting channel"); + let io = connector + .connect( + cloned_uri + .try_into() + .replace_err("Failed to parse URL (rustls)")?, + channel, + ) + .await + .replace_err("Failed to perform TLS handshake")?; + debug!("connected channel"); + Ok(io) + } + + async fn get_http_io(&self, url: &Uri) -> Result { + let url_host = url.host().replace_err("URL must have a host")?; + let url_port = utils::get_url_port(url)?; if utils::get_is_secure(url)? { - let cloned_uri = url_host.to_string().clone(); - let connector = TlsConnector::from(self.rustls_config.clone()); - debug!("connecting channel"); - let io = connector - .connect( - cloned_uri - .try_into() - .replace_err("Failed to parse URL (rustls)")?, - channel, - ) - .await - .replace_err("Failed to perform TLS handshake")?; - debug!("connected channel"); - Ok(EpxStream::Left(io)) + Ok(EpxIoStream::Left( + self.get_tls_io(url_host, url_port).await?, + )) } else { + debug!("making channel"); + let channel = self + .mux + .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) + .await + .replace_err("Failed to create multiplexor channel")? + .into_io() + .into_asyncrw(); debug!("connecting channel"); debug!("connected channel"); - Ok(EpxStream::Right(channel)) + Ok(EpxIoStream::Right(channel)) } } @@ -253,11 +265,18 @@ impl EpoxyClient { .await } - pub async fn fetch( + pub async fn connect_tls( &self, + onopen: Function, + onclose: Function, + onerror: Function, + onmessage: Function, url: String, - options: Object, - ) -> Result { + ) -> Result { + EpxTlsStream::connect(self, onopen, onclose, onerror, onmessage, url).await + } + + pub async fn fetch(&self, url: String, options: Object) -> Result { let uri = url.parse::().replace_err("Failed to parse URL")?; let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?; if *uri_scheme != uri::Scheme::HTTP && *uri_scheme != uri::Scheme::HTTPS { diff --git a/client/src/tls_stream.rs b/client/src/tls_stream.rs new file mode 100644 index 0000000..97e61a7 --- /dev/null +++ b/client/src/tls_stream.rs @@ -0,0 +1,82 @@ +use crate::*; + +use js_sys::Function; +use tokio::io::{split, AsyncWriteExt, WriteHalf}; +use tokio_util::io::ReaderStream; + +#[wasm_bindgen] +pub struct EpxTlsStream { + tx: WriteHalf, + onerror: Function, +} + +#[wasm_bindgen] +impl EpxTlsStream { + #[wasm_bindgen(constructor)] + pub fn new() -> Result { + Err(jerr!("Use EpoxyClient.connect_tls() instead.")) + } + + // shut up + #[allow(clippy::too_many_arguments)] + pub async fn connect( + tcp: &EpoxyClient, + onopen: Function, + onclose: Function, + onerror: Function, + onmessage: Function, + url: String, + ) -> Result { + let onerr = onerror.clone(); + let ret: Result = async move { + let url = Uri::try_from(url).replace_err("Failed to parse URL")?; + let url_host = url.host().replace_err("URL must have a host")?; + let url_port = url.port().replace_err("URL must have a port")?.into(); + + let io = tcp.get_tls_io(url_host, url_port).await?; + let (rx, tx) = split(io); + let mut rx = ReaderStream::new(rx); + + wasm_bindgen_futures::spawn_local(async move { + while let Some(Ok(data)) = rx.next().await { + let _ = onmessage.call1( + &JsValue::null(), + &jval!(Uint8Array::from(data.to_vec().as_slice())), + ); + } + let _ = onclose.call0(&JsValue::null()); + }); + + onopen + .call0(&Object::default()) + .replace_err("Failed to call onopen")?; + + Ok(Self { tx, onerror }) + } + .await; + if let Err(ret) = ret { + let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone())); + Err(ret) + } else { + ret + } + } + + #[wasm_bindgen] + pub async fn send(&mut self, payload: Uint8Array) -> Result<(), JsError> { + let onerr = self.onerror.clone(); + let ret = self.tx.write_all(&payload.to_vec()).await; + if let Err(ret) = ret { + let _ = onerr.call1(&JsValue::null(), &jval!(format!("{}", ret))); + Err(ret.into()) + } else { + Ok(ret?) + } + } + + #[wasm_bindgen] + pub async fn close(&mut self) -> Result<(), JsError> { + self.tx.shutdown().await?; + Ok(()) + } +} diff --git a/client/src/websocket.rs b/client/src/websocket.rs index addae2c..a83b755 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -68,7 +68,7 @@ impl EpxWebSocket { let (mut sender, conn) = Builder::new() .title_case_headers(true) .preserve_header_case(true) - .handshake::, Empty>(TokioIo::new(stream)) + .handshake::, Empty>(TokioIo::new(stream)) .await?; wasm_bindgen_futures::spawn_local(async move { From 2158b9323edc90320734344ef8e352b485ad9f6c Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 4 Feb 2024 12:06:51 -0800 Subject: [PATCH 18/26] enable atomics and back to optimizing for speed --- Cargo.lock | 24 +++--------------------- Cargo.toml | 2 +- client/Cargo.toml | 14 +++----------- client/build.sh | 14 ++++++++------ client/src/lib.rs | 5 ----- client/src/utils.rs | 11 +++-------- client/src/websocket.rs | 24 +++++++++++++----------- server/src/main.rs | 2 ++ simple-wisp-client/Cargo.toml | 1 + simple-wisp-client/src/main.rs | 13 +++++++++++-- 10 files changed, 45 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0d1f3d3..40758a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -154,16 +154,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "console_error_panic_hook" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" -dependencies = [ - "cfg-if", - "wasm-bindgen", -] - [[package]] name = "core-foundation" version = "0.9.4" @@ -231,12 +221,6 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - [[package]] name = "epoxy-client" version = "1.0.0" @@ -245,8 +229,6 @@ dependencies = [ "async_io_stream", "base64", "bytes", - "console_error_panic_hook", - "either", "fastwebsockets", "futures-util", "getrandom", @@ -255,7 +237,6 @@ dependencies = [ "hyper", "js-sys", "pin-project-lite", - "rand", "ring", "tokio", "tokio-rustls", @@ -453,9 +434,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "js-sys", @@ -1023,6 +1004,7 @@ dependencies = [ "hyper", "tokio", "tokio-native-tls", + "tokio-util", "wisp-mux", ] diff --git a/Cargo.toml b/Cargo.toml index 2fcf5e1..2e8971b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,5 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 'z' +opt-level = 3 codegen-units = 1 diff --git a/client/Cargo.toml b/client/Cargo.toml index 48ff0e5..3f8efd3 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -6,12 +6,8 @@ edition = "2021" [lib] crate-type = ["cdylib"] -[features] -default = ["console_error_panic_hook"] - [dependencies] bytes = "1.5.0" -console_error_panic_hook = { version = "0.1.7", optional = true } http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["client", "http1"] } @@ -24,19 +20,15 @@ futures-util = "0.3.30" js-sys = "0.3.66" webpki-roots = "0.26.0" tokio-rustls = "0.25.0" -web-sys = { version = "0.3.66", features = ["TextEncoder", "Navigator", "Response", "ResponseInit"] } +web-sys = { version = "0.3.66", features = ["TextEncoder", "Response", "ResponseInit"] } wasm-streams = "0.4.0" -either = "1.9.0" tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } -fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } -rand = "0.8.5" +fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } base64 = "0.21.7" wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] } - -[dependencies.getrandom] -features = ["js"] +getrandom = { version = "0.2.12", features = ["js"] } [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/build.sh b/client/build.sh index 9d5a889..7c2725b 100755 --- a/client/build.sh +++ b/client/build.sh @@ -5,18 +5,20 @@ shopt -s inherit_errexit rm -rf out/ || true mkdir out/ -cargo build --target wasm32-unknown-unknown --release -echo "[ws] built rust" +RUSTFLAGS='-C target-feature=+atomics,+bulk-memory' cargo build --target wasm32-unknown-unknown -Z build-std=panic_abort,std --release +echo "[ws] cargo finished" wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir out/ ../target/wasm32-unknown-unknown/release/epoxy_client.wasm -echo "[ws] bindgen finished" +echo "[ws] wasm-bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -Oz --vacuum --dce out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm -echo "[ws] optimized" +time wasm-opt -O4 --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +echo "[ws] wasm-opt finished" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") +# patch for websocket sharedarraybuffer error +AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//getObject(arg0).send(getArrayU8FromWasm0(arg1, arg2)/getObject(arg0).send(new Uint8Array(getArrayU8FromWasm0(arg1, arg2))} WASM_BASE64=$(base64 -w0 out/epoxy_client_bg.wasm) -AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input) \{/__wbg_init() \{let input=\'data:application/wasm;base64,$WASM_BASE64\'} +AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input, maybe_memory) \{/__wbg_init(input, maybe_memory) \{$'\n'if (!input) \{input=\'data:application/wasm;base64,$WASM_BASE64\'\}} AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//return __wbg_finalize_init\(instance\, module\);/__wbg_finalize_init\(instance\, module\); return epoxy} echo "$AUTOGENERATED_SOURCE" > epoxy-bundled.js cp epoxy-bundled.js epoxy-module-bundled.js diff --git a/client/src/lib.rs b/client/src/lib.rs index afd7b20..639c786 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -111,11 +111,6 @@ async fn send_req( } } -#[wasm_bindgen(start)] -async fn start() { - utils::set_panic_hook(); -} - #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, diff --git a/client/src/utils.rs b/client/src/utils.rs index 077a120..caf4498 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -4,11 +4,6 @@ use hyper::{header::HeaderValue, Uri}; use http::uri; use js_sys::{Array, Object}; -pub fn set_panic_hook() { - #[cfg(feature = "console_error_panic_hook")] - console_error_panic_hook::set_once(); -} - #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = console, js_name = debug)] @@ -55,15 +50,15 @@ pub trait ReplaceErr { fn replace_err_jv(self, err: &str) -> Result; } -impl ReplaceErr for Result { +impl ReplaceErr for Result { type Ok = T; fn replace_err(self, err: &str) -> Result<::Ok, JsError> { - self.map_err(|_| jerr!(err)) + self.map_err(|oe| jerr!(&format!("{}, original error: {:?}", err, oe))) } fn replace_err_jv(self, err: &str) -> Result<::Ok, JsValue> { - self.map_err(|_| jval!(err)) + self.map_err(|oe| jval!(&format!("{}, original error: {:?}", err, oe))) } } diff --git a/client/src/websocket.rs b/client/src/websocket.rs index a83b755..f823077 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -4,6 +4,7 @@ use base64::{engine::general_purpose::STANDARD, Engine}; use fastwebsockets::{ CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, }; +use futures_util::lock::Mutex; use http_body_util::Empty; use hyper::{ header::{CONNECTION, UPGRADE}, @@ -16,7 +17,7 @@ use tokio::io::WriteHalf; #[wasm_bindgen] pub struct EpxWebSocket { - tx: WebSocketWrite>>, + tx: Arc>>>>, onerror: Function, } @@ -44,7 +45,8 @@ impl EpxWebSocket { let url = Uri::try_from(url).replace_err("Failed to parse URL")?; let host = url.host().replace_err("URL must have a host")?; - let rand: [u8; 16] = rand::random(); + let mut rand: [u8; 16] = [0; 16]; + getrandom::getrandom(&mut rand)?; let key = STANDARD.encode(rand); let mut builder = Request::builder() @@ -88,16 +90,12 @@ impl EpxWebSocket { let (rx, tx) = ws.split(tokio::io::split); let mut rx = FragmentCollectorRead::new(rx); + let tx = Arc::new(Mutex::new(tx)); + let tx_cloned = tx.clone(); wasm_bindgen_futures::spawn_local(async move { while let Ok(frame) = rx - .read_frame(&mut |arg| async move { - error!( - "wtf is an obligated write {:?}, {:?}, {:?}", - arg.fin, arg.opcode, arg.payload - ); - Ok::<(), std::io::Error>(()) - }) + .read_frame(&mut |arg| async { tx_cloned.lock().await.write_frame(arg).await }) .await { match frame.opcode { @@ -137,10 +135,12 @@ impl EpxWebSocket { } #[wasm_bindgen] - pub async fn send(&mut self, payload: String) -> Result<(), JsError> { + pub async fn send(&self, payload: String) -> Result<(), JsError> { let onerr = self.onerror.clone(); let ret = self .tx + .lock() + .await .write_frame(Frame::text(Payload::Owned(payload.as_bytes().to_vec()))) .await; if let Err(ret) = ret { @@ -152,8 +152,10 @@ impl EpxWebSocket { } #[wasm_bindgen] - pub async fn close(&mut self) -> Result<(), JsError> { + pub async fn close(&self) -> Result<(), JsError> { self.tx + .lock() + .await .write_frame(Frame::close(CloseCode::Normal.into(), b"")) .await?; Ok(()) diff --git a/server/src/main.rs b/server/src/main.rs index aeca64e..04c39e7 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -78,6 +78,8 @@ async fn accept_http( let uri = req.uri().clone(); let (mut res, fut) = upgrade::upgrade(&mut req)?; + println!("{:?} {:?}", uri.path(), prefix); + if *uri.path() != prefix { tokio::spawn(async move { accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index 972d5da..9525805 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -11,5 +11,6 @@ http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["http1", "client"] } tokio = { version = "1.36.0", features = ["full"] } tokio-native-tls = "0.3.1" +tokio-util = "0.7.10" wisp-mux = { path = "../wisp", features = ["fastwebsockets"]} diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index db49784..b530f64 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -10,6 +10,7 @@ use std::{error::Error, future::Future}; use tokio::net::TcpStream; use tokio_native_tls::{native_tls, TlsConnector}; use wisp_mux::{ClientMux, StreamType}; +use tokio_util::either::Either; #[derive(Debug)] struct StrError(String); @@ -59,10 +60,18 @@ async fn main() -> Result<(), Box> { .nth(4) .ok_or(StrError::new("no dest port"))? .parse()?; + let should_tls: bool = std::env::args() + .nth(5) + .ok_or(StrError::new("no should tls"))? + .parse()?; let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; - let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); - let socket = cx.connect(&addr, socket).await?; + let socket = if should_tls { + let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); + Either::Left(cx.connect(&addr, socket).await?) + } else { + Either::Right(socket) + }; let req = Request::builder() .method("GET") .uri(format!("wss://{}:{}/", &addr, addr_port)) From 6ca14ad26a86513da6a301ccc6b48b44008734b3 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 5 Feb 2024 09:08:34 -0800 Subject: [PATCH 19/26] partially implement tower trait --- Cargo.lock | 26 +++++++ Cargo.toml | 2 +- client/build.sh | 2 +- wisp/Cargo.toml | 4 ++ wisp/src/lib.rs | 5 ++ wisp/src/tokioio.rs | 171 ++++++++++++++++++++++++++++++++++++++++++++ wisp/src/tower.rs | 13 ++++ 7 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 wisp/src/tokioio.rs create mode 100644 wisp/src/tower.rs diff --git a/Cargo.lock b/Cargo.lock index 40758a6..eed8bfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1154,12 +1154,36 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + [[package]] name = "tracing" version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1469,8 +1493,10 @@ dependencies = [ "fastwebsockets", "futures", "futures-util", + "hyper", "pin-project-lite", "tokio", + "tower", "ws_stream_wasm", ] diff --git a/Cargo.toml b/Cargo.toml index 2e8971b..2fcf5e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,5 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 3 +opt-level = 'z' codegen-units = 1 diff --git a/client/build.sh b/client/build.sh index 7c2725b..7f402f0 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] wasm-bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -O4 --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -Oz --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] wasm-opt finished" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index fc834d0..9448613 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -9,11 +9,15 @@ bytes = "1.5.0" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" +hyper = { version = "1.1.0", optional = true } pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true } +tower = { version = "0.4.13", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] ws_stream_wasm = ["dep:ws_stream_wasm"] tokio_io = ["async_io_stream/tokio_io"] +hyper_tower = ["dep:tower", "dep:hyper", "dep:tokio"] + diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 37b81c5..e211f13 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(impl_trait_in_assoc_type)] #[cfg(feature = "fastwebsockets")] mod fastwebsockets; mod packet; @@ -5,6 +6,10 @@ mod stream; pub mod ws; #[cfg(feature = "ws_stream_wasm")] mod ws_stream_wasm; +#[cfg(feature = "hyper_tower")] +pub mod tokioio; +#[cfg(feature = "hyper_tower")] +pub mod tower; pub use crate::packet::*; pub use crate::stream::*; diff --git a/wisp/src/tokioio.rs b/wisp/src/tokioio.rs new file mode 100644 index 0000000..7d6acc0 --- /dev/null +++ b/wisp/src/tokioio.rs @@ -0,0 +1,171 @@ +#![allow(dead_code)] +// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs +// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need +// hyper-util for TokioIo. + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use pin_project_lite::pin_project; + +pin_project! { + /// A wrapping implementing hyper IO traits for a type that + /// implements Tokio's IO traits. + #[derive(Debug)] + pub struct TokioIo { + #[pin] + inner: T, + } +} + +impl TokioIo { + /// Wrap a type implementing Tokio's IO traits. + pub fn new(inner: T) -> Self { + Self { inner } + } + + /// Borrow the inner type. + pub fn inner(&self) -> &T { + &self.inner + } + + /// Mut borrow the inner type. + pub fn inner_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Consume this wrapper and get the inner type. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/wisp/src/tower.rs b/wisp/src/tower.rs new file mode 100644 index 0000000..6bf635c --- /dev/null +++ b/wisp/src/tower.rs @@ -0,0 +1,13 @@ +use futures::{Future, task::{Poll, Context}}; + +impl tower::Service for crate::ClientMux { + type Response = crate::tokioio::TokioIo>; + type Error = crate::WispError; + type Future = impl Future>; + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: hyper::Uri) -> Self::Future { + + } +} From b16fb8f654aa764f41ab8cdceb7b847a7f062a28 Mon Sep 17 00:00:00 2001 From: r58Playz Date: Mon, 5 Feb 2024 19:10:40 -0800 Subject: [PATCH 20/26] use hyper client --- Cargo.lock | 72 +++++++++++++++-- client/Cargo.toml | 8 +- client/demo.js | 8 +- client/src/lib.rs | 174 +++++++++++++++++----------------------- client/src/tokioio.rs | 171 --------------------------------------- client/src/utils.rs | 52 ++++++------ client/src/websocket.rs | 20 +---- client/src/wrappers.rs | 70 +++++++++++++++- wisp/Cargo.toml | 7 +- wisp/src/lib.rs | 8 +- wisp/src/tokioio.rs | 9 ++- wisp/src/tower.rs | 40 +++++++-- 12 files changed, 297 insertions(+), 342 deletions(-) delete mode 100644 client/src/tokioio.rs diff --git a/Cargo.lock b/Cargo.lock index eed8bfa..0d6939f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -235,12 +235,14 @@ dependencies = [ "http", "http-body-util", "hyper", + "hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)", "js-sys", "pin-project-lite", "ring", "tokio", "tokio-rustls", "tokio-util", + "tower-service", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", @@ -260,13 +262,19 @@ dependencies = [ "futures-util", "http-body-util", "hyper", - "hyper-util", + "hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "tokio", "tokio-native-tls", "tokio-util", "wisp-mux", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.8" @@ -292,7 +300,7 @@ dependencies = [ "base64", "http-body-util", "hyper", - "hyper-util", + "hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "pin-project", "rand", "sha1", @@ -451,6 +459,25 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -518,6 +545,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2", "http", "http-body", "httparse", @@ -530,9 +558,24 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "git+https://github.com/r58Playz/hyper-util-wasm#40813384dc4971677cd2a9aeb90f61b392a5bb70" dependencies = [ "bytes", "futures-channel", @@ -541,11 +584,21 @@ dependencies = [ "http-body", "hyper", "pin-project-lite", - "socket2", - "tokio", + "tower", + "tower-service", "tracing", ] +[[package]] +name = "indexmap" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "itoa" version = "1.0.10" @@ -1160,6 +1213,10 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", "tower-layer", "tower-service", "tracing", @@ -1494,9 +1551,10 @@ dependencies = [ "futures", "futures-util", "hyper", + "hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)", "pin-project-lite", "tokio", - "tower", + "tower-service", "ws_stream_wasm", ] diff --git a/client/Cargo.toml b/client/Cargo.toml index 3f8efd3..6e64d1d 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -10,9 +10,8 @@ crate-type = ["cdylib"] bytes = "1.5.0" http = "1.0.0" http-body-util = "0.1.0" -hyper = { version = "1.1.0", features = ["client", "http1"] } +hyper = { version = "1.1.0", features = ["client", "http1", "http2"] } pin-project-lite = "0.2.13" -tokio = { version = "1.35.1", default_features = false } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.39" ws_stream_wasm = { version = "0.7.4", features = ["tokio_io"] } @@ -26,9 +25,12 @@ tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] } +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io", "hyper_tower"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] } getrandom = { version = "0.2.12", features = ["js"] } +hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] } +tokio = { version = "1.36.0", default-features = false } +tower-service = "0.3.2" [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/demo.js b/client/demo.js index 30379b8..6f07fc6 100644 --- a/client/demo.js +++ b/client/demo.js @@ -50,7 +50,13 @@ ["https://httpbin.org/gzip", {}], ["https://httpbin.org/brotli", {}], ["https://httpbin.org/redirect/11", {}], - ["https://httpbin.org/redirect/1", { redirect: "manual" }] + ["https://httpbin.org/redirect/1", { redirect: "manual" }], + ["https://nghttp2.org/httpbin/get", {}], + ["https://nghttp2.org/httpbin/gzip", {}], + ["https://nghttp2.org/httpbin/brotli", {}], + ["https://nghttp2.org/httpbin/redirect/11", {}], + ["https://nghttp2.org/httpbin/redirect/1", { redirect: "manual" }] + ]) { let resp = await epoxy_client.fetch(url[0], url[1]); console.warn(url, resp, Object.fromEntries(resp.headers)); diff --git a/client/src/lib.rs b/client/src/lib.rs index 639c786..07214de 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,16 +1,14 @@ -#![feature(let_chains)] +#![feature(let_chains, impl_trait_in_assoc_type)] #[macro_use] mod utils; mod tls_stream; -mod tokioio; mod websocket; mod wrappers; use tls_stream::EpxTlsStream; -use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use websocket::EpxWebSocket; -use wrappers::IncomingBody; +use wrappers::{IncomingBody, TlsWispService}; use std::sync::Arc; @@ -19,7 +17,8 @@ use async_io_stream::IoStream; use bytes::Bytes; use futures_util::{stream::SplitSink, StreamExt}; use http::{uri, HeaderName, HeaderValue, Request, Response}; -use hyper::{body::Incoming, client::conn::http1::Builder, Uri}; +use hyper::{body::Incoming, Uri}; +use hyper_util::client::legacy::Client; use js_sys::{Array, Function, Object, Reflect, Uint8Array}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_util::{ @@ -28,7 +27,7 @@ use tokio_util::{ }; use wasm_bindgen::prelude::*; use web_sys::TextEncoder; -use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; +use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, ClientMux, MuxStreamIo, StreamType}; use ws_stream_wasm::{WsMessage, WsMeta, WsStream}; type HttpBody = http_body_util::Full; @@ -36,7 +35,7 @@ type HttpBody = http_body_util::Full; #[derive(Debug)] enum EpxResponse { Success(Response), - Redirect((Response, http::Request, Uri)), + Redirect((Response, http::Request)), } enum EpxCompression { @@ -48,73 +47,11 @@ type EpxIoTlsStream = TlsStream>>; type EpxIoUnencryptedStream = IoStream>; type EpxIoStream = Either; -async fn send_req( - req: http::Request, - should_redirect: bool, - io: EpxIoStream, -) -> Result { - let (mut req_sender, conn) = Builder::new() - .title_case_headers(true) - .preserve_header_case(true) - .handshake(TokioIo::new(io)) - .await - .replace_err("Failed to connect to host")?; - - wasm_bindgen_futures::spawn_local(async move { - if let Err(e) = conn.await { - error!("epoxy: error in muxed hyper connection! {:?}", e); - } - }); - - let new_req = if should_redirect { - Some(req.clone()) - } else { - None - }; - - debug!("sending req"); - let res = req_sender - .send_request(req) - .await - .replace_err("Failed to send request"); - debug!("recieved res"); - match res { - Ok(res) => { - if utils::is_redirect(res.status().as_u16()) - && let Some(mut new_req) = new_req - && let Some(location) = res.headers().get("Location") - && let Ok(redirect_url) = new_req.uri().get_redirect(location) - && let Some(redirect_url_authority) = redirect_url - .clone() - .authority() - .replace_err("Redirect URL must have an authority") - .ok() - { - let should_strip = new_req.uri().is_same_host(&redirect_url); - if should_strip { - new_req.headers_mut().remove("authorization"); - new_req.headers_mut().remove("cookie"); - new_req.headers_mut().remove("www-authenticate"); - } - let new_url = redirect_url.clone(); - *new_req.uri_mut() = redirect_url; - new_req.headers_mut().insert( - "Host", - HeaderValue::from_str(redirect_url_authority.as_str())?, - ); - Ok(EpxResponse::Redirect((res, new_req, new_url))) - } else { - Ok(EpxResponse::Success(res)) - } - } - Err(err) => Err(err), - } -} - #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, - mux: ClientMux>, + mux: Arc>>, + hyper_client: Client>, HttpBody>, useragent: String, redirect_limit: usize, } @@ -145,6 +82,7 @@ impl EpoxyClient { debug!("connected!"); let (wtx, wrx) = ws.split(); let (mux, fut) = ClientMux::new(wrx, wtx); + let mux = Arc::new(mux); wasm_bindgen_futures::spawn_local(async move { if let Err(err) = fut.await { @@ -162,7 +100,15 @@ impl EpoxyClient { ); Ok(EpoxyClient { - mux, + mux: mux.clone(), + hyper_client: Client::builder(utils::WasmExecutor {}) + .http09_responses(true) + .http1_title_case_headers(true) + .http1_preserve_header_case(true) + .build(TlsWispService { + rustls_config: rustls_config.clone(), + service: ServiceWrapper(mux), + }), rustls_config, useragent, redirect_limit, @@ -193,26 +139,53 @@ impl EpoxyClient { Ok(io) } - async fn get_http_io(&self, url: &Uri) -> Result { - let url_host = url.host().replace_err("URL must have a host")?; - let url_port = utils::get_url_port(url)?; - - if utils::get_is_secure(url)? { - Ok(EpxIoStream::Left( - self.get_tls_io(url_host, url_port).await?, - )) + async fn send_req_inner( + &self, + req: http::Request, + should_redirect: bool, + ) -> Result { + let new_req = if should_redirect { + Some(req.clone()) } else { - debug!("making channel"); - let channel = self - .mux - .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) - .await - .replace_err("Failed to create multiplexor channel")? - .into_io() - .into_asyncrw(); - debug!("connecting channel"); - debug!("connected channel"); - Ok(EpxIoStream::Right(channel)) + None + }; + + debug!("sending req"); + let res = self + .hyper_client + .request(req) + .await + .replace_err("Failed to send request"); + debug!("recieved res"); + match res { + Ok(res) => { + if utils::is_redirect(res.status().as_u16()) + && let Some(mut new_req) = new_req + && let Some(location) = res.headers().get("Location") + && let Ok(redirect_url) = new_req.uri().get_redirect(location) + && let Some(redirect_url_authority) = redirect_url + .clone() + .authority() + .replace_err("Redirect URL must have an authority") + .ok() + { + let should_strip = new_req.uri().is_same_host(&redirect_url); + if should_strip { + new_req.headers_mut().remove("authorization"); + new_req.headers_mut().remove("cookie"); + new_req.headers_mut().remove("www-authenticate"); + } + *new_req.uri_mut() = redirect_url; + new_req.headers_mut().insert( + "Host", + HeaderValue::from_str(redirect_url_authority.as_str())?, + ); + Ok(EpxResponse::Redirect((res, new_req))) + } else { + Ok(EpxResponse::Success(res)) + } + } + Err(err) => Err(err), } } @@ -222,23 +195,22 @@ impl EpoxyClient { should_redirect: bool, ) -> Result<(hyper::Response, Uri, bool), JsError> { let mut redirected = false; - let uri = req.uri().clone(); - let mut current_resp: EpxResponse = - send_req(req, should_redirect, self.get_http_io(&uri).await?).await?; + let mut current_url = req.uri().clone(); + let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?; for _ in 0..self.redirect_limit - 1 { match current_resp { EpxResponse::Success(_) => break, - EpxResponse::Redirect((_, req, new_url)) => { + EpxResponse::Redirect((_, req)) => { redirected = true; - current_resp = - send_req(req, should_redirect, self.get_http_io(&new_url).await?).await? + current_url = req.uri().clone(); + current_resp = self.send_req_inner(req, should_redirect).await? } } } match current_resp { - EpxResponse::Success(resp) => Ok((resp, uri, redirected)), - EpxResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)), + EpxResponse::Success(resp) => Ok((resp, current_url, redirected)), + EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), } } @@ -353,7 +325,7 @@ impl EpoxyClient { .body(HttpBody::new(body_bytes)) .replace_err("Failed to make request")?; - let (resp, last_url, req_redirected) = self.send_req(request, req_should_redirect).await?; + let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?; let resp_headers_raw = resp.headers().clone(); @@ -417,7 +389,7 @@ impl EpoxyClient { Object::define_property( &resp, &jval!("url"), - &utils::define_property_obj(jval!(last_url.to_string()), false) + &utils::define_property_obj(jval!(resp_uri.to_string()), false) .replace_err("Failed to make define_property object for url")?, ); diff --git a/client/src/tokioio.rs b/client/src/tokioio.rs deleted file mode 100644 index 7d6acc0..0000000 --- a/client/src/tokioio.rs +++ /dev/null @@ -1,171 +0,0 @@ -#![allow(dead_code)] -// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs -// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need -// hyper-util for TokioIo. - -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use pin_project_lite::pin_project; - -pin_project! { - /// A wrapping implementing hyper IO traits for a type that - /// implements Tokio's IO traits. - #[derive(Debug)] - pub struct TokioIo { - #[pin] - inner: T, - } -} - -impl TokioIo { - /// Wrap a type implementing Tokio's IO traits. - pub fn new(inner: T) -> Self { - Self { inner } - } - - /// Borrow the inner type. - pub fn inner(&self) -> &T { - &self.inner - } - - /// Mut borrow the inner type. - pub fn inner_mut(&mut self) -> &mut T { - &mut self.inner - } - - /// Consume this wrapper and get the inner type. - pub fn into_inner(self) -> T { - self.inner - } -} - -impl hyper::rt::Read for TokioIo -where - T: tokio::io::AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let n = unsafe { - let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); - match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { - Poll::Ready(Ok(())) => tbuf.filled().len(), - other => return other, - } - }; - - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl hyper::rt::Write for TokioIo -where - T: tokio::io::AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - tokio::io::AsyncWrite::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } -} - -impl tokio::io::AsyncRead for TokioIo -where - T: hyper::rt::Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tbuf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - //let init = tbuf.initialized().len(); - let filled = tbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); - - match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; - - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tbuf.assume_init(n_init); - tbuf.set_filled(n_filled); - } - - Poll::Ready(Ok(())) - } -} - -impl tokio::io::AsyncWrite for TokioIo -where - T: hyper::rt::Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - hyper::rt::Write::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - hyper::rt::Write::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - hyper::rt::Write::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - hyper::rt::Write::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) - } -} diff --git a/client/src/utils.rs b/client/src/utils.rs index caf4498..0c71583 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -1,8 +1,9 @@ use wasm_bindgen::prelude::*; +use hyper::rt::Executor; use hyper::{header::HeaderValue, Uri}; -use http::uri; use js_sys::{Array, Object}; +use std::future::Future; #[wasm_bindgen] extern "C" { @@ -97,6 +98,21 @@ impl UriExt for Uri { } } +#[derive(Clone)] +pub struct WasmExecutor; + +impl Executor for WasmExecutor +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, future: F) { + wasm_bindgen_futures::spawn_local(async move { + let _ = future.await; + }); + } +} + pub fn entries_of_object(obj: &Object) -> Vec> { js_sys::Object::entries(obj) .to_vec() @@ -126,41 +142,19 @@ pub fn is_redirect(code: u16) -> bool { } pub fn get_is_secure(url: &Uri) -> Result { - let url_scheme = url.scheme().replace_err("URL must have a scheme")?; let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?; - // can't use match, compiler error - // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` - if *url_scheme == uri::Scheme::HTTP { - Ok(false) - } else if *url_scheme == uri::Scheme::HTTPS { - Ok(true) - } else if url_scheme_str == "ws" { - Ok(false) - } else if url_scheme_str == "wss" { - Ok(true) - } else { - return Ok(false); + match url_scheme_str { + "https" | "wss" => Ok(true), + _ => Ok(false), } } pub fn get_url_port(url: &Uri) -> Result { - let url_scheme = url.scheme().replace_err("URL must have a scheme")?; - let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?; if let Some(port) = url.port() { Ok(port.as_u16()) + } else if get_is_secure(url)? { + Ok(443) } else { - // can't use match, compiler error - // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` - if *url_scheme == uri::Scheme::HTTP { - Ok(80) - } else if *url_scheme == uri::Scheme::HTTPS { - Ok(443) - } else if url_scheme_str == "ws" { - Ok(80) - } else if url_scheme_str == "wss" { - Ok(443) - } else { - return Err(jerr!("Failed to coerce port from scheme")); - } + Ok(80) } } diff --git a/client/src/websocket.rs b/client/src/websocket.rs index f823077..d186b17 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -5,7 +5,7 @@ use fastwebsockets::{ CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, }; use futures_util::lock::Mutex; -use http_body_util::Empty; +use http_body_util::Full; use hyper::{ header::{CONNECTION, UPGRADE}, upgrade::Upgraded, @@ -63,23 +63,9 @@ impl EpxWebSocket { builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", ")); } - let req = builder.body(Empty::::new())?; + let req = builder.body(Full::::new(Bytes::new()))?; - let stream = tcp.get_http_io(&url).await?; - - let (mut sender, conn) = Builder::new() - .title_case_headers(true) - .preserve_header_case(true) - .handshake::, Empty>(TokioIo::new(stream)) - .await?; - - wasm_bindgen_futures::spawn_local(async move { - if let Err(e) = conn.with_upgrades().await { - error!("epoxy: error in muxed hyper connection (ws)! {:?}", e); - } - }); - - let mut response = sender.send_request(req).await?; + let mut response = tcp.hyper_client.request(req).await?; verify(&response)?; let ws = WebSocket::after_handshake( diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 8526a98..5df0814 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -7,6 +7,8 @@ use std::{ use futures_util::Stream; use hyper::body::Body; use pin_project_lite::pin_project; +use std::future::Future; +use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, WispError}; pin_project! { pub struct IncomingBody { @@ -30,7 +32,8 @@ impl Stream for IncomingBody { Poll::Ready(item) => Poll::>::Ready(match item { Some(frame) => frame .map(|x| { - x.into_data().map_err(|_| std::io::Error::other("not data frame")) + x.into_data() + .map_err(|_| std::io::Error::other("not data frame")) }) .ok(), None => None, @@ -39,3 +42,68 @@ impl Stream for IncomingBody { } } } + +pub struct TlsWispService +where + W: wisp_mux::ws::WebSocketWrite + Send + 'static, +{ + pub service: ServiceWrapper, + pub rustls_config: Arc, +} + + +impl tower_service::Service + for TlsWispService +{ + type Response = TokioIo; + type Error = WispError; + type Future = Pin>>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: http::Uri) -> Self::Future { + let mut service = self.service.clone(); + let rustls_config = self.rustls_config.clone(); + Box::pin(async move { + let uri_host = req + .host() + .ok_or(WispError::UriHasNoHost)? + .to_string() + .clone(); + let uri_parsed = Uri::builder() + .authority(format!( + "{}:{}", + uri_host, + utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)? + )) + .build() + .map_err(|x| WispError::Other(Box::new(x)))?; + let stream = service.call(uri_parsed).await?.into_inner(); + if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { + let connector = TlsConnector::from(rustls_config); + Ok(TokioIo::new(Either::Left( + connector + .connect( + uri_host.try_into().map_err(|_| WispError::InvalidUri)?, + stream, + ) + .await + .map_err(|x| WispError::Other(Box::new(x)))?, + ))) + } else { + Ok(TokioIo::new(Either::Right(stream))) + } + }) + } +} + +impl Clone for TlsWispService { + fn clone(&self) -> Self { + Self { + rustls_config: self.rustls_config.clone(), + service: self.service.clone(), + } + } +} diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 9448613..42f65e2 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,14 +10,15 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = futures = "0.3.30" futures-util = "0.3.30" hyper = { version = "1.1.0", optional = true } +hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy"], optional = true } pin-project-lite = "0.2.13" -tokio = { version = "1.35.1", optional = true } -tower = { version = "0.4.13", optional = true } +tokio = { version = "1.35.1", optional = true, default-features = false } +tower-service = { version = "0.3.2", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] ws_stream_wasm = ["dep:ws_stream_wasm"] tokio_io = ["async_io_stream/tokio_io"] -hyper_tower = ["dep:tower", "dep:hyper", "dep:tokio"] +hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util"] diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index e211f13..9ee6785 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -35,6 +35,9 @@ pub enum WispError { InvalidPacketType, InvalidStreamType, InvalidStreamId, + InvalidUri, + UriHasNoHost, + UriHasNoPort, MaxStreamCountReached, StreamAlreadyClosed, WsFrameInvalidType, @@ -60,6 +63,9 @@ impl std::fmt::Display for WispError { InvalidPacketType => write!(f, "Invalid packet type"), InvalidStreamType => write!(f, "Invalid stream type"), InvalidStreamId => write!(f, "Invalid stream id"), + InvalidUri => write!(f, "Invalid URI"), + UriHasNoHost => write!(f, "URI has no host"), + UriHasNoPort => write!(f, "URI has no port"), MaxStreamCountReached => write!(f, "Maximum stream count reached"), StreamAlreadyClosed => write!(f, "Stream already closed"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), @@ -329,7 +335,7 @@ impl ClientMux { stream_type: StreamType, host: String, port: u16, - ) -> Result, WispError> { + ) -> Result, WispError> { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_id = self.next_free_stream_id.load(Ordering::Acquire); self.tx diff --git a/wisp/src/tokioio.rs b/wisp/src/tokioio.rs index 7d6acc0..a3ca7be 100644 --- a/wisp/src/tokioio.rs +++ b/wisp/src/tokioio.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] // Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs -// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need -// hyper-util for TokioIo. +// hyper-util fails to compile on WASM as it has a dependency on socket2 use std::{ pin::Pin, @@ -169,3 +168,9 @@ where hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) } } + +impl hyper_util::client::legacy::connect::Connection for TokioIo { + fn connected(&self) -> hyper_util::client::legacy::connect::Connected { + hyper_util::client::legacy::connect::Connected::new() + } +} diff --git a/wisp/src/tower.rs b/wisp/src/tower.rs index 6bf635c..06f3ebc 100644 --- a/wisp/src/tower.rs +++ b/wisp/src/tower.rs @@ -1,13 +1,41 @@ -use futures::{Future, task::{Poll, Context}}; +use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError}; +use async_io_stream::IoStream; +use futures::{ + task::{Context, Poll}, + Future, +}; +use std::sync::Arc; -impl tower::Service for crate::ClientMux { - type Response = crate::tokioio::TokioIo>; - type Error = crate::WispError; +pub struct ServiceWrapper(pub Arc>); + +impl tower_service::Service for ServiceWrapper { + type Response = TokioIo>>; + type Error = WispError; type Future = impl Future>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: hyper::Uri) -> Self::Future { + fn call(&mut self, req: hyper::Uri) -> Self::Future { + let mux = self.0.clone(); + async move { + Ok(TokioIo::new( + mux.client_new_stream( + StreamType::Tcp, + req.host().ok_or(WispError::UriHasNoHost)?.to_string(), + req.port().ok_or(WispError::UriHasNoPort)?.into(), + ) + .await? + .into_io() + .into_asyncrw(), + )) + } + } +} + +impl Clone for ServiceWrapper { + fn clone(&self) -> Self { + Self(self.0.clone()) } } From 28869ef6ee4f861f8215db63261269d8efa7c6df Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 6 Feb 2024 00:27:58 -0800 Subject: [PATCH 21/26] faster than native? --- Cargo.lock | 144 +++++++++++++++++++++++++++------------------- Cargo.toml | 2 +- client/Cargo.toml | 1 + client/build.sh | 2 +- client/demo.js | 22 ++----- client/index.html | 20 ++++++- client/src/lib.rs | 7 ++- 7 files changed, 118 insertions(+), 80 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0d6939f..6e6dfe3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,9 +34,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" +checksum = "a116f46a969224200a0a97f29cfd4c50e7534e4b4826bd23ea2c3c533039c82c" dependencies = [ "brotli", "flate2", @@ -93,9 +93,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "block-buffer" @@ -154,6 +154,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -229,6 +239,7 @@ dependencies = [ "async_io_stream", "base64", "bytes", + "console_error_panic_hook", "fastwebsockets", "futures-util", "getrandom", @@ -486,9 +497,9 @@ checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "http" @@ -575,7 +586,7 @@ dependencies = [ [[package]] name = "hyper-util" version = "0.1.3" -source = "git+https://github.com/r58Playz/hyper-util-wasm#40813384dc4971677cd2a9aeb90f61b392a5bb70" +source = "git+https://github.com/r58Playz/hyper-util-wasm#ea9a5608f3255562d4a647a5c94ff7d3f9c32b53" dependencies = [ "bytes", "futures-channel", @@ -587,13 +598,15 @@ dependencies = [ "tower", "tower-service", "tracing", + "wasm-bindgen", + "wasmtimer", ] [[package]] name = "indexmap" -version = "2.1.0" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", "hashbrown", @@ -607,9 +620,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.66" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" dependencies = [ "wasm-bindgen", ] @@ -622,15 +635,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -656,9 +669,9 @@ checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] @@ -719,11 +732,11 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.62" +version = "0.10.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cde4d2d9200ad5909f8dac647e29482e07c3a35de8a13fce7c9c7747ad9f671" +checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "cfg-if", "foreign-types", "libc", @@ -751,9 +764,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.98" +version = "0.9.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1665caf8ab2dc9aef43d1c0023bd904633a6a05cb30b0ad59bec2ae986e57a7" +checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" dependencies = [ "cc", "libc", @@ -796,18 +809,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", @@ -828,9 +841,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "ppv-lite86" @@ -840,9 +853,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.76" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -926,11 +939,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.28" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -953,17 +966,17 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.1.0" -source = "git+https://github.com/r58Playz/rustls-pki-types#685721bb4b819c7da4724f07cffe06173f8cc883" +version = "1.2.0" +source = "git+https://github.com/r58Playz/rustls-pki-types#7bc22404e91ac909ef0e6ac11e6e316aefacde75" dependencies = [ "wasm-bindgen", ] [[package]] name = "rustls-webpki" -version = "0.102.1" +version = "0.102.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" dependencies = [ "ring", "rustls-pki-types", @@ -1072,9 +1085,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" @@ -1111,13 +1124,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", "rustix", "windows-sys 0.52.0", ] @@ -1313,9 +1325,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1323,9 +1335,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" dependencies = [ "bumpalo", "log", @@ -1338,9 +1350,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.39" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" +checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461" dependencies = [ "cfg-if", "js-sys", @@ -1350,9 +1362,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1360,9 +1372,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", @@ -1373,9 +1385,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "wasm-streams" @@ -1391,10 +1403,24 @@ dependencies = [ ] [[package]] -name = "web-sys" -version = "0.3.66" +name = "wasmtimer" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" +checksum = "5f656cd8858a5164932d8a90f936700860976ec21eb00e0fe2aa8cab13f6b4cf" +dependencies = [ + "futures", + "js-sys", + "parking_lot", + "pin-utils", + "slab", + "wasm-bindgen", +] + +[[package]] +name = "web-sys" +version = "0.3.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" dependencies = [ "js-sys", "wasm-bindgen", @@ -1402,9 +1428,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" dependencies = [ "rustls-pki-types", ] diff --git a/Cargo.toml b/Cargo.toml index 2fcf5e1..2e8971b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,5 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 'z' +opt-level = 3 codegen-units = 1 diff --git a/client/Cargo.toml b/client/Cargo.toml index 6e64d1d..cd05441 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -31,6 +31,7 @@ getrandom = { version = "0.2.12", features = ["js"] } hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] } tokio = { version = "1.36.0", default-features = false } tower-service = "0.3.2" +console_error_panic_hook = "0.1.7" [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/build.sh b/client/build.sh index 7f402f0..7c2725b 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] wasm-bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -Oz --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -O4 --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] wasm-opt finished" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") diff --git a/client/demo.js b/client/demo.js index 6f07fc6..7498e91 100644 --- a/client/demo.js +++ b/client/demo.js @@ -1,25 +1,15 @@ -(async () => { +importScripts("epoxy-bundled.js"); +onmessage = async (msg) => { + console.debug("recieved:", msg); + let [should_feature_test, should_multiparallel_test, should_parallel_test, should_multiperf_test, should_perf_test, should_ws_test, should_tls_test] = msg.data; console.log( "%cWASM is significantly slower with DevTools open!", "color:red;font-size:3rem;font-weight:bold" ); - const params = (new URL(window.location.href)).searchParams; - - const should_feature_test = params.has("feature_test"); - const should_multiparallel_test = params.has("multi_parallel_test"); - const should_parallel_test = params.has("parallel_test"); - const should_multiperf_test = params.has("multi_perf_test"); - const should_perf_test = params.has("perf_test"); - const should_ws_test = params.has("ws_test"); - const should_tls_test = params.has("rawtls_test"); - const log = (str) => { - let el = document.createElement("div"); - el.innerText = str; - document.getElementById("logs").appendChild(el); console.warn(str); - window.scrollTo(0, document.body.scrollHeight); + postMessage(str); } let { EpoxyClient } = await epoxy(); @@ -183,4 +173,4 @@ console.warn(await resp.text()); } log("done"); -})(); +}; diff --git a/client/index.html b/client/index.html index cf37260..718f470 100644 --- a/client/index.html +++ b/client/index.html @@ -2,12 +2,28 @@ epoxy - - + diff --git a/client/src/lib.rs b/client/src/lib.rs index 07214de..2a40a51 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -47,6 +47,11 @@ type EpxIoTlsStream = TlsStream>>; type EpxIoUnencryptedStream = IoStream>; type EpxIoStream = Either; +#[wasm_bindgen(start)] +fn init() { + std::panic::set_hook(Box::new(console_error_panic_hook::hook)); +} + #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, @@ -303,7 +308,7 @@ impl EpoxyClient { let headers_map = builder.headers_mut().replace_err("Failed to get headers")?; headers_map.insert("Accept-Encoding", HeaderValue::from_str("gzip, br")?); - headers_map.insert("Connection", HeaderValue::from_str("close")?); + headers_map.insert("Connection", HeaderValue::from_str("keep-alive")?); headers_map.insert("User-Agent", HeaderValue::from_str(&self.useragent)?); headers_map.insert("Host", HeaderValue::from_str(uri_host)?); if body_bytes.is_empty() { From 1a897ec03a0c3e44bd487705cd6e09667d384190 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 6 Feb 2024 01:05:33 -0800 Subject: [PATCH 22/26] remove unnecessary tests, add helper script to pacify firefox --- client/demo.js | 7 ------- client/serve.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 client/serve.py diff --git a/client/demo.js b/client/demo.js index 7498e91..b390c98 100644 --- a/client/demo.js +++ b/client/demo.js @@ -41,12 +41,6 @@ onmessage = async (msg) => { ["https://httpbin.org/brotli", {}], ["https://httpbin.org/redirect/11", {}], ["https://httpbin.org/redirect/1", { redirect: "manual" }], - ["https://nghttp2.org/httpbin/get", {}], - ["https://nghttp2.org/httpbin/gzip", {}], - ["https://nghttp2.org/httpbin/brotli", {}], - ["https://nghttp2.org/httpbin/redirect/11", {}], - ["https://nghttp2.org/httpbin/redirect/1", { redirect: "manual" }] - ]) { let resp = await epoxy_client.fetch(url[0], url[1]); console.warn(url, resp, Object.fromEntries(resp.headers)); @@ -122,7 +116,6 @@ onmessage = async (msg) => { } total_mux_minus_native = total_mux_minus_native / num_tests; log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`); - } else if (should_perf_test) { const num_tests = 10; diff --git a/client/serve.py b/client/serve.py new file mode 100644 index 0000000..e32b7a0 --- /dev/null +++ b/client/serve.py @@ -0,0 +1,10 @@ +from http.server import HTTPServer, SimpleHTTPRequestHandler, test +import sys + +class RequestHandler (SimpleHTTPRequestHandler): + def end_headers (self): + self.send_header('Cross-Origin-Opener-Policy', 'same-origin') + self.send_header('Cross-Origin-Embedder-Policy', 'require-corp') + SimpleHTTPRequestHandler.end_headers(self) + +test(RequestHandler, HTTPServer, port=int(sys.argv[1]) if len(sys.argv) > 1 else 8000) From 85a30aeec570177c453b483d7d540e092477efa3 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 7 Feb 2024 08:38:37 -0800 Subject: [PATCH 23/26] more improvements and fix wisp impl --- .gitignore | 2 +- Cargo.lock | 319 ++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 2 +- client/build.sh | 2 +- client/src/lib.rs | 10 +- server/Cargo.toml | 2 + server/src/main.rs | 71 ++++++---- wisp/Cargo.lock | 320 --------------------------------------------- wisp/Cargo.toml | 1 + wisp/src/lib.rs | 101 +++++++++----- wisp/src/packet.rs | 4 +- wisp/src/stream.rs | 57 +++++++- 12 files changed, 478 insertions(+), 413 deletions(-) delete mode 100644 wisp/Cargo.lock diff --git a/.gitignore b/.gitignore index dcb2577..1cd2570 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ /target -server/src/*.pem +**/*.pem client/pkg client/out .direnv diff --git a/Cargo.lock b/Cargo.lock index 6e6dfe3..cc4a287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,54 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "anstream" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "async-compression" version = "0.4.6" @@ -154,6 +202,77 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", + "terminal_size", +] + +[[package]] +name = "clap_derive" +version = "4.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + +[[package]] +name = "clio" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7fc6734af48458f72f5a3fa7b840903606427d98a710256e808f76a965047d9" +dependencies = [ + "cfg-if", + "clap", + "is-terminal", + "libc", + "tempfile", + "walkdir", + "windows-sys 0.42.0", +] + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -198,6 +317,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "crypto-common" version = "0.1.6" @@ -268,6 +393,8 @@ name = "epoxy-server" version = "1.0.0" dependencies = [ "bytes", + "clap", + "clio", "dashmap", "fastwebsockets", "futures-util", @@ -296,6 +423,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "event-listener" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b72557800024fabbaa2449dd4bf24e37b93702d457a4d4f2b0dd1f0f039f20c1" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -495,6 +633,12 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.5" @@ -612,6 +756,17 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "is-terminal" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +dependencies = [ + "hermit-abi", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "itoa" version = "1.0.10" @@ -620,9 +775,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.67" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -774,6 +929,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.1" @@ -983,6 +1144,15 @@ dependencies = [ "untrusted", ] +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1105,6 +1275,12 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -1134,6 +1310,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "terminal_size" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21bebf2b7c9e0a515f6e0f8c51dc0f8e4696391e6f1ff30379559f8365fb0df7" +dependencies = [ + "rustix", + "windows-sys 0.48.0", +] + [[package]] name = "thiserror" version = "1.0.56" @@ -1296,6 +1482,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "vcpkg" version = "0.2.15" @@ -1308,6 +1500,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -1325,9 +1527,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1335,9 +1537,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", @@ -1350,9 +1552,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -1362,9 +1564,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1372,9 +1574,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", @@ -1385,9 +1587,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "wasm-streams" @@ -1435,6 +1637,52 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -1483,6 +1731,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -1495,6 +1749,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -1507,6 +1767,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -1519,6 +1785,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -1531,6 +1803,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -1543,6 +1821,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -1555,6 +1839,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -1573,6 +1863,7 @@ version = "0.1.0" dependencies = [ "async_io_stream", "bytes", + "event-listener", "fastwebsockets", "futures", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index 2e8971b..2fcf5e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,5 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" } [profile.release] lto = true -opt-level = 3 +opt-level = 'z' codegen-units = 1 diff --git a/client/build.sh b/client/build.sh index 7c2725b..7f402f0 100755 --- a/client/build.sh +++ b/client/build.sh @@ -11,7 +11,7 @@ wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir echo "[ws] wasm-bindgen finished" mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm -time wasm-opt -O4 --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm +time wasm-opt -Oz --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm echo "[ws] wasm-opt finished" AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js") diff --git a/client/src/lib.rs b/client/src/lib.rs index 2a40a51..4da9fee 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -49,7 +49,7 @@ type EpxIoStream = Either; #[wasm_bindgen(start)] fn init() { - std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + console_error_panic_hook::set_once(); } #[wasm_bindgen] @@ -86,7 +86,7 @@ impl EpoxyClient { .replace_err("Failed to connect to websocket")?; debug!("connected!"); let (wtx, wrx) = ws.split(); - let (mux, fut) = ClientMux::new(wrx, wtx); + let (mux, fut) = ClientMux::new(wrx, wtx).await?; let mux = Arc::new(mux); wasm_bindgen_futures::spawn_local(async move { @@ -174,12 +174,6 @@ impl EpoxyClient { .replace_err("Redirect URL must have an authority") .ok() { - let should_strip = new_req.uri().is_same_host(&redirect_url); - if should_strip { - new_req.headers_mut().remove("authorization"); - new_req.headers_mut().remove("cookie"); - new_req.headers_mut().remove("www-authenticate"); - } *new_req.uri_mut() = redirect_url; new_req.headers_mut().insert( "Host", diff --git a/server/Cargo.toml b/server/Cargo.toml index a0a64c1..5da8ac0 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,8 @@ edition = "2021" [dependencies] bytes = "1.5.0" +clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] } +clio = { version = "0.3.5", features = ["clap-parse"] } dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] } futures-util = { version = "0.3.30", features = ["sink"] } diff --git a/server/src/main.rs b/server/src/main.rs index 04c39e7..58ef7c2 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,7 +1,8 @@ #![feature(let_chains)] -use std::io::Error; +use std::io::{Error, Read}; use bytes::Bytes; +use clap::Parser; use fastwebsockets::{ upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, @@ -18,25 +19,36 @@ use tokio_util::codec::{BytesCodec, Framed}; use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent}; -type HttpBody = http_body_util::Empty; +type HttpBody = http_body_util::Full; -#[tokio::main] +#[derive(Parser)] +#[command(version = clap::crate_version!(), about = "Implementation of the Wisp protocol in Rust, made for epoxy.")] +struct Cli { + #[arg(long, default_value = "/")] + prefix: String, + #[arg( + long = "port", + short = 'l', + value_name = "PORT", + default_value = "4000" + )] + listen_port: String, + #[arg(long, short, value_parser)] + pubkey: clio::Input, + #[arg(long, short = 'P', value_parser)] + privkey: clio::Input, +} + +#[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Error> { - let pem = include_bytes!("./pem.pem"); - let key = include_bytes!("./key.pem"); - let identity = native_tls::Identity::from_pkcs8(pem, key).expect("failed to make identity"); - let prefix = if let Some(prefix) = std::env::args().nth(1) { - prefix - } else { - "/".to_string() - }; - let port = if let Some(prefix) = std::env::args().nth(2) { - prefix - } else { - "4000".to_string() - }; + let mut opt = Cli::parse(); + let mut pem = Vec::new(); + opt.pubkey.read_to_end(&mut pem)?; + let mut key = Vec::new(); + opt.privkey.read_to_end(&mut key)?; + let identity = native_tls::Identity::from_pkcs8(&pem, &key).expect("failed to make identity"); - let socket = TcpListener::bind(format!("0.0.0.0:{}", port)) + let socket = TcpListener::bind(format!("0.0.0.0:{}", opt.listen_port)) .await .expect("failed to bind"); let acceptor = TlsAcceptor::from( @@ -47,7 +59,7 @@ async fn main() -> Result<(), Error> { println!("listening on 0.0.0.0:4000"); while let Ok((stream, addr)) = socket.accept().await { let acceptor_cloned = acceptor.clone(); - let prefix_cloned = prefix.clone(); + let prefix_cloned = opt.prefix.clone(); tokio::spawn(async move { let stream = acceptor_cloned.accept(stream).await.expect("not tls"); let io = TokioIo::new(stream); @@ -72,15 +84,23 @@ async fn accept_http( ) -> Result, WebSocketError> { if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) - && let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") - && protocol == "wisp-v1" + && let Some(protocols) = req.headers().get("Sec-Websocket-Protocol").and_then(|x| { + Some( + x.to_str() + .ok()? + .split(',') + .map(|x| x.trim()) + .collect::>(), + ) + }) + && protocols.contains(&"wisp-v1") { let uri = req.uri().clone(); let (mut res, fut) = upgrade::upgrade(&mut req)?; println!("{:?} {:?}", uri.path(), prefix); - if *uri.path() != prefix { + if uri.path().starts_with(&prefix) { tokio::spawn(async move { accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await }); @@ -92,11 +112,14 @@ async fn accept_http( "Sec-Websocket-Protocol", HeaderValue::from_str("wisp-v1").unwrap(), ); - Ok(res) + Ok(Response::from_parts( + res.into_parts().0, + HttpBody::new(Bytes::new()), + )) } else { Ok(Response::builder() .status(StatusCode::OK) - .body(HttpBody::new()) + .body(HttpBody::new(":3".to_string().into())) .unwrap()) } } @@ -176,7 +199,7 @@ async fn accept_ws( println!("{:?}: connected", addr); - let (mut mux, fut) = ServerMux::new(rx, tx); + let (mut mux, fut) = ServerMux::new(rx, tx, 128); tokio::spawn(async move { if let Err(e) = fut.await { diff --git a/wisp/Cargo.lock b/wisp/Cargo.lock deleted file mode 100644 index 19bc2ba..0000000 --- a/wisp/Cargo.lock +++ /dev/null @@ -1,320 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bytes" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "dashmap" -version = "5.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" -dependencies = [ - "cfg-if", - "hashbrown", - "lock_api", - "once_cell", - "parking_lot_core", -] - -[[package]] -name = "futures" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - -[[package]] -name = "futures-executor" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-macro" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" - -[[package]] -name = "futures-task" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" - -[[package]] -name = "futures-util" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - -[[package]] -name = "hashbrown" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" - -[[package]] -name = "libc" -version = "0.2.152" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" - -[[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "memchr" -version = "2.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "proc-macro2" -version = "1.0.76" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" - -[[package]] -name = "syn" -version = "2.0.48" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "wisp-mux" -version = "0.1.0" -dependencies = [ - "bytes", - "dashmap", - "futures", - "futures-util", -] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 42f65e2..da0a601 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] async_io_stream = "0.3.3" bytes = "1.5.0" +event-listener = "5.0.0" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 9ee6785..7ad9931 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -3,17 +3,18 @@ mod fastwebsockets; mod packet; mod stream; -pub mod ws; -#[cfg(feature = "ws_stream_wasm")] -mod ws_stream_wasm; #[cfg(feature = "hyper_tower")] pub mod tokioio; #[cfg(feature = "hyper_tower")] pub mod tower; +pub mod ws; +#[cfg(feature = "ws_stream_wasm")] +mod ws_stream_wasm; pub use crate::packet::*; pub use crate::stream::*; +use event_listener::Event; use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt}; use std::{ collections::HashMap, @@ -23,7 +24,7 @@ use std::{ }, }; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Copy, Clone)] pub enum Role { Client, Server, @@ -96,13 +97,14 @@ impl ServerMuxInner { rx: R, close_rx: mpsc::UnboundedReceiver, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, + buffer_size: u32 ) -> Result<(), WispError> where R: ws::WebSocketRead, { let ret = futures::select! { x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x, - x = self.server_msg_loop(rx, muxstream_sender).fuse() => x + x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x }; self.stream_map.lock().await.iter().for_each(|x| { let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01))); @@ -137,12 +139,13 @@ impl ServerMuxInner { &self, mut rx: R, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, + buffer_size: u32, ) -> Result<(), WispError> where R: ws::WebSocketRead, { self.tx - .write_frame(Packet::new_continue(0, u32::MAX).into()) + .write_frame(Packet::new_continue(0, buffer_size).into()) .await?; while let Ok(frame) = rx.wisp_read_frame(&self.tx).await { @@ -157,10 +160,13 @@ impl ServerMuxInner { inner_packet, MuxStream::new( packet.stream_id, + Role::Server, ch_rx, self.tx.clone(), self.close_tx.clone(), AtomicBool::new(false).into(), + AtomicU32::new(buffer_size).into(), + Event::new().into(), ), )) .map_err(|x| WispError::Other(Box::new(x)))?; @@ -168,11 +174,6 @@ impl ServerMuxInner { Data(data) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); - self.tx - .write_frame( - Packet::new_continue(packet.stream_id, u32::MAX).into(), - ) - .await?; } } Continue(_) => unreachable!(), @@ -200,7 +201,7 @@ where } impl ServerMux { - pub fn new(read: R, write: W) -> (Self, impl Future>) + pub fn new(read: R, write: W, buffer_size: u32) -> (Self, impl Future>) where R: ws::WebSocketRead, { @@ -215,7 +216,7 @@ impl ServerMux { close_tx, stream_map: map.clone(), } - .into_future(read, close_rx, tx), + .into_future(read, close_rx, tx, buffer_size), ) } @@ -229,7 +230,8 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>>, + stream_map: + Arc, Arc, Arc)>>>, } impl ClientMuxInner { @@ -280,13 +282,20 @@ impl ClientMuxInner { Connect(_) => unreachable!(), Data(data) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - let _ = stream.unbounded_send(WsEvent::Send(data)); + let _ = stream.0.unbounded_send(WsEvent::Send(data)); + } + } + Continue(inner_packet) => { + if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { + stream + .1 + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.2.notify(u32::MAX); } } - Continue(_) => {} Close(inner_packet) => { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { - let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + let _ = stream.0.unbounded_send(WsEvent::Close(inner_packet)); } self.stream_map.lock().await.remove(&packet.stream_id); } @@ -302,32 +311,46 @@ where W: ws::WebSocketWrite, { tx: ws::LockedWebSocketWrite, - stream_map: Arc>>>, + stream_map: + Arc, Arc, Arc)>>>, next_free_stream_id: AtomicU32, close_tx: mpsc::UnboundedSender, + buf_size: u32, } impl ClientMux { - pub fn new(read: R, write: W) -> (Self, impl Future>) + pub async fn new( + mut read: R, + write: W, + ) -> Result<(Self, impl Future>), WispError> where R: ws::WebSocketRead, { - let (tx, rx) = mpsc::unbounded::(); - let map = Arc::new(Mutex::new(HashMap::new())); let write = ws::LockedWebSocketWrite::new(write); - ( - Self { - tx: write.clone(), - stream_map: map.clone(), - next_free_stream_id: AtomicU32::new(1), - close_tx: tx, - }, - ClientMuxInner { - tx: write.clone(), - stream_map: map.clone(), - } - .into_future(read, rx), - ) + let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; + if first_packet.stream_id != 0 { + return Err(WispError::InvalidStreamId); + } + if let PacketType::Continue(packet) = first_packet.packet { + let (tx, rx) = mpsc::unbounded::(); + let map = Arc::new(Mutex::new(HashMap::new())); + Ok(( + Self { + tx: write.clone(), + stream_map: map.clone(), + next_free_stream_id: AtomicU32::new(1), + close_tx: tx, + buf_size: packet.buffer_remaining, + }, + ClientMuxInner { + tx: write.clone(), + stream_map: map.clone(), + } + .into_future(read, rx), + )) + } else { + Err(WispError::InvalidPacketType) + } } pub async fn client_new_stream( @@ -337,6 +360,8 @@ impl ClientMux { port: u16, ) -> Result, WispError> { let (ch_tx, ch_rx) = mpsc::unbounded(); + let evt: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buf_size).into(); let stream_id = self.next_free_stream_id.load(Ordering::Acquire); self.tx .write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) @@ -347,13 +372,19 @@ impl ClientMux { .ok_or(WispError::MaxStreamCountReached)?, Ordering::Release, ); - self.stream_map.lock().await.insert(stream_id, ch_tx); + self.stream_map + .lock() + .await + .insert(stream_id, (ch_tx, flow_control.clone(), evt.clone())); Ok(MuxStream::new( stream_id, + Role::Client, ch_rx, self.tx.clone(), self.close_tx.clone(), AtomicBool::new(false).into(), + flow_control, + evt, )) } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 98eb20e..505bbe6 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -63,7 +63,7 @@ impl From for Vec { #[derive(Debug)] pub struct ContinuePacket { - buffer_remaining: u32, + pub buffer_remaining: u32, } impl ContinuePacket { @@ -94,7 +94,7 @@ impl From for Vec { #[derive(Debug)] pub struct ClosePacket { - reason: u8, + pub reason: u8, } impl ClosePacket { diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 3998c9d..f561edb 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,5 +1,6 @@ use async_io_stream::IoStream; use bytes::Bytes; +use event_listener::Event; use futures::{ channel::{mpsc, oneshot}, sink, stream, @@ -10,7 +11,7 @@ use pin_project_lite::pin_project; use std::{ pin::Pin, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Arc, }, }; @@ -24,19 +25,36 @@ pub enum MuxEvent { Close(u32, u8, oneshot::Sender>), } -pub struct MuxStreamRead { +pub struct MuxStreamRead +where + W: crate::ws::WebSocketWrite, +{ pub stream_id: u32, + role: crate::Role, + tx: crate::ws::LockedWebSocketWrite, rx: mpsc::UnboundedReceiver, is_closed: Arc, + flow_control: Arc, } -impl MuxStreamRead { +impl MuxStreamRead { pub async fn read(&mut self) -> Option { if self.is_closed.load(Ordering::Acquire) { return None; } match self.rx.next().await? { - WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), + WsEvent::Send(bytes) => { + if self.role == crate::Role::Server { + let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst); + self.tx + .write_frame( + crate::Packet::new_continue(self.stream_id, old_val + 1).into(), + ) + .await + .ok()?; + } + Some(WsEvent::Send(bytes)) + } WsEvent::Close(packet) => { self.is_closed.store(true, Ordering::Release); Some(WsEvent::Close(packet)) @@ -63,9 +81,12 @@ where W: crate::ws::WebSocketWrite, { pub stream_id: u32, + role: crate::Role, tx: crate::ws::LockedWebSocketWrite, close_channel: mpsc::UnboundedSender, is_closed: Arc, + continue_recieved: Arc, + flow_control: Arc, } impl MuxStreamWrite { @@ -73,9 +94,22 @@ impl MuxStreamWrite { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } + if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) <= 0 { + self.continue_recieved.listen().await; + } self.tx .write_frame(crate::Packet::new_data(self.stream_id, data).into()) - .await + .await?; + if self.role == crate::Role::Client { + self.flow_control.store( + self.flow_control + .load(Ordering::Acquire) + .checked_add(1) + .unwrap_or(0), + Ordering::Release, + ); + } + Ok(()) } pub fn get_close_handle(&self) -> MuxStreamCloser { @@ -123,30 +157,39 @@ where W: crate::ws::WebSocketWrite, { pub stream_id: u32, - rx: MuxStreamRead, + rx: MuxStreamRead, tx: MuxStreamWrite, } impl MuxStream { pub(crate) fn new( stream_id: u32, + role: crate::Role, rx: mpsc::UnboundedReceiver, tx: crate::ws::LockedWebSocketWrite, close_channel: mpsc::UnboundedSender, is_closed: Arc, + flow_control: Arc, + continue_recieved: Arc ) -> Self { Self { stream_id, rx: MuxStreamRead { stream_id, + role, + tx: tx.clone(), rx, is_closed: is_closed.clone(), + flow_control: flow_control.clone(), }, tx: MuxStreamWrite { stream_id, + role, tx, close_channel, is_closed: is_closed.clone(), + flow_control: flow_control.clone(), + continue_recieved: continue_recieved.clone(), }, } } @@ -167,7 +210,7 @@ impl MuxStream { self.tx.close(reason).await } - pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { (self.rx, self.tx) } From 5b1503c28e6ab281077d906970a527a2b69b0e13 Mon Sep 17 00:00:00 2001 From: r58Playz Date: Wed, 7 Feb 2024 14:59:05 -0800 Subject: [PATCH 24/26] fix server --- server/Cargo.toml | 2 +- server/src/main.rs | 82 +++++++++++++++++++--------------------------- 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/server/Cargo.toml b/server/Cargo.toml index 5da8ac0..9e3231d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,4 +16,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio-native-tls = "0.3.1" tokio-util = { version = "0.7.10", features = ["codec"] } -wisp-mux = { path = "../wisp", features = ["fastwebsockets"] } +wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] } diff --git a/server/src/main.rs b/server/src/main.rs index 58ef7c2..7205d18 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -24,7 +24,7 @@ type HttpBody = http_body_util::Full; #[derive(Parser)] #[command(version = clap::crate_version!(), about = "Implementation of the Wisp protocol in Rust, made for epoxy.")] struct Cli { - #[arg(long, default_value = "/")] + #[arg(long, default_value = "")] prefix: String, #[arg( long = "port", @@ -82,9 +82,13 @@ async fn accept_http( addr: String, prefix: String, ) -> Result, WebSocketError> { + let uri = req.uri().clone().path().to_string(); if upgrade::is_upgrade_request(&req) - && req.uri().path().to_string().starts_with(&prefix) - && let Some(protocols) = req.headers().get("Sec-Websocket-Protocol").and_then(|x| { + && let Some(uri) = uri.strip_prefix(&prefix) + { + let (mut res, fut) = upgrade::upgrade(&mut req)?; + + if let Some(protocols) = req.headers().get("Sec-Websocket-Protocol").and_then(|x| { Some( x.to_str() .ok()? @@ -92,31 +96,25 @@ async fn accept_http( .map(|x| x.trim()) .collect::>(), ) - }) - && protocols.contains(&"wisp-v1") - { - let uri = req.uri().clone(); - let (mut res, fut) = upgrade::upgrade(&mut req)?; - - println!("{:?} {:?}", uri.path(), prefix); - - if uri.path().starts_with(&prefix) { - tokio::spawn(async move { - accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await - }); - } else { + }) && protocols.contains(&"wisp-v1") + && (uri == "" || uri == "/") + { tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); + res.headers_mut().insert( + "Sec-Websocket-Protocol", + HeaderValue::from_str("wisp-v1").unwrap(), + ); + } else { + let uri = uri.strip_prefix("/").unwrap_or(uri).to_string(); + tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await }); } - res.headers_mut().insert( - "Sec-Websocket-Protocol", - HeaderValue::from_str("wisp-v1").unwrap(), - ); Ok(Response::from_parts( res.into_parts().0, HttpBody::new(Bytes::new()), )) } else { + println!("random request to path {:?}", uri); Ok(Response::builder() .status(StatusCode::OK) .body(HttpBody::new(":3".to_string().into())) @@ -134,32 +132,13 @@ async fn handle_mux( ); match packet.stream_type { StreamType::Tcp => { - let tcp_stream = TcpStream::connect(uri) + let mut tcp_stream = TcpStream::connect(uri) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + let mut mux_stream = stream.into_io().into_asyncrw(); + tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream) .await .map_err(|x| WispError::Other(Box::new(x)))?; - let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); - - loop { - tokio::select! { - event = stream.read() => { - match event { - Some(event) => match event { - WsEvent::Send(data) => { - tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?; - } - WsEvent::Close(_) => return Ok(false), - }, - None => break, - } - }, - event = tcp_stream_framed.next() => { - match event.and_then(|x| x.ok()) { - Some(event) => stream.write(event.into()).await?, - None => break, - } - } - } - } } StreamType::Udp => { let udp_socket = UdpSocket::bind(uri) @@ -233,20 +212,27 @@ async fn accept_ws( async fn accept_wsproxy( fut: upgrade::UpgradeFut, - incoming_uri: &str, + incoming_uri: String, addr: String, ) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); - println!("{:?}: connected (wsproxy)", addr); + println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); + + match hyper::Uri::try_from(incoming_uri.clone()) { + Ok(_) => (), + Err(err) => { + ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")).await?; + return Err(Box::new(err)); + } + } let tcp_stream = match TcpStream::connect(incoming_uri).await { Ok(stream) => stream, Err(err) => { ws_stream .write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect")) - .await - .unwrap(); + .await?; return Err(Box::new(err)); } }; From 7fdee4ecfe9f8a12a88ff5adc4fc67d38356c7af Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 7 Feb 2024 16:31:35 -0800 Subject: [PATCH 25/26] fix simple wisp client --- simple-wisp-client/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index b530f64..2b3ca0e 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -91,7 +91,7 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let (mux, fut) = ClientMux::new(rx, tx); + let (mux, fut) = ClientMux::new(rx, tx).await?; tokio::task::spawn(fut); From a41e3eecc53ae6b0e9428e5d5411940784e906e7 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 7 Feb 2024 16:50:07 -0800 Subject: [PATCH 26/26] better docs --- README.md | 46 ++++++++++++++++++++++++++++++++++++---------- client/demo.js | 2 +- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 8cfb752..5c0eca9 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,55 @@ # epoxy Epoxy is an encrypted proxy for browser javascript. It allows you to make requests that bypass cors without compromising security, by running SSL/TLS inside webassembly. -Simple usage example for making a secure GET request to httpbin.org: +## Using the client +Epoxy must be run from within a web worker and must be served with the [security headers needed for `SharedArrayBuffer`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer#security_requirements). Here is a simple usage example: ```javascript -import epoxy from "@mercuryworkshop/epoxy-tls"; +importScripts("epoxy-bundled.js"); const { EpoxyClient } = await epoxy(); let client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10); let response = await client.fetch("https://httpbin.org/get"); await response.text(); - ``` -Epoxy also allows you to make arbitrary end to end encrypted TCP connections safely directly from the browser. +## Using the server +``` +$ cargo r -r --bin epoxy-server -- --help +Implementation of the Wisp protocol in Rust, made for epoxy. + +Usage: epoxy-server [OPTIONS] --pubkey --privkey + +Options: + --prefix [default: ] + -l, --port [default: 4000] + -p, --pubkey + -P, --privkey + -h, --help Print help + -V, --version Print version +``` ## Building +Rust nightly is required. ### Server - -1. Generate certs with `mkcert` and place the public certificate in `./server/src/pem.pem` and private certificate in `./server/src/key.pem` -2. Run `cargo r --bin epoxy-server`, optionally with `-r` flag for release +``` +cargo b -r --bin epoxy-server +``` +The executable will be placed at `target/release/epoxy-server`. ### Client -Note: Building the client is only supported on linux +> [!IMPORTANT] +> Building the client is only supported on Linux. -1. Make sure you have the `wasm32-unknown-unknown` target installed, `wasm-bindgen` and `wasm-opt` executables installed, and `bash`, `python3` packages (`python3` is used for `http.server` module) -2. Run `pnpm build` +Make sure you have the `wasm32-unknown-unknown` rust target, the `rust-std` component, and the `wasm-bindgen`, `wasm-opt`, and `base64` binaries installed. + +In the `client` directory: +``` +bash build.sh +``` + +To host a local server with the required headers: +``` +python3 serve.py +``` diff --git a/client/demo.js b/client/demo.js index b390c98..5b11bd1 100644 --- a/client/demo.js +++ b/client/demo.js @@ -12,7 +12,7 @@ onmessage = async (msg) => { postMessage(str); } - let { EpoxyClient } = await epoxy(); + const { EpoxyClient } = await epoxy(); const tconn0 = performance.now(); // args: websocket url, user agent, redirect limit