they should add a cancellation safety lint

This commit is contained in:
Toshit Chawda 2024-09-06 20:47:16 -07:00
parent 9d1604cc3e
commit 9d697416d9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 62 additions and 43 deletions

View file

@ -1,22 +1,19 @@
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
};
use crate::{
extensions::AnyProtocolExtension,
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError,
};
use nohash_hasher::IntMap;
use bytes::{Bytes, BytesMut};
use event_listener::Event;
use flume as mpsc;
use futures::{channel::oneshot, FutureExt};
use futures::{channel::oneshot, select, FutureExt};
use nohash_hasher::IntMap;
pub(crate) enum WsEvent {
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
@ -26,7 +23,7 @@ pub(crate) enum WsEvent {
u16,
oneshot::Sender<Result<MuxStream, WispError>>,
),
WispMessage(Frame<'static>, Option<Frame<'static>>),
WispMessage(Option<Packet<'static>>, Option<Frame<'static>>),
EndFut(Option<CloseReason>),
}
@ -43,12 +40,14 @@ struct MuxMapValue {
}
pub struct MuxInner<R: WebSocketRead + Send> {
rx: R,
// gets taken by the mux task
rx: Option<R>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
role: Role,
fut_rx: mpsc::Receiver<WsEvent>,
// gets taken by the mux task
fut_rx: Option<mpsc::Receiver<WsEvent>>,
fut_tx: mpsc::Sender<WsEvent>,
fut_exited: Arc<AtomicBool>,
@ -79,10 +78,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
(
Self {
rx,
rx: Some(rx),
tx,
fut_rx,
fut_rx: Some(fut_rx),
fut_tx,
fut_exited: fut_exited.clone(),
@ -115,10 +114,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
(
Self {
rx,
rx: Some(rx),
tx,
fut_rx,
fut_rx: Some(fut_rx),
fut_tx,
fut_exited: fut_exited.clone(),
@ -205,31 +204,52 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
stream.flow_control_event.notify(usize::MAX);
}
async fn get_message(&mut self) -> Result<Option<WsEvent>, WispError> {
futures::select! {
x = self.fut_rx.recv_async().fuse() => Ok(x.ok()),
x = self.rx.wisp_read_split(&self.tx).fuse() => {
let (mut frame, optional_frame) = x?;
if frame.opcode == OpCode::Close {
return Ok(None);
}
async fn process_wisp_message(
&mut self,
rx: &mut R,
msg: Result<(Frame<'static>, Option<Frame<'static>>), WispError>,
) -> Result<Option<WsEvent>, WispError> {
let (mut frame, optional_frame) = msg?;
if frame.opcode == OpCode::Close {
return Ok(None);
}
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
Ok(Some(WsEvent::WispMessage(frame, optional_frame)))
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
let packet =
Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?;
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
}
async fn stream_loop(&mut self) -> Result<(), WispError> {
let mut next_free_stream_id: u32 = 1;
while let Some(msg) = self.get_message().await? {
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let tx = self.tx.clone();
let fut_rx = self.fut_rx.take().ok_or(WispError::MuxTaskStarted)?;
let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse();
while let Some(msg) = select! {
x = recv_fut => {
drop(recv_fut);
recv_fut = fut_rx.recv_async().fuse();
Ok(x.ok())
},
x = read_fut => {
drop(read_fut);
let ret = self.process_wisp_message(&mut rx, x).await;
read_fut = rx.wisp_read_split(&tx).fuse();
ret
}
}? {
match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
@ -275,15 +295,8 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
break;
}
WsEvent::WispMessage(frame, optional_frame) => {
if let Some(packet) = Packet::maybe_handle_extension(
frame,
&mut self.extensions,
&mut self.rx,
&mut self.tx,
)
.await?
{
WsEvent::WispMessage(packet, optional_frame) => {
if let Some(packet) = packet {
let should_break = match self.role {
Role::Server => {
self.server_handle_packet(packet, optional_frame).await?

View file

@ -96,6 +96,8 @@ pub enum WispError {
MuxMessageFailedToRecv,
/// Multiplexor task ended.
MuxTaskEnded,
/// Multiplexor task already started.
MuxTaskStarted,
}
impl From<std::str::Utf8Error> for WispError {
@ -150,6 +152,7 @@ impl std::fmt::Display for WispError {
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
}
}
}

View file

@ -1,5 +1,8 @@
use crate::{
inner::WsEvent, sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError
inner::WsEvent,
sink_unfold,
ws::{Frame, LockedWebSocketWrite, Payload},
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
};
use bytes::{BufMut, Bytes, BytesMut};