they should add a cancellation safety lint

This commit is contained in:
Toshit Chawda 2024-09-06 20:47:16 -07:00
parent 9d1604cc3e
commit 9d697416d9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 62 additions and 43 deletions

View file

@ -1,22 +1,19 @@
use std::{ use std::sync::{
sync::{ atomic::{AtomicBool, AtomicU32, Ordering},
atomic::{AtomicBool, AtomicU32, Ordering}, Arc,
Arc,
},
}; };
use crate::{ use crate::{
extensions::AnyProtocolExtension, extensions::AnyProtocolExtension,
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError, Role, StreamType, WispError,
}; };
use nohash_hasher::IntMap;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use event_listener::Event; use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{channel::oneshot, FutureExt}; use futures::{channel::oneshot, select, FutureExt};
use nohash_hasher::IntMap;
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>), Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
@ -26,7 +23,7 @@ pub(crate) enum WsEvent {
u16, u16,
oneshot::Sender<Result<MuxStream, WispError>>, oneshot::Sender<Result<MuxStream, WispError>>,
), ),
WispMessage(Frame<'static>, Option<Frame<'static>>), WispMessage(Option<Packet<'static>>, Option<Frame<'static>>),
EndFut(Option<CloseReason>), EndFut(Option<CloseReason>),
} }
@ -43,12 +40,14 @@ struct MuxMapValue {
} }
pub struct MuxInner<R: WebSocketRead + Send> { pub struct MuxInner<R: WebSocketRead + Send> {
rx: R, // gets taken by the mux task
rx: Option<R>,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>, extensions: Vec<AnyProtocolExtension>,
role: Role, role: Role,
fut_rx: mpsc::Receiver<WsEvent>, // gets taken by the mux task
fut_rx: Option<mpsc::Receiver<WsEvent>>,
fut_tx: mpsc::Sender<WsEvent>, fut_tx: mpsc::Sender<WsEvent>,
fut_exited: Arc<AtomicBool>, fut_exited: Arc<AtomicBool>,
@ -79,10 +78,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
( (
Self { Self {
rx, rx: Some(rx),
tx, tx,
fut_rx, fut_rx: Some(fut_rx),
fut_tx, fut_tx,
fut_exited: fut_exited.clone(), fut_exited: fut_exited.clone(),
@ -115,10 +114,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
( (
Self { Self {
rx, rx: Some(rx),
tx, tx,
fut_rx, fut_rx: Some(fut_rx),
fut_tx, fut_tx,
fut_exited: fut_exited.clone(), fut_exited: fut_exited.clone(),
@ -205,31 +204,52 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
stream.flow_control_event.notify(usize::MAX); stream.flow_control_event.notify(usize::MAX);
} }
async fn get_message(&mut self) -> Result<Option<WsEvent>, WispError> { async fn process_wisp_message(
futures::select! { &mut self,
x = self.fut_rx.recv_async().fuse() => Ok(x.ok()), rx: &mut R,
x = self.rx.wisp_read_split(&self.tx).fuse() => { msg: Result<(Frame<'static>, Option<Frame<'static>>), WispError>,
let (mut frame, optional_frame) = x?; ) -> Result<Option<WsEvent>, WispError> {
if frame.opcode == OpCode::Close { let (mut frame, optional_frame) = msg?;
return Ok(None); if frame.opcode == OpCode::Close {
} return Ok(None);
}
if let Some(ref extra_frame) = optional_frame { if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload); let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload); payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload); frame.payload = Payload::Bytes(payload);
}
}
Ok(Some(WsEvent::WispMessage(frame, optional_frame)))
} }
} }
let packet =
Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?;
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
} }
async fn stream_loop(&mut self) -> Result<(), WispError> { async fn stream_loop(&mut self) -> Result<(), WispError> {
let mut next_free_stream_id: u32 = 1; let mut next_free_stream_id: u32 = 1;
while let Some(msg) = self.get_message().await? {
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let tx = self.tx.clone();
let fut_rx = self.fut_rx.take().ok_or(WispError::MuxTaskStarted)?;
let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse();
while let Some(msg) = select! {
x = recv_fut => {
drop(recv_fut);
recv_fut = fut_rx.recv_async().fuse();
Ok(x.ok())
},
x = read_fut => {
drop(read_fut);
let ret = self.process_wisp_message(&mut rx, x).await;
read_fut = rx.wisp_read_split(&tx).fuse();
ret
}
}? {
match msg { match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => { WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async { let ret: Result<MuxStream, WispError> = async {
@ -275,15 +295,8 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
} }
break; break;
} }
WsEvent::WispMessage(frame, optional_frame) => { WsEvent::WispMessage(packet, optional_frame) => {
if let Some(packet) = Packet::maybe_handle_extension( if let Some(packet) = packet {
frame,
&mut self.extensions,
&mut self.rx,
&mut self.tx,
)
.await?
{
let should_break = match self.role { let should_break = match self.role {
Role::Server => { Role::Server => {
self.server_handle_packet(packet, optional_frame).await? self.server_handle_packet(packet, optional_frame).await?

View file

@ -96,6 +96,8 @@ pub enum WispError {
MuxMessageFailedToRecv, MuxMessageFailedToRecv,
/// Multiplexor task ended. /// Multiplexor task ended.
MuxTaskEnded, MuxTaskEnded,
/// Multiplexor task already started.
MuxTaskStarted,
} }
impl From<std::str::Utf8Error> for WispError { impl From<std::str::Utf8Error> for WispError {
@ -150,6 +152,7 @@ impl std::fmt::Display for WispError {
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
} }
} }
} }

View file

@ -1,5 +1,8 @@
use crate::{ use crate::{
inner::WsEvent, sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError inner::WsEvent,
sink_unfold,
ws::{Frame, LockedWebSocketWrite, Payload},
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
}; };
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};