use crate::{ sink_unfold, ws::{Frame, LockedWebSocketWrite}, CloseReason, Packet, Role, StreamType, WispError, }; use bytes::{BufMut, Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; use futures::{ channel::oneshot, select, stream::{self, IntoAsyncRead, SplitSink, SplitStream}, task::{Context, Poll}, AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, StreamExt, TryStreamExt, }; use pin_project_lite::pin_project; use std::{ pin::Pin, sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, Arc, }, task::ready, }; pub(crate) enum WsEvent { Close(Packet, oneshot::Sender>), CreateStream( StreamType, String, u16, oneshot::Sender>, ), EndFut(Option), } /// Read side of a multiplexor stream. pub struct MuxStreamRead { /// ID of the stream. pub stream_id: u32, /// Type of the stream. pub stream_type: StreamType, role: Role, tx: LockedWebSocketWrite, rx: mpsc::Receiver, is_closed: Arc, is_closed_event: Arc, flow_control: Arc, flow_control_read: AtomicU32, target_flow_control: u32, } impl MuxStreamRead { /// Read an event from the stream. pub async fn read(&self) -> Option { if self.is_closed.load(Ordering::Acquire) { return None; } let bytes = select! { x = self.rx.recv_async() => x.ok()?, _ = self.is_closed_event.listen().fuse() => return None }; if self.role == Role::Server && self.stream_type == StreamType::Tcp { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { self.tx .write_frame( Packet::new_continue( self.stream_id, self.flow_control.fetch_add(val, Ordering::AcqRel) + val, ) .into(), ) .await .ok()?; self.flow_control_read.store(0, Ordering::Release); } } Some(bytes) } pub(crate) fn into_stream(self) -> Pin + Send>> { Box::pin(stream::unfold(self, |rx| async move { Some((rx.read().await?, rx)) })) } } /// Write side of a multiplexor stream. pub struct MuxStreamWrite { /// ID of the stream. pub stream_id: u32, /// Type of the stream. pub stream_type: StreamType, role: Role, mux_tx: mpsc::Sender, tx: LockedWebSocketWrite, is_closed: Arc, continue_recieved: Arc, flow_control: Arc, } impl MuxStreamWrite { /// Write data to the stream. pub async fn write(&self, data: Bytes) -> Result<(), WispError> { if self.role == Role::Client && self.stream_type == StreamType::Tcp && self.flow_control.load(Ordering::Acquire) == 0 { self.continue_recieved.listen().await; } if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } self.tx .write_frame(Frame::from(Packet::new_data(self.stream_id, data))) .await?; if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( self.flow_control.load(Ordering::Acquire).saturating_sub(1), Ordering::Release, ); } Ok(()) } /// Get a handle to close the connection. /// /// Useful to close the connection without having access to the stream. /// /// # Example /// ``` /// let handle = stream.get_close_handle(); /// if let Err(error) = handle_stream(stream) { /// handle.close(0x01); /// } /// ``` pub fn get_close_handle(&self) -> MuxStreamCloser { MuxStreamCloser { stream_id: self.stream_id, close_channel: self.mux_tx.clone(), is_closed: self.is_closed.clone(), } } /// Get a protocol extension stream to send protocol extension packets. pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { MuxProtocolExtensionStream { stream_id: self.stream_id, tx: self.tx.clone(), is_closed: self.is_closed.clone(), } } /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } self.is_closed.store(true, Ordering::Release); let (tx, rx) = oneshot::channel::>(); self.mux_tx .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } pub(crate) fn into_sink(self) -> Pin + Send>> { let handle = self.get_close_handle(); Box::pin(sink_unfold::unfold( self, |tx, data| async move { tx.write(data).await?; Ok(tx) }, handle, move |handle| async { handle.close(CloseReason::Unknown).await?; Ok(handle) }, )) } } impl Drop for MuxStreamWrite { fn drop(&mut self) { if !self.is_closed.load(Ordering::Acquire) { self.is_closed.store(true, Ordering::Release); let (tx, _) = oneshot::channel(); let _ = self.mux_tx.send(WsEvent::Close( Packet::new_close(self.stream_id, CloseReason::Unknown), tx, )); } } } /// Multiplexor stream. pub struct MuxStream { /// ID of the stream. pub stream_id: u32, rx: MuxStreamRead, tx: MuxStreamWrite, } impl MuxStream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( stream_id: u32, role: Role, stream_type: StreamType, rx: mpsc::Receiver, mux_tx: mpsc::Sender, tx: LockedWebSocketWrite, is_closed: Arc, is_closed_event: Arc, flow_control: Arc, continue_recieved: Arc, target_flow_control: u32, ) -> Self { Self { stream_id, rx: MuxStreamRead { stream_id, stream_type, role, tx: tx.clone(), rx, is_closed: is_closed.clone(), is_closed_event: is_closed_event.clone(), flow_control: flow_control.clone(), flow_control_read: AtomicU32::new(0), target_flow_control, }, tx: MuxStreamWrite { stream_id, stream_type, role, mux_tx, tx, is_closed: is_closed.clone(), flow_control: flow_control.clone(), continue_recieved: continue_recieved.clone(), }, } } /// Read an event from the stream. pub async fn read(&self) -> Option { self.rx.read().await } /// Write data to the stream. pub async fn write(&self, data: Bytes) -> Result<(), WispError> { self.tx.write(data).await } /// Get a handle to close the connection. /// /// Useful to close the connection without having access to the stream. /// /// # Example /// ``` /// let handle = stream.get_close_handle(); /// if let Err(error) = handle_stream(stream) { /// handle.close(0x01); /// } /// ``` pub fn get_close_handle(&self) -> MuxStreamCloser { self.tx.get_close_handle() } /// Get a protocol extension stream to send protocol extension packets. pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { self.tx.get_protocol_extension_stream() } /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { self.tx.close(reason).await } /// Split the stream into read and write parts, consuming it. pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { (self.rx, self.tx) } /// Turn the stream into one that implements futures `Stream + Sink`, consuming it. pub fn into_io(self) -> MuxStreamIo { MuxStreamIo { rx: self.rx.into_stream(), tx: self.tx.into_sink(), } } } /// Close handle for a multiplexor stream. #[derive(Clone)] pub struct MuxStreamCloser { /// ID of the stream. pub stream_id: u32, close_channel: mpsc::Sender, is_closed: Arc, } impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } self.is_closed.store(true, Ordering::Release); let (tx, rx) = oneshot::channel::>(); self.close_channel .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } } /// Stream for sending arbitrary protocol extension packets. pub struct MuxProtocolExtensionStream { /// ID of the stream. pub stream_id: u32, pub(crate) tx: LockedWebSocketWrite, pub(crate) is_closed: Arc, } impl MuxProtocolExtensionStream { /// Send a protocol extension packet with this stream's ID. pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); encoded.put_u8(packet_type); encoded.put_u32_le(self.stream_id); encoded.extend(data); self.tx.write_frame(Frame::binary(encoded)).await } } pin_project! { /// Multiplexor stream that implements futures `Stream + Sink`. pub struct MuxStreamIo { #[pin] rx: Pin + Send>>, #[pin] tx: Pin + Send>>, } } impl MuxStreamIo { /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. pub fn into_asyncrw(self) -> MuxStreamAsyncRW { let (tx, rx) = self.split(); MuxStreamAsyncRW { rx: MuxStreamAsyncRead::new(rx), tx: MuxStreamAsyncWrite::new(tx), } } } impl Stream for MuxStreamIo { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().rx.poll_next(cx).map(|x| x.map(Ok)) } } impl Sink for MuxStreamIo { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() .tx .poll_ready(cx) .map_err(std::io::Error::other) } fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { self.project() .tx .start_send(item) .map_err(std::io::Error::other) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() .tx .poll_flush(cx) .map_err(std::io::Error::other) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() .tx .poll_close(cx) .map_err(std::io::Error::other) } } pin_project! { /// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. pub struct MuxStreamAsyncRW { #[pin] rx: MuxStreamAsyncRead, #[pin] tx: MuxStreamAsyncWrite, } } impl MuxStreamAsyncRW { /// Split the stream into read and write parts, consuming it. pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { (self.rx, self.tx) } } impl AsyncRead for MuxStreamAsyncRW { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { self.project().rx.poll_read(cx, buf) } fn poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll> { self.project().rx.poll_read_vectored(cx, bufs) } } impl AsyncBufRead for MuxStreamAsyncRW { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().rx.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { self.project().rx.consume(amt) } } impl AsyncWrite for MuxStreamAsyncRW { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.project().tx.poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().tx.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().tx.poll_close(cx) } } pin_project! { /// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`. pub struct MuxStreamAsyncRead { #[pin] rx: IntoAsyncRead>, } } impl MuxStreamAsyncRead { pub(crate) fn new(stream: SplitStream) -> Self { Self { rx: stream.into_async_read(), } } } impl AsyncRead for MuxStreamAsyncRead { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { self.project().rx.poll_read(cx, buf) } fn poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll> { self.project().rx.poll_read_vectored(cx, bufs) } } impl AsyncBufRead for MuxStreamAsyncRead { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().rx.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { self.project().rx.consume(amt) } } pin_project! { /// Write side of a multiplexor stream that implements futures `AsyncWrite`. pub struct MuxStreamAsyncWrite { #[pin] tx: SplitSink, } } impl MuxStreamAsyncWrite { pub(crate) fn new(sink: SplitSink) -> Self { Self { tx: sink } } } impl AsyncWrite for MuxStreamAsyncWrite { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let mut this = self.project(); ready!(this.tx.as_mut().poll_ready(cx))?; match this.tx.start_send(Bytes::copy_from_slice(buf)) { Ok(()) => Poll::Ready(Ok(buf.len())), Err(e) => Poll::Ready(Err(e)), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().tx.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().tx.poll_close(cx) } }