From 9d697416d9799d11b9321c78ca1cc9c0d2923f42 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 6 Sep 2024 20:47:16 -0700 Subject: [PATCH] they should add a cancellation safety lint --- wisp/src/inner.rs | 97 ++++++++++++++++++++++++++-------------------- wisp/src/lib.rs | 3 ++ wisp/src/stream.rs | 5 ++- 3 files changed, 62 insertions(+), 43 deletions(-) diff --git a/wisp/src/inner.rs b/wisp/src/inner.rs index 70a2f97..58a2338 100644 --- a/wisp/src/inner.rs +++ b/wisp/src/inner.rs @@ -1,22 +1,19 @@ -use std::{ - sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, - }, +use std::sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, }; - use crate::{ extensions::AnyProtocolExtension, ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; -use nohash_hasher::IntMap; use bytes::{Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; -use futures::{channel::oneshot, FutureExt}; +use futures::{channel::oneshot, select, FutureExt}; +use nohash_hasher::IntMap; pub(crate) enum WsEvent { Close(Packet<'static>, oneshot::Sender>), @@ -26,7 +23,7 @@ pub(crate) enum WsEvent { u16, oneshot::Sender>, ), - WispMessage(Frame<'static>, Option>), + WispMessage(Option>, Option>), EndFut(Option), } @@ -43,12 +40,14 @@ struct MuxMapValue { } pub struct MuxInner { - rx: R, + // gets taken by the mux task + rx: Option, tx: LockedWebSocketWrite, extensions: Vec, role: Role, - fut_rx: mpsc::Receiver, + // gets taken by the mux task + fut_rx: Option>, fut_tx: mpsc::Sender, fut_exited: Arc, @@ -79,10 +78,10 @@ impl MuxInner { ( Self { - rx, + rx: Some(rx), tx, - fut_rx, + fut_rx: Some(fut_rx), fut_tx, fut_exited: fut_exited.clone(), @@ -115,10 +114,10 @@ impl MuxInner { ( Self { - rx, + rx: Some(rx), tx, - fut_rx, + fut_rx: Some(fut_rx), fut_tx, fut_exited: fut_exited.clone(), @@ -205,31 +204,52 @@ impl MuxInner { stream.flow_control_event.notify(usize::MAX); } - async fn get_message(&mut self) -> Result, WispError> { - futures::select! { - x = self.fut_rx.recv_async().fuse() => Ok(x.ok()), - x = self.rx.wisp_read_split(&self.tx).fuse() => { - let (mut frame, optional_frame) = x?; - if frame.opcode == OpCode::Close { - return Ok(None); - } + async fn process_wisp_message( + &mut self, + rx: &mut R, + msg: Result<(Frame<'static>, Option>), WispError>, + ) -> Result, WispError> { + let (mut frame, optional_frame) = msg?; + if frame.opcode == OpCode::Close { + return Ok(None); + } - if let Some(ref extra_frame) = optional_frame { - if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { - let mut payload = BytesMut::from(frame.payload); - payload.extend_from_slice(&extra_frame.payload); - frame.payload = Payload::Bytes(payload); - } - } - - Ok(Some(WsEvent::WispMessage(frame, optional_frame))) + if let Some(ref extra_frame) = optional_frame { + if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { + let mut payload = BytesMut::from(frame.payload); + payload.extend_from_slice(&extra_frame.payload); + frame.payload = Payload::Bytes(payload); } } + + let packet = + Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?; + + Ok(Some(WsEvent::WispMessage(packet, optional_frame))) } async fn stream_loop(&mut self) -> Result<(), WispError> { let mut next_free_stream_id: u32 = 1; - while let Some(msg) = self.get_message().await? { + + let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; + let tx = self.tx.clone(); + let fut_rx = self.fut_rx.take().ok_or(WispError::MuxTaskStarted)?; + + let mut recv_fut = fut_rx.recv_async().fuse(); + let mut read_fut = rx.wisp_read_split(&tx).fuse(); + while let Some(msg) = select! { + x = recv_fut => { + drop(recv_fut); + recv_fut = fut_rx.recv_async().fuse(); + Ok(x.ok()) + }, + x = read_fut => { + drop(read_fut); + let ret = self.process_wisp_message(&mut rx, x).await; + read_fut = rx.wisp_read_split(&tx).fuse(); + ret + } + }? { match msg { WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { @@ -275,15 +295,8 @@ impl MuxInner { } break; } - WsEvent::WispMessage(frame, optional_frame) => { - if let Some(packet) = Packet::maybe_handle_extension( - frame, - &mut self.extensions, - &mut self.rx, - &mut self.tx, - ) - .await? - { + WsEvent::WispMessage(packet, optional_frame) => { + if let Some(packet) = packet { let should_break = match self.role { Role::Server => { self.server_handle_packet(packet, optional_frame).await? diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 384d0a1..407df26 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -96,6 +96,8 @@ pub enum WispError { MuxMessageFailedToRecv, /// Multiplexor task ended. MuxTaskEnded, + /// Multiplexor task already started. + MuxTaskStarted, } impl From for WispError { @@ -150,6 +152,7 @@ impl std::fmt::Display for WispError { Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), + Self::MuxTaskStarted => write!(f, "Multiplexor task already started"), } } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 90e3ffc..4972a5b 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,5 +1,8 @@ use crate::{ - inner::WsEvent, sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError + inner::WsEvent, + sink_unfold, + ws::{Frame, LockedWebSocketWrite, Payload}, + AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, }; use bytes::{BufMut, Bytes, BytesMut};