use knockoff dynosaur to remove async_trait on wsr/wsw

This commit is contained in:
Toshit Chawda 2024-11-23 15:00:12 -08:00
parent 5e54465e58
commit 9129d767f8
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
31 changed files with 692 additions and 258 deletions

View file

@ -5,23 +5,23 @@ use std::sync::{
use crate::{
extensions::AnyProtocolExtension,
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError,
};
use bytes::{Bytes, BytesMut};
use event_listener::Event;
use flume as mpsc;
use futures::{channel::oneshot, select, FutureExt};
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
use nohash_hasher::IntMap;
pub(crate) enum WsEvent {
pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
CreateStream(
StreamType,
String,
u16,
oneshot::Sender<Result<MuxStream, WispError>>,
oneshot::Sender<Result<MuxStream<W>, WispError>>,
),
SendPing(Payload<'static>, oneshot::Sender<Result<(), WispError>>),
SendPong(Payload<'static>),
@ -43,20 +43,21 @@ struct MuxMapValue {
is_closed_event: Arc<Event>,
}
pub struct MuxInner<R: WebSocketRead + Send> {
pub struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
// gets taken by the mux task
rx: Option<R>,
// gets taken by the mux task
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
tx: LockedWebSocketWrite<W>,
// gets taken by the mux task
extensions: Option<Vec<AnyProtocolExtension>>,
tcp_extensions: Vec<u8>,
role: Role,
// gets taken by the mux task
actor_rx: Option<mpsc::Receiver<WsEvent>>,
actor_tx: mpsc::Sender<WsEvent>,
actor_rx: Option<mpsc::Receiver<WsEvent<W>>>,
actor_tx: mpsc::Sender<WsEvent<W>>,
fut_exited: Arc<AtomicBool>,
stream_map: IntMap<u32, MuxMapValue>,
@ -64,16 +65,16 @@ pub struct MuxInner<R: WebSocketRead + Send> {
buffer_size: u32,
target_buffer_size: u32,
server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>,
server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>,
}
pub struct MuxInnerResult<R: WebSocketRead + Send> {
pub mux: MuxInner<R>,
pub struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
pub mux: MuxInner<R, W>,
pub actor_exited: Arc<AtomicBool>,
pub actor_tx: mpsc::Sender<WsEvent>,
pub actor_tx: mpsc::Sender<WsEvent<W>>,
}
impl<R: WebSocketRead + Send> MuxInner<R> {
impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec<u8> {
extensions
.iter()
@ -83,18 +84,19 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
.collect()
}
#[allow(clippy::type_complexity)]
pub fn new_server(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
tx: LockedWebSocketWrite<W>,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
) -> (
MuxInnerResult<R>,
mpsc::Receiver<(ConnectPacket, MuxStream)>,
MuxInnerResult<R, W>,
mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
) {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false));
@ -110,7 +112,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
fut_exited: fut_exited.clone(),
tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions,
extensions: Some(extensions),
buffer_size,
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32,
@ -130,12 +132,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
pub fn new_client(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
tx: LockedWebSocketWrite<W>,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
) -> MuxInnerResult<R> {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
) -> MuxInnerResult<R, W> {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false));
@ -150,7 +152,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
fut_exited: fut_exited.clone(),
tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions,
extensions: Some(extensions),
buffer_size,
target_buffer_size: 0,
@ -183,7 +185,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
&mut self,
stream_id: u32,
stream_type: StreamType,
) -> Result<(MuxMapValue, MuxStream), WispError> {
) -> Result<(MuxMapValue, MuxStream<W>), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
self.buffer_size as usize
} else {
@ -241,11 +243,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
async fn process_wisp_message(
&mut self,
rx: &mut R,
msg: Result<(Frame<'static>, Option<Frame<'static>>), WispError>,
) -> Result<Option<WsEvent>, WispError> {
let (mut frame, optional_frame) = msg?;
tx: &LockedWebSocketWrite<W>,
extensions: &mut [AnyProtocolExtension],
msg: (Frame<'static>, Option<Frame<'static>>),
) -> Result<Option<WsEvent<W>>, WispError> {
let (mut frame, optional_frame) = msg;
if frame.opcode == OpCode::Close {
return Ok(None);
} else if frame.opcode == OpCode::Ping {
@ -262,8 +265,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
}
let packet =
Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?;
let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?;
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
}
@ -271,36 +273,47 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
async fn stream_loop(&mut self) -> Result<(), WispError> {
let mut next_free_stream_id: u32 = 1;
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
let tx = self.tx.clone();
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?;
if let Some(downgrade_packet) = maybe_downgrade_packet {
if self.handle_packet(downgrade_packet, None).await? {
return Ok(());
}
}
let mut read_stream = Box::pin(unfold(
(rx, tx.clone(), extensions),
|(mut rx, tx, mut extensions)| async {
let ret = async {
let msg = rx.wisp_read_split(&tx).await?;
Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await
}
.await;
ret.transpose().map(|x| (x, (rx, tx, extensions)))
},
))
.fuse();
let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse();
while let Some(msg) = select! {
x = recv_fut => {
drop(recv_fut);
recv_fut = fut_rx.recv_async().fuse();
Ok(x.ok())
},
x = read_fut => {
drop(read_fut);
let ret = self.process_wisp_message(&mut rx, x).await;
read_fut = rx.wisp_read_split(&tx).fuse();
ret
x = read_stream.next() => {
x.transpose()
}
}? {
match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let ret: Result<MuxStream<W>, WispError> = async {
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)