diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 673e182..2eb0594 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,15 +1,16 @@ #[cfg(feature = "fastwebsockets")] mod fastwebsockets; mod packet; +mod stream; pub mod ws; #[cfg(feature = "ws_stream_wasm")] mod ws_stream_wasm; pub use crate::packet::*; +pub use crate::stream::*; -use bytes::Bytes; use dashmap::DashMap; -use futures::{channel::mpsc, channel::oneshot, SinkExt, StreamExt}; +use futures::{channel::mpsc, StreamExt}; use std::sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, @@ -67,94 +68,6 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -pub enum WsEvent { - Send(Bytes), - Close(ClosePacket), -} - -pub enum MuxEvent { - Close(u32, u8, oneshot::Sender>), -} - -pub struct MuxStream -where - W: ws::WebSocketWrite, -{ - pub stream_id: u32, - rx: mpsc::UnboundedReceiver, - tx: ws::LockedWebSocketWrite, - close_channel: mpsc::UnboundedSender, - is_closed: Arc, -} - -impl MuxStream { - pub async fn read(&mut self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - return None; - } - match self.rx.next().await? { - WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), - WsEvent::Close(packet) => { - self.is_closed.store(true, Ordering::Release); - Some(WsEvent::Close(packet)) - } - } - } - - pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.tx - .write_frame(Packet::new_data(self.stream_id, data).into()) - .await - } - - pub fn get_close_handle(&self) -> MuxStreamCloser { - MuxStreamCloser { - stream_id: self.stream_id, - close_channel: self.close_channel.clone(), - is_closed: self.is_closed.clone(), - } - } - - pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send(MuxEvent::Close(self.stream_id, reason, tx)) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; - self.is_closed.store(true, Ordering::Release); - Ok(()) - } -} - -pub struct MuxStreamCloser { - stream_id: u32, - close_channel: mpsc::UnboundedSender, - is_closed: Arc, -} - -impl MuxStreamCloser { - pub async fn close(&mut self, reason: u8) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send(MuxEvent::Close(self.stream_id, reason, tx)) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; - self.is_closed.store(true, Ordering::Release); - Ok(()) - } -} - pub struct ServerMux where R: ws::WebSocketRead, @@ -217,13 +130,13 @@ impl ServerMux { self.stream_map.clone().insert(packet.stream_id, ch_tx); let _ = handler_fn( inner_packet, - MuxStream { - stream_id: packet.stream_id, - rx: ch_rx, - tx: self.tx.clone(), - close_channel: self.close_tx.clone(), - is_closed: AtomicBool::new(false).into(), - }, + MuxStream::new( + packet.stream_id, + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + ), ) .await; } @@ -335,12 +248,12 @@ impl ClientMux { Ordering::Release, ); self.stream_map.clone().insert(stream_id, ch_tx); - Ok(MuxStream { + Ok(MuxStream::new( stream_id, - rx: ch_rx, - tx: self.tx.clone(), - close_channel: self.close_tx.clone(), - is_closed: AtomicBool::new(false).into(), - }) + ch_rx, + self.tx.clone(), + self.close_tx.clone(), + AtomicBool::new(false).into(), + )) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs new file mode 100644 index 0000000..8c8a76b --- /dev/null +++ b/wisp/src/stream.rs @@ -0,0 +1,168 @@ +use bytes::Bytes; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +pub enum WsEvent { + Send(Bytes), + Close(crate::ClosePacket), +} + +pub enum MuxEvent { + Close(u32, u8, oneshot::Sender>), +} + +pub struct MuxStreamRead { + pub stream_id: u32, + rx: mpsc::UnboundedReceiver, + is_closed: Arc, +} + +impl MuxStreamRead { + pub async fn read(&mut self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + return None; + } + match self.rx.next().await? { + WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)), + WsEvent::Close(packet) => { + self.is_closed.store(true, Ordering::Release); + Some(WsEvent::Close(packet)) + } + } + } +} + +pub struct MuxStreamWrite +where + W: crate::ws::WebSocketWrite, +{ + pub stream_id: u32, + tx: crate::ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamWrite { + pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + self.tx + .write_frame(crate::Packet::new_data(self.stream_id, data).into()) + .await + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + stream_id: self.stream_id, + close_channel: self.close_channel.clone(), + is_closed: self.is_closed.clone(), + } + } + + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) + .map_err(|x| crate::WispError::Other(Box::new(x)))?; + rx.await + .map_err(|x| crate::WispError::Other(Box::new(x)))??; + + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +} + +impl Drop for MuxStreamWrite { + fn drop(&mut self) { + let (tx, _) = oneshot::channel::>(); + let _ = self + .close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, 0x01, tx)); + } +} + +pub struct MuxStream +where + W: crate::ws::WebSocketWrite, +{ + pub stream_id: u32, + rx: MuxStreamRead, + tx: MuxStreamWrite, +} + +impl MuxStream { + pub(crate) fn new( + stream_id: u32, + rx: mpsc::UnboundedReceiver, + tx: crate::ws::LockedWebSocketWrite, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, + ) -> Self { + Self { + stream_id, + rx: MuxStreamRead { + stream_id, + rx, + is_closed: is_closed.clone(), + }, + tx: MuxStreamWrite { + stream_id, + tx, + close_channel, + is_closed: is_closed.clone(), + }, + } + } + + pub async fn read(&mut self) -> Option { + self.rx.read().await + } + + pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> { + self.tx.write(data).await + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + self.tx.get_close_handle() + } + + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + self.tx.close(reason).await + } + + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + (self.rx, self.tx) + } +} + +pub struct MuxStreamCloser { + stream_id: u32, + close_channel: mpsc::UnboundedSender, + is_closed: Arc, +} + +impl MuxStreamCloser { + pub async fn close(&mut self, reason: u8) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + let (tx, rx) = oneshot::channel::>(); + self.close_channel + .unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) + .map_err(|x| crate::WispError::Other(Box::new(x)))?; + rx.await + .map_err(|x| crate::WispError::Other(Box::new(x)))??; + self.is_closed.store(true, Ordering::Release); + Ok(()) + } +}