diff --git a/Cargo.lock b/Cargo.lock index 0b35945..ad0128c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,9 +1530,11 @@ name = "wisp-mux" version = "0.1.0" dependencies = [ "bytes", + "dashmap", "fastwebsockets", "futures", "futures-util", + "tokio", ] [[package]] diff --git a/server/src/lockedws.rs b/server/src/lockedws.rs deleted file mode 100644 index ffc8e13..0000000 --- a/server/src/lockedws.rs +++ /dev/null @@ -1,20 +0,0 @@ -use fastwebsockets::{WebSocketWrite, Frame, WebSocketError}; -use hyper::upgrade::Upgraded; -use hyper_util::rt::TokioIo; -use std::sync::Arc; -use tokio::{io::WriteHalf, sync::Mutex}; - -type Ws = WebSocketWrite>>; - -#[derive(Clone)] -pub struct LockedWebSocketWrite(Arc>); - -impl LockedWebSocketWrite { - pub fn new(ws: Ws) -> Self { - Self(Arc::new(Mutex::new(ws))) - } - - pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> { - self.0.lock().await.write_frame(frame).await - } -} diff --git a/server/src/main.rs b/server/src/main.rs index 96e73c3..4dcbf0f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,25 +1,22 @@ -mod lockedws; - -use std::{io::Error, sync::Arc}; +#![feature(let_chains)] +use std::io::Error; use bytes::Bytes; use fastwebsockets::{ - upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError, + upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, + WebSocketError, }; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::mpsc, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; -use wisp_mux::{ws, Packet, PacketType}; +use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent}; type HttpBody = http_body_util::Empty; @@ -73,37 +70,26 @@ async fn accept_http( addr: String, prefix: String, ) -> Result, WebSocketError> { - if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) { + if upgrade::is_upgrade_request(&req) + && req.uri().path().to_string().starts_with(&prefix) + && let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") + && protocol == "wisp-v1" + { let uri = req.uri().clone(); let (mut res, fut) = upgrade::upgrade(&mut req)?; - tokio::spawn(async move { - if *uri.path() != prefix { - if let Err(e) = - accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()) - .await - { - println!("{:?}: error in ws handling: {:?}", addr, e); - } - } else if let Err(e) = accept_ws(fut, addr.clone()).await { - println!("{:?}: error in ws handling: {:?}", addr, e); - } - }); - - if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") { - let first_protocol = protocol - .to_str() - .expect("failed to get protocol") - .split(',') - .next() - .expect("failed to get first protocol") - .trim(); - res.headers_mut().insert( - "Sec-Websocket-Protocol", - HeaderValue::from_str(first_protocol).unwrap(), - ); + if *uri.path() != prefix { + tokio::spawn(async move { + accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await + }); + } else { + tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); } + res.headers_mut().insert( + "Sec-Websocket-Protocol", + HeaderValue::from_str("wisp-v1").unwrap(), + ); Ok(res) } else { Ok(Response::builder() @@ -113,127 +99,80 @@ async fn accept_http( } } -enum WsEvent { - Send(Bytes), - Close, +async fn handle_mux( + packet: ConnectPacket, + mut stream: MuxStream, +) -> Result<(), WispError> { + let uri = format!( + "{}:{}", + packet.destination_hostname, packet.destination_port + ); + match packet.stream_type { + StreamType::Tcp => { + let tcp_stream = TcpStream::connect(uri) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); + + loop { + tokio::select! { + event = stream.read() => { + match event { + Some(event) => match event { + WsEvent::Send(data) => { + tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?; + } + WsEvent::Close(_) => break, + }, + None => break + } + }, + event = tcp_stream_framed.next() => { + match event.and_then(|x| x.ok()) { + Some(event) => stream.write(event.into()).await?, + None => break + } + } + } + } + } + StreamType::Udp => todo!(), + } + Ok(()) } async fn accept_ws( fut: upgrade::UpgradeFut, addr: String, -) -> Result<(), Box> { - let (mut rx, tx) = fut.await?.split(tokio::io::split); - let tx = lockedws::LockedWebSocketWrite::new(tx); - - let stream_map = Arc::new(dashmap::DashMap::>::new()); +) -> Result<(), Box> { + let (rx, tx) = fut.await?.split(tokio::io::split); + let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); - tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) - .await?; + let mut mux = ServerMux::new(rx, tx); - while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await { - if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { - use PacketType::*; - match packet.packet { - Connect(inner_packet) => { - let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::(); - stream_map.clone().insert(packet.stream_id, ch_tx); - let tx_cloned = tx.clone(); - tokio::spawn(async move { - let tcp_stream = match TcpStream::connect(format!( - "{}:{}", - inner_packet.destination_hostname, inner_packet.destination_port - )) + mux.server_loop(&mut |packet, stream| async move { + let tx_cloned = stream.get_write_half(); + let stream_id = stream.stream_id; + tokio::spawn(async move { + let _ = handle_mux(packet, stream) + .or_else(|err| async { + let _ = tx_cloned + .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03))) + .await; + Err(err) + }) + .and_then(|_| async { + tx_cloned + .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) .await - { - Ok(stream) => stream, - Err(err) => { - tx_cloned - .write_frame( - ws::Frame::from(Packet::new_close(packet.stream_id, 0x03)) - .into(), - ) - .await - .map_err(std::io::Error::other)?; - return Err(Box::new(err)); - } - }; - let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new()); - loop { - tokio::select! { - event = tcp_stream.next() => { - if let Some(res) = event { - match res { - Ok(buf) => { - tx_cloned.write_frame( - ws::Frame::from( - Packet::new_data( - packet.stream_id, - buf.to_vec() - ) - ).into() - ).await.map_err(std::io::Error::other)?; - } - Err(err) => { - tx_cloned - .write_frame( - ws::Frame::from(Packet::new_close( - packet.stream_id, - 0x03, - )) - .into(), - ) - .await - .map_err(std::io::Error::other)?; - return Err(Box::new(err)); - } - } - } - } - event = ch_rx.recv() => { - if let Some(event) = event { - match event { - WsEvent::Send(buf) => { - tcp_stream.send(buf).await?; - tx_cloned - .write_frame( - ws::Frame::from( - Packet::new_continue( - packet.stream_id, - u32::MAX - ) - ).into() - ).await.map_err(std::io::Error::other)?; - } - WsEvent::Close => { - break; - } - } - } else { - break; - } - } - }; - } - Ok(()) - }); - } - Data(inner_packet) => { - println!("recieved data for {:?}", packet.stream_id); - if let Some(stream) = stream_map.clone().get(&packet.stream_id) { - let _ = stream.send(WsEvent::Send(inner_packet.into())); - } - } - Continue(_) => unreachable!(), - Close(_) => { - if let Some(stream) = stream_map.clone().get(&packet.stream_id) { - let _ = stream.send(WsEvent::Close); - } - } - } - } - } + }) + .await; + }); + Ok(()) + }) + .await?; println!("{:?}: disconnected", addr); Ok(()) @@ -243,7 +182,7 @@ async fn accept_wsproxy( fut: upgrade::UpgradeFut, incoming_uri: &str, addr: String, -) -> Result<(), Box> { +) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); println!("{:?}: connected (wsproxy)", addr); diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 693c91e..14d3e92 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -5,9 +5,11 @@ edition = "2021" [dependencies] bytes = "1.5.0" -fastwebsockets = { version = "0.6.0", optional = true } +dashmap = "5.5.3" +fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" +tokio = { version = "1.35.1", optional = true } [features] -fastwebsockets = ["dep:fastwebsockets"] +fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index ba499ac..6aacb28 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,30 +1,30 @@ use bytes::Bytes; -use fastwebsockets::{Payload, Frame, OpCode}; +use fastwebsockets::{ + FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, +}; +use tokio::io::{AsyncRead, AsyncWrite}; -impl TryFrom for crate::ws::OpCode { - type Error = crate::WispError; - fn try_from(opcode: OpCode) -> Result { +impl From for crate::ws::OpCode { + fn from(opcode: OpCode) -> Self { use OpCode::*; match opcode { - Continuation => Err(Self::Error::WsImplNotSupported), - Text => Ok(Self::Text), - Binary => Ok(Self::Binary), - Close => Ok(Self::Close), - Ping => Err(Self::Error::WsImplNotSupported), - Pong => Err(Self::Error::WsImplNotSupported), + Continuation => unreachable!(), + Text => Self::Text, + Binary => Self::Binary, + Close => Self::Close, + Ping => Self::Ping, + Pong => Self::Pong, } } } -impl TryFrom> for crate::ws::Frame { - type Error = crate::WispError; - fn try_from(mut frame: Frame) -> Result { - let opcode = frame.opcode.try_into()?; - Ok(Self { +impl From> for crate::ws::Frame { + fn from(mut frame: Frame) -> Self { + Self { finished: frame.fin, - opcode, + opcode: frame.opcode.into(), payload: Bytes::copy_from_slice(frame.payload.to_mut()), - }) + } } } @@ -34,7 +34,38 @@ impl From for Frame<'_> { match frame.opcode { Text => Self::text(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), - Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())) + Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), + Ping => Self::new( + true, + OpCode::Ping, + None, + Payload::Owned(frame.payload.to_vec()), + ), + Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), } } } + +impl From for crate::WispError { + fn from(err: WebSocketError) -> Self { + Self::WsImplError(Box::new(err)) + } +} + +impl crate::ws::WebSocketRead for FragmentCollectorRead { + async fn wisp_read_frame( + &mut self, + tx: &mut crate::ws::LockedWebSocketWrite, + ) -> Result { + Ok(self + .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .await? + .into()) + } +} + +impl crate::ws::WebSocketWrite for WebSocketWrite { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { + self.write_frame(frame.into()).await.map_err(|e| e.into()) + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 897b7de..c1318a5 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -5,6 +5,11 @@ pub mod ws; pub use crate::packet::*; +use bytes::Bytes; +use dashmap::DashMap; +use futures::{channel::mpsc, StreamExt}; +use std::sync::Arc; + #[derive(Debug, PartialEq)] pub enum Role { Client, @@ -15,11 +20,13 @@ pub enum Role { pub enum WispError { PacketTooSmall, InvalidPacketType, + InvalidStreamType, WsFrameInvalidType, WsFrameNotFinished, - WsImplError(Box), + WsImplError(Box), WsImplNotSupported, Utf8Error(std::str::Utf8Error), + Other(Box), } impl From for WispError { @@ -34,13 +41,112 @@ impl std::fmt::Display for WispError { match self { PacketTooSmall => write!(f, "Packet too small"), InvalidPacketType => write!(f, "Invalid packet type"), + InvalidStreamType => write!(f, "Invalid stream type"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), + Other(err) => write!(f, "Other error: {:?}", err), } } } impl std::error::Error for WispError {} + +pub enum WsEvent { + Send(Bytes), + Close(ClosePacket), +} + +pub struct MuxStream +where + W: ws::WebSocketWrite, +{ + pub stream_id: u32, + rx: mpsc::UnboundedReceiver, + tx: ws::LockedWebSocketWrite, +} + +impl MuxStream { + pub async fn read(&mut self) -> Option { + self.rx.next().await + } + + pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + self.tx + .write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data))) + .await + } + + pub fn get_write_half(&self) -> ws::LockedWebSocketWrite { + self.tx.clone() + } +} + +pub struct ServerMux +where + R: ws::WebSocketRead, + W: ws::WebSocketWrite, +{ + rx: R, + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, +} + +impl ServerMux { + pub fn new(read: R, write: W) -> Self { + Self { + rx: read, + tx: ws::LockedWebSocketWrite::new(write), + stream_map: Arc::new(DashMap::new()), + } + } + + pub async fn server_loop( + &mut self, + handler_fn: &mut impl Fn(ConnectPacket, MuxStream) -> FR, + ) -> Result<(), WispError> + where + FR: std::future::Future>, + { + self.tx + .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX))) + .await?; + + while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(inner_packet) => { + let (ch_tx, ch_rx) = mpsc::unbounded(); + self.stream_map.clone().insert(packet.stream_id, ch_tx); + let _ = handler_fn( + inner_packet, + MuxStream { + stream_id: packet.stream_id, + rx: ch_rx, + tx: self.tx.clone(), + }, + ).await; + } + Data(data) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Send(data)); + self.tx + .write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX))) + .await?; + } + } + Continue(_) => unreachable!(), + Close(inner_packet) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + } + } + } + } + } + Ok(()) + } +} diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 1c5f177..98eb20e 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -2,15 +2,33 @@ use crate::ws; use crate::WispError; use bytes::{Buf, BufMut, Bytes}; +#[derive(Debug)] +pub enum StreamType { + Tcp = 0x01, + Udp = 0x02, +} + +impl TryFrom for StreamType { + type Error = WispError; + fn try_from(stream_type: u8) -> Result { + use StreamType::*; + match stream_type { + 0x01 => Ok(Tcp), + 0x02 => Ok(Udp), + _ => Err(Self::Error::InvalidStreamType), + } + } +} + #[derive(Debug)] pub struct ConnectPacket { - pub stream_type: u8, + pub stream_type: StreamType, pub destination_port: u16, pub destination_hostname: String, } impl ConnectPacket { - pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self { + pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self { Self { stream_type, destination_port, @@ -26,7 +44,7 @@ impl TryFrom for ConnectPacket { return Err(Self::Error::PacketTooSmall); } Ok(Self { - stream_type: bytes.get_u8(), + stream_type: bytes.get_u8().try_into()?, destination_port: bytes.get_u16_le(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(), }) @@ -36,7 +54,7 @@ impl TryFrom for ConnectPacket { impl From for Vec { fn from(packet: ConnectPacket) -> Self { let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); - encoded.put_u8(packet.stream_type); + encoded.put_u8(packet.stream_type as u8); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); encoded @@ -108,7 +126,7 @@ impl From for Vec { #[derive(Debug)] pub enum PacketType { Connect(ConnectPacket), - Data(Vec), + Data(Bytes), Continue(ContinuePacket), Close(ClosePacket), } @@ -130,7 +148,7 @@ impl From for Vec { use PacketType::*; match packet { Connect(x) => x.into(), - Data(x) => x, + Data(x) => x.to_vec(), Continue(x) => x.into(), Close(x) => x.into(), } @@ -150,7 +168,7 @@ impl Packet { pub fn new_connect( stream_id: u32, - stream_type: u8, + stream_type: StreamType, destination_port: u16, destination_hostname: String, ) -> Self { @@ -164,7 +182,7 @@ impl Packet { } } - pub fn new_data(stream_id: u32, data: Vec) -> Self { + pub fn new_data(stream_id: u32, data: Bytes) -> Self { Self { stream_id, packet: PacketType::Data(data), @@ -198,7 +216,7 @@ impl TryFrom for Packet { stream_id: bytes.get_u32_le(), packet: match packet_type { 0x01 => Connect(ConnectPacket::try_from(bytes)?), - 0x02 => Data(bytes.to_vec()), + 0x02 => Data(bytes), 0x03 => Continue(ContinuePacket::try_from(bytes)?), 0x04 => Close(ClosePacket::try_from(bytes)?), _ => return Err(Self::Error::InvalidPacketType), diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index fbb1e56..dc8bdcc 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -1,10 +1,14 @@ use bytes::Bytes; +use futures::lock::Mutex; +use std::sync::Arc; #[derive(Debug, PartialEq, Clone, Copy)] pub enum OpCode { Text, Binary, Close, + Ping, + Pong, } pub struct Frame { @@ -38,3 +42,37 @@ impl Frame { } } } + +pub trait WebSocketRead { + fn wisp_read_frame( + &mut self, + tx: &mut crate::ws::LockedWebSocketWrite, + ) -> impl std::future::Future>; +} + +pub trait WebSocketWrite { + fn wisp_write_frame( + &mut self, + frame: Frame, + ) -> impl std::future::Future>; +} + +pub struct LockedWebSocketWrite(Arc>) +where + S: WebSocketWrite; + +impl LockedWebSocketWrite { + pub fn new(ws: S) -> Self { + Self(Arc::new(Mutex::new(ws))) + } + + pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { + self.0.lock().await.wisp_write_frame(frame).await + } +} + +impl Clone for LockedWebSocketWrite { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +}