From 0fa2492a56cbe02bfd5b0bbeaa9fae1ba8613d45 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 5 Jan 2024 13:23:21 -0800 Subject: [PATCH] switch to fastwebsockets --- Cargo.lock | 193 +++++++++++++++++-------------------------------- Cargo.toml | 6 +- src/main.rs | 204 ++++++++++++++++++---------------------------------- 3 files changed, 141 insertions(+), 262 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1c32e09..bada9b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,12 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" + [[package]] name = "block-buffer" version = "0.10.4" @@ -47,12 +53,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.5.0" @@ -93,12 +93,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "data-encoding" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" - [[package]] name = "digest" version = "0.10.7" @@ -109,21 +103,31 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "fastwebsockets" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63dd7b57f9b33b1741fa631c9522eb35d43e96dcca4a6a91d5e4ca7c93acdc1" +dependencies = [ + "base64", + "http-body-util", + "hyper", + "hyper-util", + "pin-project", + "rand", + "sha1", + "simdutf8", + "thiserror", + "tokio", + "utf-8", +] + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - [[package]] name = "futures-channel" version = "0.3.30" @@ -210,17 +214,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" -[[package]] -name = "http" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.0.0" @@ -239,7 +232,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.0.0", + "http", ] [[package]] @@ -250,7 +243,7 @@ checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" dependencies = [ "bytes", "futures-util", - "http 1.0.0", + "http", "http-body", "pin-project-lite", ] @@ -276,13 +269,14 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "httparse", "httpdate", "itoa", "pin-project-lite", "tokio", + "want", ] [[package]] @@ -294,7 +288,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.0.0", + "http", "http-body", "hyper", "pin-project-lite", @@ -303,16 +297,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "idna" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "itoa" version = "1.0.10" @@ -325,12 +309,6 @@ version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" -[[package]] -name = "log" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" - [[package]] name = "memchr" version = "2.7.1" @@ -383,10 +361,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] -name = "percent-encoding" -version = "2.3.1" +name = "pin-project" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "pin-project-lite" @@ -471,6 +463,12 @@ dependencies = [ "digest", ] +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "slab" version = "0.4.9" @@ -521,21 +519,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tinyvec" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "tokio" version = "1.35.1" @@ -564,18 +547,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-tungstenite" -version = "0.20.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.7.10" @@ -610,23 +581,10 @@ dependencies = [ ] [[package]] -name = "tungstenite" -version = "0.20.1" +name = "try-lock" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 0.2.11", - "httparse", - "log", - "rand", - "sha1", - "thiserror", - "url", - "utf-8", -] +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" @@ -634,38 +592,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" - [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" -[[package]] -name = "unicode-normalization" -version = "0.1.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "url" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - [[package]] name = "utf-8" version = "0.7.6" @@ -678,6 +610,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -751,15 +692,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] -name = "wsproxy-rust" +name = "wsfetch-server" version = "0.1.0" dependencies = [ "bytes", + "fastwebsockets", "futures-util", "http-body-util", "hyper", "hyper-util", "tokio", - "tokio-tungstenite", "tokio-util", ] diff --git a/Cargo.toml b/Cargo.toml index 7539d56..bfb208f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "wsproxy-rust" +name = "wsfetch-server" version = "0.1.0" edition = "2021" @@ -7,10 +7,10 @@ edition = "2021" [dependencies] bytes = "1.5.0" -futures-util = "0.3.30" +fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } +futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } hyper-util = { version = "0.1.2", features = ["tokio"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } -tokio-tungstenite = "0.20.1" tokio-util = { version = "0.7.10", features = ["codec"] } diff --git a/src/main.rs b/src/main.rs index 5894e73..f6568db 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,47 +1,16 @@ -use std::{convert::Infallible, io::Error}; +use std::io::Error; -use bytes::{BufMut, BytesMut}; +use bytes::Bytes; +use fastwebsockets::{upgrade, FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use futures_util::{SinkExt, StreamExt}; use hyper::{ - body::Incoming, - header::{ - HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, - UPGRADE, - }, - server::conn::http1, - service::service_fn, - upgrade::Upgraded, - Method, Request, Response, StatusCode, Version + body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; use tokio::net::{TcpListener, TcpStream}; -use tokio_tungstenite::{ - tungstenite::{protocol::Role, handshake::derive_accept_key, Message}, - WebSocketStream, -}; -use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite}; +use tokio_util::codec::{BytesCodec, Framed}; -struct NetworkCodec; - -impl Encoder> for NetworkCodec { - type Error = std::io::Error; - - fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), Self::Error> { - dst.put_slice(item.as_slice()); - Ok(()) - } -} - -impl Decoder for NetworkCodec { - type Item = Vec; - type Error = std::io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - Ok(Some(src.to_vec())) - } -} - -type HttpBody = http_body_util::Full; +type HttpBody = http_body_util::Empty; #[tokio::main(flavor = "multi_thread", worker_threads = 32)] async fn main() -> Result<(), Error> { @@ -51,15 +20,14 @@ async fn main() -> Result<(), Error> { println!("listening on 0.0.0.0:4000"); while let Ok((stream, addr)) = socket.accept().await { - println!("socket connected: {:?}", addr); tokio::spawn(async move { let io = TokioIo::new(stream); - let service = service_fn(accept_http); + let service = service_fn(move |res| accept_http(res, addr.to_string())); let conn = http1::Builder::new() .serve_connection(io, service) .with_upgrades(); if let Err(err) = conn.await { - println!("failed to serve conn: {:?}", err); + println!("{:?}: failed to serve conn: {:?}", addr, err); } }); } @@ -67,124 +35,91 @@ async fn main() -> Result<(), Error> { Ok(()) } -async fn accept_http(mut req: Request) -> Result, Infallible> { - let incoming_uri = req.uri().clone(); - let req_ver = req.version(); - let req_headers = req.headers().clone(); - let req_key = req_headers.get(SEC_WEBSOCKET_KEY); - let derived_key = req_key.map(|k| derive_accept_key(k.as_bytes())); +async fn accept_http(mut req: Request, addr: String) -> Result, WebSocketError> { + if upgrade::is_upgrade_request(&req) { + let uri = req.uri().clone(); + let (res, fut) = upgrade::upgrade(&mut req)?; - if req.method() != Method::GET - || req.version() < Version::HTTP_11 - || !req_headers - .get(CONNECTION) - .and_then(|h| h.to_str().ok()) - .map(|h| { - h.split(|c| c == ' ' || c == ',') - .any(|p| p.eq_ignore_ascii_case("Upgrade")) - }) - .unwrap_or(false) - || !req_headers - .get(UPGRADE) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) - || !req_headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false) - || req_key.is_none() - { - return Ok(Response::new(HttpBody::from("Hello World!"))); + tokio::spawn(async move { + if let Err(e) = tokio::task::unconstrained(accept_ws(fut, uri.path().to_string(), addr.clone())).await + { + println!("{:?}: error in ws: {:?}", addr, e); + } + }); + + Ok(res) + } else { + Ok(Response::builder() + .status(StatusCode::OK) + .body(HttpBody::new()) + .unwrap()) } - - tokio::spawn(async move { - match hyper::upgrade::on(&mut req).await { - Ok(upgraded) => { - println!("upgraded connection"); - let upgraded_io = TokioIo::new(upgraded); - accept_ws( - WebSocketStream::from_raw_socket(upgraded_io, Role::Server, None).await, - incoming_uri.path().to_string(), - ).await; - } - Err(e) => { - println!("upgrade error! {:?}", e); - } - } - }); - - println!("sending upgrade response"); - - let mut res = Response::new(HttpBody::default()); - *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - *res.version_mut() = req_ver; - res.headers_mut() - .append(CONNECTION, HeaderValue::from_static("Upgrade")); - res.headers_mut() - .append(UPGRADE, HeaderValue::from_static("websocket")); - res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived_key.unwrap().parse().unwrap()); - - Ok(res) } -async fn accept_ws(mut ws_stream: WebSocketStream>, incoming_uri: String) { - println!("new ws connection: {}", incoming_uri); +async fn accept_ws( + fut: upgrade::UpgradeFut, + incoming_uri: String, + addr: String +) -> Result<(), Box> { + let mut ws_stream = FragmentCollector::new(fut.await?); let mut incoming_uri_chars = incoming_uri.chars(); incoming_uri_chars.next(); - let tcp_stream = TcpStream::connect(incoming_uri_chars.as_str()) - .await - .expect("failed to connect to incoming uri"); - let (tcp_read, tcp_write) = tokio::io::split(tcp_stream); - let mut tcp_write = FramedWrite::new(tcp_write, NetworkCodec); - let mut tcp_read = FramedRead::new(tcp_read, NetworkCodec); + let tcp_stream = TcpStream::connect(incoming_uri_chars.as_str()).await?; + let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); loop { tokio::select! { - event = ws_stream.next() => { - if let Some(Ok(payload)) = event { - print!("event ws {:?} - ", payload); - match payload { - Message::Text(txt) => { - if tcp_write.send(txt.into_bytes()).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); + event = ws_stream.read_frame() => { + match event { + Ok(frame) => { + print!("{:?}: event ws - ", addr); + match frame.opcode { + OpCode::Text | OpCode::Binary => { + if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() { + println!("sent success"); + } else { + println!("sent FAILED"); + } + } + OpCode::Close => { + if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { + println!("closed success"); + } else { + println!("closed FAILED"); + } + break; + } + _ => { + println!("ignored"); } } - Message::Binary(bin) => { - if tcp_write.send(bin).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); - } - } - Message::Close(_) => { - if tcp_write.close().await.is_ok() { - println!("closed success"); - } else { - println!("closed FAILED"); - } - break; - } - _ => { - println!("ignored"); + }, + Err(err) => { + print!("{:?}: err in ws: {:?} - ", addr, err); + if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { + println!("closed tcp success"); + } else { + println!("closed tcp FAILED"); } + break; } } }, - event = tcp_read.next() => { + event = tcp_stream_framed.next() => { if let Some(res) = event { - print!("event tcp - "); + print!("{:?}: event tcp - ", addr); match res { Ok(buf) => { - if ws_stream.send(Message::Binary(buf)).await.is_ok() { + if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() { println!("sent success"); } else { println!("sent FAILED"); } } Err(_) => { - if ws_stream.close(None).await.is_ok() { + if ws_stream.write_frame(Frame::close(1001, b"tcp side is going away")).await.is_ok() { println!("closed success"); } else { println!("closed FAILED"); @@ -195,5 +130,8 @@ async fn accept_ws(mut ws_stream: WebSocketStream>, incoming_u } } } - println!("connection closed"); + + println!("\"{}\": connection closed", addr); + + Ok(()) }