diff --git a/server/Cargo.toml b/server/Cargo.toml index 75d8355..a0a64c1 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] bytes = "1.5.0" dashmap = "5.5.3" -fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } +fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] } futures-util = { version = "0.3.30", features = ["sink"] } http-body-util = "0.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs index 7cf1822..ffc8e13 100644 --- a/server/src/lockedws.rs +++ b/server/src/lockedws.rs @@ -1,23 +1,19 @@ -use fastwebsockets::{FragmentCollector, Frame, WebSocketError}; +use fastwebsockets::{WebSocketWrite, Frame, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::{io::WriteHalf, sync::Mutex}; -type Ws = FragmentCollector>; +type Ws = WebSocketWrite>>; #[derive(Clone)] -pub struct LockedWebSocket(Arc>); +pub struct LockedWebSocketWrite(Arc>); -impl LockedWebSocket { +impl LockedWebSocketWrite { pub fn new(ws: Ws) -> Self { Self(Arc::new(Mutex::new(ws))) } - pub async fn read_frame(&self) -> Result { - self.0.lock().await.read_frame().await - } - pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { self.0.lock().await.write_frame(frame).await } diff --git a/server/src/main.rs b/server/src/main.rs index 3906cf5..96e73c3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -122,25 +122,24 @@ async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, ) -> Result<(), Box> { - let ws_stream = FragmentCollector::new(fut.await?); - let ws_stream = lockedws::LockedWebSocket::new(ws_stream); + let (mut rx, tx) = fut.await?.split(tokio::io::split); + let tx = lockedws::LockedWebSocketWrite::new(tx); let stream_map = Arc::new(dashmap::DashMap::>::new()); println!("{:?}: connected", addr); - ws_stream - .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) + tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) .await?; - while let Ok(frame) = ws_stream.read_frame().await { + while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await { if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { use PacketType::*; match packet.packet { Connect(inner_packet) => { let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); stream_map.clone().insert(packet.stream_id, ch_tx); - let ws_stream_cloned = ws_stream.clone(); + let tx_cloned = tx.clone(); tokio::spawn(async move { let tcp_stream = match TcpStream::connect(format!( "{}:{}", @@ -150,7 +149,7 @@ async fn accept_ws( { Ok(stream) => stream, Err(err) => { - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) .into(), @@ -160,16 +159,14 @@ async fn accept_ws( return Err(Box::new(err)); } }; - println!("muxing"); let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); loop { tokio::select! { event = tcp_stream.next() => { - println!("recvd"); if let Some(res) = event { match res { Ok(buf) => { - ws_stream_cloned.write_frame( + tx_cloned.write_frame( ws::Frame::from( Packet::new_data( packet.stream_id, @@ -179,7 +176,7 @@ async fn accept_ws( ).await.map_err(std::io::Error::other)?; } Err(err) => { - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from(Packet::new_close( packet.stream_id, @@ -199,8 +196,7 @@ async fn accept_ws( match event { WsEvent::Send(buf) => { tcp_stream.send(buf).await?; - println!("sending"); - ws_stream_cloned + tx_cloned .write_frame( ws::Frame::from( Packet::new_continue( @@ -209,7 +205,6 @@ async fn accept_ws( ) ).into() ).await.map_err(std::io::Error::other)?; - println!("sent"); } WsEvent::Close => { break;