diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index d3eac0d..c48763d 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -153,7 +153,7 @@ async fn handle_stream( muxstream.write(&data[..size]).await?; } data = muxstream.read() => { - if let Some(data) = data { + if let Some(data) = data? { stream.send(&data).await?; } else { break Ok(()); diff --git a/wisp/src/mux/inner.rs b/wisp/src/mux/inner.rs index 301eacc..265bc48 100644 --- a/wisp/src/mux/inner.rs +++ b/wisp/src/mux/inner.rs @@ -177,7 +177,11 @@ impl MuxInner { stream_id: u32, stream_type: StreamType, ) -> Result<(MuxMapValue, MuxStream), WispError> { - let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { + self.buffer_size as usize + } else { + usize::MAX + }); let should_flow_control = self.tcp_extensions.contains(&stream_type.into()); let flow_control_event: Arc = Event::new().into(); diff --git a/wisp/src/stream/compat.rs b/wisp/src/stream/compat.rs index 508ef9a..dc5c7e5 100644 --- a/wisp/src/stream/compat.rs +++ b/wisp/src/stream/compat.rs @@ -73,7 +73,7 @@ pin_project! { /// Read side of a multiplexor stream that implements futures `Stream`. pub struct MuxStreamIoStream { #[pin] - pub(crate) rx: Pin + Send>>, + pub(crate) rx: Pin> + Send>>, pub(crate) is_closed: Arc, pub(crate) close_reason: Arc, } @@ -98,7 +98,7 @@ impl MuxStreamIoStream { impl Stream for MuxStreamIoStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx).map(|x| x.map(Ok)) + self.project().rx.poll_next(cx).map_err(std::io::Error::other) } } diff --git a/wisp/src/stream/mod.rs b/wisp/src/stream/mod.rs index 7635a7f..38e9f4c 100644 --- a/wisp/src/stream/mod.rs +++ b/wisp/src/stream/mod.rs @@ -3,7 +3,9 @@ mod sink_unfold; pub use compat::*; use crate::{ - inner::WsEvent, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError + inner::WsEvent, + ws::{Frame, LockedWebSocketWrite, Payload}, + AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -42,13 +44,13 @@ pub struct MuxStreamRead { impl MuxStreamRead { /// Read an event from the stream. - pub async fn read(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - return None; + pub async fn read(&self) -> Result, WispError> { + if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) { + return Ok(None); } let bytes = select! { - x = self.rx.recv_async() => x.ok()?, - _ = self.is_closed_event.listen().fuse() => return None + x = self.rx.recv_async() => x.map_err(|_| WispError::MuxMessageFailedToRecv)?, + _ = self.is_closed_event.listen().fuse() => return Ok(None) }; if self.role == Role::Server && self.should_flow_control { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; @@ -61,17 +63,18 @@ impl MuxStreamRead { ) .into(), ) - .await - .ok()?; + .await?; self.flow_control_read.store(0, Ordering::Release); } } - Some(bytes) + Ok(Some(bytes)) } - pub(crate) fn into_inner_stream(self) -> Pin + Send>> { + pub(crate) fn into_inner_stream( + self, + ) -> Pin> + Send>> { Box::pin(stream::unfold(self, |rx| async move { - Some((rx.read().await?, rx)) + Some((rx.read().await.transpose()?, rx)) })) } @@ -311,7 +314,7 @@ impl MuxStream { } /// Read an event from the stream. - pub async fn read(&self) -> Option { + pub async fn read(&self) -> Result, WispError> { self.rx.read().await }