use blazingly fast flume channels 🚀

This commit is contained in:
Toshit Chawda 2024-04-15 17:42:49 -07:00
parent 5af56fe582
commit 5e741d3808
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
11 changed files with 225 additions and 135 deletions

View file

@ -3,7 +3,7 @@ use std::ops::Deref;
use async_trait::async_trait;
use bytes::BytesMut;
use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite
};
use tokio::io::{AsyncRead, AsyncWrite};
@ -77,4 +77,8 @@ impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into())
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into())
}
}

View file

@ -1,4 +1,4 @@
#![deny(missing_docs)]
#![deny(missing_docs, warnings)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! A library for easily creating [Wisp] clients and servers.
//!
@ -19,9 +19,8 @@ use bytes::Bytes;
use dashmap::DashMap;
use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use futures::{
channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt
};
use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay;
use std::{
sync::{
@ -151,11 +150,12 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
struct MuxMapValue {
stream: Mutex<mpsc::Sender<Bytes>>,
stream: mpsc::Sender<Bytes>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
}
struct MuxInner {
@ -170,7 +170,7 @@ impl MuxInner {
rx: R,
extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
@ -210,20 +210,60 @@ impl MuxInner {
};
for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release);
x.stream.lock().await.disconnect();
x.stream.lock().await.close_channel();
x.is_closed_event.notify(usize::MAX);
}
self.stream_map.clear();
let _ = self.tx.close().await;
ret
}
async fn create_new_stream(
&self,
stream_id: u32,
stream_type: StreamType,
role: Role,
stream_tx: mpsc::Sender<WsEvent>,
target_buffer_size: u32,
) -> Result<(MuxMapValue, MuxStream), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let is_closed_event: Arc<Event> = Event::new().into();
Ok((
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
},
MuxStream::new(
stream_id,
role,
stream_type,
ch_rx,
stream_tx.clone(),
is_closed,
is_closed_event,
flow_control,
flow_control_event,
target_buffer_size,
),
))
}
async fn stream_loop(
&self,
mut stream_rx: mpsc::Receiver<WsEvent>,
stream_rx: mpsc::Receiver<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>,
) {
let mut next_free_stream_id: u32 = 1;
while let Some(msg) = stream_rx.next().await {
while let Ok(msg) = stream_rx.recv_async().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self.stream_map.get(&packet.stream_id).is_some() {
@ -234,16 +274,20 @@ impl MuxInner {
}
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?;
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let (map_value, stream) = self
.create_new_stream(
stream_id,
stream_type,
Role::Client,
stream_tx.clone(),
0,
)
.await?;
self.tx
.write_frame(
@ -251,39 +295,19 @@ impl MuxInner {
)
.await?;
self.stream_map.insert(stream_id, map_value);
next_free_stream_id = next_stream_id;
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
Ok(MuxStream::new(
stream_id,
Role::Client,
stream_type,
ch_rx,
stream_tx.clone(),
is_closed,
flow_control,
flow_control_event,
0,
))
Ok(stream)
}
.await;
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
@ -305,8 +329,8 @@ impl MuxInner {
&self,
mut rx: R,
mut extensions: Vec<AnyProtocolExtension>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
stream_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
@ -325,42 +349,24 @@ impl MuxInner {
use PacketType::*;
match packet.packet_type {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.stream_map.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
let (map_value, stream) = self
.create_new_stream(
packet.stream_id,
inner_packet.stream_type,
Role::Server,
stream_tx.clone(),
target_buffer_size,
)
.await?;
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
stream_type,
ch_rx,
close_tx.clone(),
is_closed,
flow_control,
flow_control_event,
target_buffer_size,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
.send_async((inner_packet, stream))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
self.stream_map.insert(packet.stream_id, map_value);
}
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.lock().await.send(data).await;
let _ = stream.stream.send_async(data).await;
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
@ -379,8 +385,8 @@ impl MuxInner {
}
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
stream.is_closed_event.notify(usize::MAX);
drop(stream.stream)
}
}
}
@ -409,7 +415,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.lock().await.send(data).await;
let _ = stream.stream.send_async(data).await;
}
}
Continue(inner_packet) => {
@ -428,8 +434,8 @@ impl MuxInner {
}
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
stream.is_closed_event.notify(usize::MAX);
drop(stream.stream)
}
}
}
@ -465,7 +471,7 @@ pub struct ServerMux {
/// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
}
impl ServerMux {
@ -484,7 +490,7 @@ impl ServerMux {
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(Box::new(write));
@ -547,12 +553,12 @@ impl ServerMux {
/// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.next().await
self.muxstream_recv.recv_async().await.ok()
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
@ -574,6 +580,13 @@ impl ServerMux {
.await
}
}
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.close_tx.send(WsEvent::EndFut(None));
}
}
/// Client side multiplexor.
///
/// # Example
@ -595,7 +608,7 @@ pub struct ClientMux {
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>,
}
impl ClientMux {
@ -654,10 +667,10 @@ impl ClientMux {
extension.handle_handshake(&mut read, &write).await?;
}
let (tx, rx) = mpsc::channel::<WsEvent>(256);
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
Ok((
Self {
close_tx: tx.clone(),
stream_tx: tx.clone(),
downgraded,
supported_extension_ids: supported_extensions
.iter()
@ -697,16 +710,16 @@ impl ClientMux {
return Err(WispError::UdpExtensionNotSupported);
}
let (tx, rx) = oneshot::channel();
self.close_tx
.send(WsEvent::CreateStream(stream_type, host, port, tx))
self.stream_tx
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
self.stream_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
@ -728,3 +741,9 @@ impl ClientMux {
.await
}
}
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.stream_tx.send(WsEvent::EndFut(None));
}
}

View file

@ -1,13 +1,14 @@
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
use async_io_stream::IoStream;
pub use async_io_stream::IoStream;
use bytes::Bytes;
use event_listener::Event;
use flume as mpsc;
use futures::{
channel::{mpsc, oneshot},
stream,
channel::oneshot,
select, stream,
task::{Context, Poll},
Sink, SinkExt, Stream, StreamExt,
FutureExt, Sink, Stream,
};
use pin_project_lite::pin_project;
use std::{
@ -40,6 +41,7 @@ pub struct MuxStreamRead {
tx: mpsc::Sender<WsEvent>,
rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32,
target_flow_control: u32,
@ -51,13 +53,16 @@ impl MuxStreamRead {
if self.is_closed.load(Ordering::Acquire) {
return None;
}
let bytes = self.rx.next().await?;
let bytes = select! {
x = self.rx.recv_async() => x.ok()?,
_ = self.is_closed_event.listen().fuse() => return None
};
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
if val > self.target_flow_control {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send(WsEvent::SendPacket(
.send_async(WsEvent::SendPacket(
Packet::new_continue(
self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
@ -107,13 +112,13 @@ impl MuxStreamWrite {
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send(WsEvent::SendPacket(
.send_async(WsEvent::SendPacket(
Packet::new_data(self.stream_id, data),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store(
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
@ -151,13 +156,13 @@ impl MuxStreamWrite {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send(WsEvent::Close(
.send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(())
}
@ -179,6 +184,16 @@ impl MuxStreamWrite {
}
}
impl Drop for MuxStreamWrite {
fn drop(&mut self) {
if !self.is_closed.load(Ordering::Acquire) {
self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel();
let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx));
}
}
}
/// Multiplexor stream.
pub struct MuxStream {
/// ID of the stream.
@ -196,6 +211,7 @@ impl MuxStream {
rx: mpsc::Receiver<Bytes>,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>,
target_flow_control: u32,
@ -209,6 +225,7 @@ impl MuxStream {
tx: tx.clone(),
rx,
is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0),
target_flow_control,
@ -288,13 +305,13 @@ impl MuxStreamCloser {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.send(WsEvent::Close(
.send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(())
}

View file

@ -76,6 +76,9 @@ pub trait WebSocketRead {
pub trait WebSocketWrite {
/// Write a frame to the socket.
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
/// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>;
}
/// Locked WebSocket.
@ -88,9 +91,14 @@ impl LockedWebSocketWrite {
}
/// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await
}
/// Close the websocket.
pub async fn close(&self) -> Result<(), WispError> {
self.0.lock().await.wisp_close().await
}
}
pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)