diff --git a/Cargo.lock b/Cargo.lock index 1db3161..0b35945 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,6 +1530,7 @@ name = "wisp-mux" version = "0.1.0" dependencies = [ "bytes", + "fastwebsockets", "futures", "futures-util", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index 11b0915..75d8355 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,4 +14,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio-native-tls = "0.3.1" tokio-util = { version = "0.7.10", features = ["codec"] } -wisp-mux = { path = "../wisp" } +wisp-mux = { path = "../wisp", features = ["fastwebsockets"] } diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs index 53c7fb8..7cf1822 100644 --- a/server/src/lockedws.rs +++ b/server/src/lockedws.rs @@ -6,6 +6,7 @@ use tokio::sync::Mutex; type Ws = FragmentCollector>; +#[derive(Clone)] pub struct LockedWebSocket(Arc>); impl LockedWebSocket { diff --git a/server/src/main.rs b/server/src/main.rs index ef58f78..3906cf5 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,8 +12,10 @@ use hyper::{ Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, +}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; @@ -120,35 +122,120 @@ async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, ) -> Result<(), Box> { - let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?)); + let ws_stream = FragmentCollector::new(fut.await?); + let ws_stream = lockedws::LockedWebSocket::new(ws_stream); let stream_map = Arc::new(dashmap::DashMap::>::new()); println!("{:?}: connected", addr); - while let Ok(mut frame) = ws_stream.read_frame().await { - use fastwebsockets::OpCode::*; - let frame = match frame.opcode { - Continuation => unreachable!(), - Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())), - Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())), - Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())), - Ping => continue, - Pong => continue, - }; - if let Ok(packet) = Packet::try_from(frame) { + ws_stream + .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) + .await?; + + while let Ok(frame) = ws_stream.read_frame().await { + if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { use PacketType::*; match packet.packet { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::unbounded_channel::(); + let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); stream_map.clone().insert(packet.stream_id, ch_tx); + let ws_stream_cloned = ws_stream.clone(); tokio::spawn(async move { - + let tcp_stream = match TcpStream::connect(format!( + "{}:{}", + inner_packet.destination_hostname, inner_packet.destination_port + )) + .await + { + Ok(stream) => stream, + Err(err) => { + ws_stream_cloned + .write_frame( + ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) + .into(), + ) + .await + .map_err(std::io::Error::other)?; + return Err(Box::new(err)); + } + }; + println!("muxing"); + let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); + loop { + tokio::select! { + event = tcp_stream.next() => { + println!("recvd"); + if let Some(res) = event { + match res { + Ok(buf) => { + ws_stream_cloned.write_frame( + ws::Frame::from( + Packet::new_data( + packet.stream_id, + buf.to_vec() + ) + ).into() + ).await.map_err(std::io::Error::other)?; + } + Err(err) => { + ws_stream_cloned + .write_frame( + ws::Frame::from(Packet::new_close( + packet.stream_id, + 0x03, + )) + .into(), + ) + .await + .map_err(std::io::Error::other)?; + return Err(Box::new(err)); + } + } + } + } + event = ch_rx.recv() => { + if let Some(event) = event { + match event { + WsEvent::Send(buf) => { + tcp_stream.send(buf).await?; + println!("sending"); + ws_stream_cloned + .write_frame( + ws::Frame::from( + Packet::new_continue( + packet.stream_id, + u32::MAX + ) + ).into() + ).await.map_err(std::io::Error::other)?; + println!("sent"); + } + WsEvent::Close => { + break; + } + } + } else { + break; + } + } + }; + } + Ok(()) }); } - Data(inner_packet) => {} + Data(inner_packet) => { + println!("recieved data for {:?}", packet.stream_id); + if let Some(stream) = stream_map.clone().get(&packet.stream_id) { + let _ = stream.send(WsEvent::Send(inner_packet.into())); + } + } Continue(_) => unreachable!(), - Close(inner_packet) => {} + Close(_) => { + if let Some(stream) = stream_map.clone().get(&packet.stream_id) { + let _ = stream.send(WsEvent::Close); + } + } } } } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index a660280..693c91e 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -3,9 +3,11 @@ name = "wisp-mux" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] bytes = "1.5.0" +fastwebsockets = { version = "0.6.0", optional = true } futures = "0.3.30" futures-util = "0.3.30" + +[features] +fastwebsockets = ["dep:fastwebsockets"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs new file mode 100644 index 0000000..ba499ac --- /dev/null +++ b/wisp/src/fastwebsockets.rs @@ -0,0 +1,40 @@ +use bytes::Bytes; +use fastwebsockets::{Payload, Frame, OpCode}; + +impl TryFrom for crate::ws::OpCode { + type Error = crate::WispError; + fn try_from(opcode: OpCode) -> Result { + use OpCode::*; + match opcode { + Continuation => Err(Self::Error::WsImplNotSupported), + Text => Ok(Self::Text), + Binary => Ok(Self::Binary), + Close => Ok(Self::Close), + Ping => Err(Self::Error::WsImplNotSupported), + Pong => Err(Self::Error::WsImplNotSupported), + } + } +} + +impl TryFrom> for crate::ws::Frame { + type Error = crate::WispError; + fn try_from(mut frame: Frame) -> Result { + let opcode = frame.opcode.try_into()?; + Ok(Self { + finished: frame.fin, + opcode, + payload: Bytes::copy_from_slice(frame.payload.to_mut()), + }) + } +} + +impl From for Frame<'_> { + fn from(frame: crate::ws::Frame) -> Self { + use crate::ws::OpCode::*; + match frame.opcode { + Text => Self::text(Payload::Owned(frame.payload.to_vec())), + Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), + Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())) + } + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index d8ade1c..897b7de 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "fastwebsockets")] +mod fastwebsockets; mod packet; pub mod ws; @@ -9,12 +11,14 @@ pub enum Role { Server, } +#[derive(Debug)] pub enum WispError { PacketTooSmall, InvalidPacketType, WsFrameInvalidType, WsFrameNotFinished, WsImplError(Box), + WsImplNotSupported, Utf8Error(std::str::Utf8Error), } @@ -23,3 +27,20 @@ impl From for WispError { WispError::Utf8Error(err) } } + +impl std::fmt::Display for WispError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + use WispError::*; + match self { + PacketTooSmall => write!(f, "Packet too small"), + InvalidPacketType => write!(f, "Invalid packet type"), + WsFrameInvalidType => write!(f, "Invalid websocket frame type"), + WsFrameNotFinished => write!(f, "Unfinished websocket frame"), + WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), + WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), + Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), + } + } +} + +impl std::error::Error for WispError {} diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 2a9667f..1c5f177 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -4,9 +4,9 @@ use bytes::{Buf, BufMut, Bytes}; #[derive(Debug)] pub struct ConnectPacket { - stream_type: u8, - destination_port: u16, - destination_hostname: String, + pub stream_type: u8, + pub destination_port: u16, + pub destination_hostname: String, } impl ConnectPacket {