switch to fastwebsockets

This commit is contained in:
Toshit Chawda 2024-01-05 13:23:21 -08:00
parent c4f315ca40
commit 0fa2492a56
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 141 additions and 262 deletions

193
Cargo.lock generated
View file

@ -38,6 +38,12 @@ dependencies = [
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53"
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -47,12 +53,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.5.0"
@ -93,12 +93,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "data-encoding"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
[[package]]
name = "digest"
version = "0.10.7"
@ -109,21 +103,31 @@ dependencies = [
"crypto-common",
]
[[package]]
name = "fastwebsockets"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63dd7b57f9b33b1741fa631c9522eb35d43e96dcca4a6a91d5e4ca7c93acdc1"
dependencies = [
"base64",
"http-body-util",
"hyper",
"hyper-util",
"pin-project",
"rand",
"sha1",
"simdutf8",
"thiserror",
"tokio",
"utf-8",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
@ -210,17 +214,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
[[package]]
name = "http"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http"
version = "1.0.0"
@ -239,7 +232,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.0.0",
"http",
]
[[package]]
@ -250,7 +243,7 @@ checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"pin-project-lite",
]
@ -276,13 +269,14 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"tokio",
"want",
]
[[package]]
@ -294,7 +288,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"hyper",
"pin-project-lite",
@ -303,16 +297,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "idna"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "itoa"
version = "1.0.10"
@ -325,12 +309,6 @@ version = "0.2.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4"
[[package]]
name = "log"
version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "memchr"
version = "2.7.1"
@ -383,10 +361,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "percent-encoding"
version = "2.3.1"
name = "pin-project"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
@ -471,6 +463,12 @@ dependencies = [
"digest",
]
[[package]]
name = "simdutf8"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "slab"
version = "0.4.9"
@ -521,21 +519,6 @@ dependencies = [
"syn",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.35.1"
@ -564,18 +547,6 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-tungstenite"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
@ -610,23 +581,10 @@ dependencies = [
]
[[package]]
name = "tungstenite"
version = "0.20.1"
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 0.2.11",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"url",
"utf-8",
]
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "typenum"
@ -634,38 +592,12 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicode-bidi"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416"
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unicode-normalization"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921"
dependencies = [
"tinyvec",
]
[[package]]
name = "url"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
@ -678,6 +610,15 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@ -751,15 +692,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "wsproxy-rust"
name = "wsfetch-server"
version = "0.1.0"
dependencies = [
"bytes",
"fastwebsockets",
"futures-util",
"http-body-util",
"hyper",
"hyper-util",
"tokio",
"tokio-tungstenite",
"tokio-util",
]

View file

@ -1,5 +1,5 @@
[package]
name = "wsproxy-rust"
name = "wsfetch-server"
version = "0.1.0"
edition = "2021"
@ -7,10 +7,10 @@ edition = "2021"
[dependencies]
bytes = "1.5.0"
futures-util = "0.3.30"
fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] }
futures-util = { version = "0.3.30", features = ["sink"] }
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["server", "http1"] }
hyper-util = { version = "0.1.2", features = ["tokio"] }
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
tokio-tungstenite = "0.20.1"
tokio-util = { version = "0.7.10", features = ["codec"] }

View file

