From 1f23c26db6ba43442f5326be03a252532585aadb Mon Sep 17 00:00:00 2001 From: r58Playz Date: Mon, 22 Jan 2024 18:19:39 -0800 Subject: [PATCH] wisp part 1 --- Cargo.lock | 20 ++++++++ server/Cargo.toml | 1 + server/src/lockedws.rs | 23 +++++++++ server/src/main.rs | 113 +++++++++++++++++++++++++---------------- wisp/src/lib.rs | 2 +- wisp/src/packet.rs | 4 +- 6 files changed, 117 insertions(+), 46 deletions(-) create mode 100644 server/src/lockedws.rs diff --git a/Cargo.lock b/Cargo.lock index 3db0195..1db3161 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -214,6 +214,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "digest" version = "0.10.7" @@ -266,6 +279,7 @@ name = "epoxy-server" version = "1.0.0" dependencies = [ "bytes", + "dashmap", "fastwebsockets", "futures-util", "http-body-util", @@ -461,6 +475,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + [[package]] name = "hermit-abi" version = "0.3.3" diff --git a/server/Cargo.toml b/server/Cargo.toml index c0a4a10..11b0915 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] bytes = "1.5.0" +dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs new file mode 100644 index 0000000..53c7fb8 --- /dev/null +++ b/server/src/lockedws.rs @@ -0,0 +1,23 @@ +use fastwebsockets::{FragmentCollector, Frame, WebSocketError}; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; +use std::sync::Arc; +use tokio::sync::Mutex; + +type Ws = FragmentCollector>; + +pub struct LockedWebSocket(Arc>); + +impl LockedWebSocket { + pub fn new(ws: Ws) -> Self { + Self(Arc::new(Mutex::new(ws))) + } + + pub async fn read_frame(&self) -> Result { + self.0.lock().await.read_frame().await + } + + pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { + self.0.lock().await.write_frame(frame).await + } +} diff --git a/server/src/main.rs b/server/src/main.rs index 6318929..ef58f78 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,6 @@ -use std::io::Error; +mod lockedws; + +use std::{io::Error, sync::Arc}; use bytes::Bytes; use fastwebsockets::{ @@ -11,9 +13,12 @@ use hyper::{ }; use hyper_util::rt::TokioIo; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; +use wisp_mux::{ws, Packet, PacketType}; + type HttpBody = http_body_util::Empty; #[tokio::main(flavor = "multi_thread")] @@ -47,7 +52,8 @@ async fn main() -> Result<(), Error> { tokio::spawn(async move { let stream = acceptor_cloned.accept(stream).await.expect("not tls"); let io = TokioIo::new(stream); - let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); + let service = + service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); let conn = http1::Builder::new() .serve_connection(io, service) .with_upgrades(); @@ -72,10 +78,13 @@ async fn accept_http( tokio::spawn(async move { if *uri.path() != prefix { if let Err(e) = - accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await + accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()) + .await { println!("{:?}: error in ws handling: {:?}", addr, e); } + } else if let Err(e) = accept_ws(fut, addr.clone()).await { + println!("{:?}: error in ws handling: {:?}", addr, e); } }); @@ -102,18 +111,60 @@ async fn accept_http( } } +enum WsEvent { + Send(Bytes), + Close, +} + +async fn accept_ws( + fut: upgrade::UpgradeFut, + addr: String, +) -> Result<(), Box> { + let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?)); + + let stream_map = Arc::new(dashmap::DashMap::>::new()); + + println!("{:?}: connected", addr); + + while let Ok(mut frame) = ws_stream.read_frame().await { + use fastwebsockets::OpCode::*; + let frame = match frame.opcode { + Continuation => unreachable!(), + Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())), + Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())), + Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())), + Ping => continue, + Pong => continue, + }; + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(inner_packet) => { + let (ch_tx, ch_rx) = mpsc::unbounded_channel::(); + stream_map.clone().insert(packet.stream_id, ch_tx); + tokio::spawn(async move { + + }); + } + Data(inner_packet) => {} + Continue(_) => unreachable!(), + Close(inner_packet) => {} + } + } + } + + println!("{:?}: disconnected", addr); + Ok(()) +} + async fn accept_wsproxy( fut: upgrade::UpgradeFut, - incoming_uri: String, + incoming_uri: &str, addr: String, - prefix: String, ) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); - // should always have prefix - let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap(); - - println!("{:?}: connected", addr); + println!("{:?}: connected (wsproxy)", addr); let tcp_stream = match TcpStream::connect(incoming_uri).await { Ok(stream) => stream, @@ -132,56 +183,33 @@ async fn accept_wsproxy( event = ws_stream.read_frame() => { match event { Ok(frame) => { - print!("{:?}: event ws - ", addr); match frame.opcode { OpCode::Text | OpCode::Binary => { - if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); - } + let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await; } OpCode::Close => { - if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { - println!("closed success"); - } else { - println!("closed FAILED"); - } + // tokio closes the stream for us + drop(tcp_stream_framed); break; } - _ => { - println!("ignored"); - } + _ => {} } }, - Err(err) => { - print!("{:?}: err in ws: {:?} - ", addr, err); - if as SinkExt>::close(&mut tcp_stream_framed).await.is_ok() { - println!("closed tcp success"); - } else { - println!("closed tcp FAILED"); - } + Err(_) => { + // tokio closes the stream for us + drop(tcp_stream_framed); break; } } }, event = tcp_stream_framed.next() => { if let Some(res) = event { - print!("{:?}: event tcp - ", addr); match res { Ok(buf) => { - if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() { - println!("sent success"); - } else { - println!("sent FAILED"); - } + let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await; } Err(_) => { - if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() { - println!("closed success"); - } else { - println!("closed FAILED"); - } + let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await; } } } @@ -189,8 +217,7 @@ async fn accept_wsproxy( } } - println!("\"{}\": connection closed", addr); + println!("{:?}: disconnected (wsproxy)", addr); Ok(()) } - diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index c8147ae..d8ade1c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,5 +1,5 @@ mod packet; -mod ws; +pub mod ws; pub use crate::packet::*; diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index b1091b8..2a9667f 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -139,8 +139,8 @@ impl From for Vec { #[derive(Debug)] pub struct Packet { - stream_id: u32, - packet: PacketType, + pub stream_id: u32, + pub packet: PacketType, } impl Packet {