diff --git a/Cargo.lock b/Cargo.lock index ad0128c..6326ae1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,7 @@ dependencies = [ "wasm-streams", "web-sys", "webpki-roots", + "wisp-mux", "ws_stream_wasm", ] @@ -1535,6 +1536,7 @@ dependencies = [ "futures", "futures-util", "tokio", + "ws_stream_wasm", ] [[package]] diff --git a/client/Cargo.toml b/client/Cargo.toml index 9fc6149..3a1b3bd 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -33,6 +33,7 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } rand = "0.8.5" base64 = "0.21.7" +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] } [dependencies.getrandom] features = ["js"] diff --git a/server/src/main.rs b/server/src/main.rs index 43993fd..11f6478 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -16,7 +16,7 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; -use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent}; +use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent}; type HttpBody = http_body_util::Empty; @@ -162,7 +162,7 @@ async fn handle_mux( } } } - Ok(false) + Ok(true) } async fn accept_ws( @@ -177,22 +177,17 @@ async fn accept_ws( let mut mux = ServerMux::new(rx, tx); mux.server_loop(&mut |packet, stream| async move { - let tx_cloned_err = stream.get_write_half(); - let tx_cloned_ok = stream.get_write_half(); - let stream_id = stream.stream_id; + let mut close_err = stream.get_close_handle(); + let mut close_ok = stream.get_close_handle(); tokio::spawn(async move { let _ = handle_mux(packet, stream) .or_else(|err| async move { - let _ = tx_cloned_err - .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03))) - .await; + let _ = close_err.close(0x03).await; Err(err) }) .and_then(|should_send| async move { if should_send { - tx_cloned_ok - .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02))) - .await + close_ok.close(0x02).await } else { Ok(()) } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 14d3e92..ae279ae 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,6 +10,8 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = futures = "0.3.30" futures-util = "0.3.30" tokio = { version = "1.35.1", optional = true } +ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] +ws_stream_wasm = ["dep:ws_stream_wasm"] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 6aacb28..f020bfd 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -28,10 +28,11 @@ impl From> for crate::ws::Frame { } } -impl From for Frame<'_> { - fn from(frame: crate::ws::Frame) -> Self { +impl TryFrom for Frame<'_> { + type Error = crate::WispError; + fn try_from(frame: crate::ws::Frame) -> Result { use crate::ws::OpCode::*; - match frame.opcode { + Ok(match frame.opcode { Text => Self::text(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), @@ -42,7 +43,7 @@ impl From for Frame<'_> { Payload::Owned(frame.payload.to_vec()), ), Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), - } + }) } } @@ -66,6 +67,6 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead impl crate::ws::WebSocketWrite for WebSocketWrite { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { - self.write_frame(frame.into()).await.map_err(|e| e.into()) + self.write_frame(frame.try_into()?).await.map_err(|e| e.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index c1318a5..673e182 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -2,13 +2,18 @@ mod fastwebsockets; mod packet; pub mod ws; +#[cfg(feature = "ws_stream_wasm")] +mod ws_stream_wasm; pub use crate::packet::*; use bytes::Bytes; use dashmap::DashMap; -use futures::{channel::mpsc, StreamExt}; -use std::sync::Arc; +use futures::{channel::mpsc, channel::oneshot, SinkExt, StreamExt}; +use std::sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, +}; #[derive(Debug, PartialEq)] pub enum Role { @@ -21,9 +26,13 @@ pub enum WispError { PacketTooSmall, InvalidPacketType, InvalidStreamType, + InvalidStreamId, + MaxStreamCountReached, + StreamAlreadyClosed, WsFrameInvalidType, WsFrameNotFinished, WsImplError(Box), + WsImplSocketClosed, WsImplNotSupported, Utf8Error(std::str::Utf8Error), Other(Box), @@ -42,9 +51,13 @@ impl std::fmt::Display for WispError { PacketTooSmall => write!(f, "Packet too small"), InvalidPacketType => write!(f, "Invalid packet type"), InvalidStreamType => write!(f, "Invalid stream type"), + InvalidStreamId => write!(f, "Invalid stream id"), + MaxStreamCountReached => write!(f, "Maximum stream count reached"), + StreamAlreadyClosed => write!(f, "Stream already closed"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), + WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), Other(err) => write!(f, "Other error: {:?}", err), @@ -59,6 +72,10 @@ pub enum WsEvent { Close(ClosePacket), } +pub enum MuxEvent { + Close(u32, u8, oneshot::Sender>), +} + pub struct MuxStream where W: ws::WebSocketWrite, @@ -66,21 +83,75 @@ where pub stream_id: u32, rx: mpsc::UnboundedReceiver, tx: ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, } impl MuxStream { pub async fn read(&mut self) -> Option { - self.rx.next().await + if self.is_closed.load(Ordering::Acquire) { + return None; + } + match self.rx.next().await? { + WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), + WsEvent::Close(packet) => { + self.is_closed.store(true, Ordering::Release); + Some(WsEvent::Close(packet)) + } + } } pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } self.tx - .write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data))) + .write_frame(Packet::new_data(self.stream_id, data).into()) .await } - pub fn get_write_half(&self) -> ws::LockedWebSocketWrite { - self.tx.clone() + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + stream_id: self.stream_id, + close_channel: self.close_channel.clone(), + is_closed: self.is_closed.clone(), + } + } + + pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .send(MuxEvent::Close(self.stream_id, reason, tx)) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +} + +pub struct MuxStreamCloser { + stream_id: u32, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamCloser { + pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .send(MuxEvent::Close(self.stream_id, reason, tx)) + .await + .map_err(|x| WispError::Other(Box::new(x)))?; + rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) } } @@ -92,14 +163,37 @@ where rx: R, tx: ws::LockedWebSocketWrite, stream_map: Arc>>, + close_rx: mpsc::UnboundedReceiver, + close_tx: mpsc::UnboundedSender, } impl ServerMux { pub fn new(read: R, write: W) -> Self { + let (tx, rx) = mpsc::unbounded::(); Self { rx: read, tx: ws::LockedWebSocketWrite::new(write), stream_map: Arc::new(DashMap::new()), + close_rx: rx, + close_tx: tx, + } + } + + pub async fn server_bg_loop(&mut self) { + while let Some(msg) = self.close_rx.next().await { + match msg { + MuxEvent::Close(stream_id, reason, channel) => { + if self.stream_map.clone().remove(&stream_id).is_some() { + let _ = channel.send( + self.tx + .write_frame(Packet::new_close(stream_id, reason).into()) + .await, + ); + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + } } } @@ -111,7 +205,7 @@ impl ServerMux { FR: std::future::Future>, { self.tx - .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX))) + .write_frame(Packet::new_continue(0, u32::MAX).into()) .await?; while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { @@ -127,14 +221,19 @@ impl ServerMux { stream_id: packet.stream_id, rx: ch_rx, tx: self.tx.clone(), + close_channel: self.close_tx.clone(), + is_closed: AtomicBool::new(false).into(), }, - ).await; + ) + .await; } Data(data) => { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Send(data)); self.tx - .write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX))) + .write_frame( + Packet::new_continue(packet.stream_id, u32::MAX).into(), + ) .await?; } } @@ -142,6 +241,7 @@ impl ServerMux { Close(inner_packet) => { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + self.stream_map.clone().remove(&packet.stream_id); } } } @@ -150,3 +250,97 @@ impl ServerMux { Ok(()) } } + +pub struct ClientMux +where + R: ws::WebSocketRead, + W: ws::WebSocketWrite, +{ + rx: R, + tx: ws::LockedWebSocketWrite, + stream_map: Arc>>, + next_free_stream_id: AtomicU32, + close_rx: mpsc::UnboundedReceiver, + close_tx: mpsc::UnboundedSender, +} + +impl ClientMux { + pub fn new(read: R, write: W) -> Self { + let (tx, rx) = mpsc::unbounded::(); + Self { + rx: read, + tx: ws::LockedWebSocketWrite::new(write), + stream_map: Arc::new(DashMap::new()), + next_free_stream_id: AtomicU32::new(1), + close_rx: rx, + close_tx: tx, + } + } + + pub async fn client_bg_loop(&mut self) { + while let Some(msg) = self.close_rx.next().await { + match msg { + MuxEvent::Close(stream_id, reason, channel) => { + if self.stream_map.clone().remove(&stream_id).is_some() { + let _ = channel.send( + self.tx + .write_frame(Packet::new_close(stream_id, reason).into()) + .await, + ); + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + } + } + } + + pub async fn client_loop(&mut self) -> Result<(), WispError> { + self.tx + .write_frame(Packet::new_continue(0, u32::MAX).into()) + .await?; + + while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { + if let Ok(packet) = Packet::try_from(frame) { + use PacketType::*; + match packet.packet { + Connect(_) => unreachable!(), + Data(data) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Send(data)); + } + } + Continue(_) => {} + Close(inner_packet) => { + if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { + let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); + self.stream_map.clone().remove(&packet.stream_id); + } + } + } + } + } + Ok(()) + } + + pub async fn client_new_stream( + &mut self, + ) -> Result, WispError> { + let (ch_tx, ch_rx) = mpsc::unbounded(); + let stream_id = self.next_free_stream_id.load(Ordering::Acquire); + self.next_free_stream_id.store( + stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?, + Ordering::Release, + ); + self.stream_map.clone().insert(stream_id, ch_tx); + Ok(MuxStream { + stream_id, + rx: ch_rx, + tx: self.tx.clone(), + close_channel: self.close_tx.clone(), + is_closed: AtomicBool::new(false).into(), + }) + } +} diff --git a/wisp/src/ws_stream_wasm.rs b/wisp/src/ws_stream_wasm.rs new file mode 100644 index 0000000..6e15816 --- /dev/null +++ b/wisp/src/ws_stream_wasm.rs @@ -0,0 +1,57 @@ +use futures::{SinkExt, StreamExt}; +use ws_stream_wasm::{WsErr, WsMessage, WsStream}; + +impl From for crate::ws::Frame { + fn from(msg: WsMessage) -> Self { + use crate::ws::OpCode; + match msg { + WsMessage::Text(str) => Self { + finished: true, + opcode: OpCode::Text, + payload: str.into(), + }, + WsMessage::Binary(bin) => Self { + finished: true, + opcode: OpCode::Binary, + payload: bin.into(), + }, + } + } +} + +impl TryFrom for WsMessage { + type Error = crate::WispError; + fn try_from(msg: crate::ws::Frame) -> Result { + use crate::ws::OpCode; + match msg.opcode { + OpCode::Text => Ok(Self::Text(std::str::from_utf8(&msg.payload)?.to_string())), + OpCode::Binary => Ok(Self::Binary(msg.payload.to_vec())), + _ => Err(Self::Error::WsImplNotSupported), + } + } +} + +impl From for crate::WispError { + fn from(err: WsErr) -> Self { + Self::WsImplError(Box::new(err)) + } +} + +impl crate::ws::WebSocketRead for WsStream { + async fn wisp_read_frame( + &mut self, + _: &mut crate::ws::LockedWebSocketWrite, + ) -> Result { + Ok(self + .next() + .await + .ok_or(crate::WispError::WsImplSocketClosed)? + .into()) + } +} + +impl crate::ws::WebSocketWrite for WsStream { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { + self.send(frame.try_into()?).await.map_err(|e| e.into()) + } +}