@ -1,47 +1,16 @@
use std::{convert::Infallible, io::Error};
use std::io::Error;
use bytes::{BufMut, BytesMut};
use bytes::Bytes;
use fastwebsockets::{upgrade, FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use futures_util::{SinkExt, StreamExt};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::{
tungstenite::{protocol::Role, handshake::derive_accept_key, Message},
WebSocketStream,
};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
use tokio_util::codec::{BytesCodec, Framed};
struct NetworkCodec;
impl Encoder<Vec<u8>> for NetworkCodec {
type Error = std::io::Error;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.put_slice(item.as_slice());
Ok(())
}
}
impl Decoder for NetworkCodec {
type Item = Vec<u8>;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Ok(Some(src.to_vec()))
}
}
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
#[tokio::main(flavor = "multi_thread", worker_threads = 32)]
async fn main() -> Result<(), Error> {
@ -51,15 +20,14 @@ async fn main() -> Result<(), Error> {
println!("listening on 0.0.0.0:4000");
while let Ok((stream, addr)) = socket.accept().await {
println!("socket connected: {:?}", addr);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(accept_http);
let service = service_fn(move |res| accept_http(res, addr.to_string()));
let conn = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(err) = conn.await {
println!("failed to serve conn: {:?}", err);
println!("{:?}: failed to serve conn: {:?}", addr, err);
}
});
}
@ -67,124 +35,91 @@ async fn main() -> Result<(), Error> {
Ok(())
}
async fn accept_http(mut req: Request<Incoming>) -> Result<Response<HttpBody>, Infallible> {
let incoming_uri = req.uri().clone();
let req_ver = req.version();
let req_headers = req.headers().clone();
let req_key = req_headers.get(SEC_WEBSOCKET_KEY);
let derived_key = req_key.map(|k| derive_accept_key(k.as_bytes()));
async fn accept_http(mut req: Request<Incoming>, addr: String) -> Result<Response<HttpBody>, WebSocketError> {
if upgrade::is_upgrade_request(&req) {
let uri = req.uri().clone();
let (res, fut) = upgrade::upgrade(&mut req)?;
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !req_headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("Upgrade"))
})
.unwrap_or(false)
|| !req_headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !req_headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false)
|| req_key.is_none()
{
return Ok(Response::new(HttpBody::from("Hello World!")));
tokio::spawn(async move {
if let Err(e) = tokio::task::unconstrained(accept_ws(fut, uri.path().to_string(), addr.clone())).await
{
println!("{:?}: error in ws: {:?}", addr, e);
}
});
Ok(res)
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(HttpBody::new())
.unwrap())
}
tokio::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
println!("upgraded connection");
let upgraded_io = TokioIo::new(upgraded);
accept_ws(
WebSocketStream::from_raw_socket(upgraded_io, Role::Server, None).await,
incoming_uri.path().to_string(),
).await;
}
Err(e) => {
println!("upgrade error! {:?}", e);
}
}
});
println!("sending upgrade response");
let mut res = Response::new(HttpBody::default());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = req_ver;
res.headers_mut()
.append(CONNECTION, HeaderValue::from_static("Upgrade"));
res.headers_mut()
.append(UPGRADE, HeaderValue::from_static("websocket"));
res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived_key.unwrap().parse().unwrap());
Ok(res)
}
async fn accept_ws(mut ws_stream: WebSocketStream<TokioIo<Upgraded>>, incoming_uri: String) {
println!("new ws connection: {}", incoming_uri);
async fn accept_ws(
fut: upgrade::UpgradeFut,
incoming_uri: String,
addr: String
) -> Result<(), Box<dyn std::error::Error>> {
let mut ws_stream = FragmentCollector::new(fut.await?);
let mut incoming_uri_chars = incoming_uri.chars();
incoming_uri_chars.next();
let tcp_stream = TcpStream::connect(incoming_uri_chars.as_str())
.await
.expect("failed to connect to incoming uri");
let (tcp_read, tcp_write) = tokio::io::split(tcp_stream);
let mut tcp_write = FramedWrite::new(tcp_write, NetworkCodec);
let mut tcp_read = FramedRead::new(tcp_read, NetworkCodec);
let tcp_stream = TcpStream::connect(incoming_uri_chars.as_str()).await?;
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = ws_stream.next() => {
if let Some(Ok(payload)) = event {
print!("event ws {:?} - ", payload);
match payload {
Message::Text(txt) => {
if tcp_write.send(txt.into_bytes()).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
event = ws_stream.read_frame() => {
match event {
Ok(frame) => {
print!("{:?}: event ws - ", addr);
match frame.opcode {
OpCode::Text | OpCode::Binary => {
if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
}
}
OpCode::Close => {
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
println!("closed success");
} else {
println!("closed FAILED");
}
break;
}
_ => {
println!("ignored");
}
}
Message::Binary(bin) => {
if tcp_write.send(bin).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
}
}
Message::Close(_) => {
if tcp_write.close().await.is_ok() {
println!("closed success");
} else {
println!("closed FAILED");
}
break;
}
_ => {
println!("ignored");
},
Err(err) => {
print!("{:?}: err in ws: {:?} - ", addr, err);
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
println!("closed tcp success");
} else {
println!("closed tcp FAILED");
}
break;
}
}
},
event = tcp_read.next() => {
event = tcp_stream_framed.next() => {
if let Some(res) = event {
print!("event tcp - ");
print!("{:?}: event tcp - ", addr);
match res {
Ok(buf) => {
if ws_stream.send(Message::Binary(buf)).await.is_ok() {
if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
}
}
Err(_) => {
if ws_stream.close(None).await.is_ok() {
if ws_stream.write_frame(Frame::close(1001, b"tcp side is going away")).await.is_ok() {
println!("closed success");
} else {
println!("closed FAILED");
@ -195,5 +130,8 @@ async fn accept_ws(mut ws_stream: WebSocketStream<TokioIo<Upgraded>>, incoming_u
}
}
}
println!("connection closed");
println!("\"{}\": connection closed", addr);
Ok(())
}