remove the mutex<hashmap> in wisp_mux, other improvements

This commit is contained in:
Toshit Chawda 2024-03-26 18:55:54 -07:00
parent ff2a1ad269
commit 7001ee8fa5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
16 changed files with 346 additions and 309 deletions

View file

@ -16,14 +16,13 @@ pub use crate::packet::*;
pub use crate::stream::*;
use bytes::Bytes;
use dashmap::DashMap;
use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
use futures::SinkExt;
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
};
/// The role of the multiplexor.
@ -72,6 +71,8 @@ pub enum WispError {
Utf8Error(std::str::Utf8Error),
/// Other error.
Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task.
MuxMessageFailedToSend,
}
impl From<std::str::Utf8Error> for WispError {
@ -82,25 +83,29 @@ impl From<std::str::Utf8Error> for WispError {
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WispError::*;
match self {
PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"),
InvalidCloseReason => write!(f, "Invalid close reason"),
InvalidUri => write!(f, "Invalid URI"),
UriHasNoHost => write!(f, "URI has no host"),
UriHasNoPort => write!(f, "URI has no port"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Other(err) => write!(f, "Other error: {}", err),
Self::PacketTooSmall => write!(f, "Packet too small"),
Self::InvalidPacketType => write!(f, "Invalid packet type"),
Self::InvalidStreamType => write!(f, "Invalid stream type"),
Self::InvalidStreamId => write!(f, "Invalid stream id"),
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
Self::InvalidUri => write!(f, "Invalid URI"),
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
Self::WsImplSocketClosed => {
write!(f, "Websocket implementation error: websocket closed")
}
Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature")
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
}
}
}
@ -115,61 +120,74 @@ struct MuxMapValue {
is_closed: Arc<AtomicBool>,
}
struct ServerMuxInner<W>
struct MuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
stream_map: Arc<DashMap<u32, MuxMapValue>>,
}
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
pub async fn into_future<R>(
impl<W: ws::WebSocketWrite> MuxInner<W> {
pub async fn server_into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.into_future(
close_rx,
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
)
.await
}
pub async fn client_into_future<R>(
self,
rx: R,
close_rx: mpsc::Receiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.into_future(close_rx, self.client_loop(rx)).await
}
async fn into_future(
&self,
close_rx: mpsc::Receiver<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> {
let ret = futures::select! {
x = self.server_bg_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
_ = self.stream_loop(close_rx).fuse() => Ok(()),
x = wisp_fut.fuse() => x,
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
self.stream_map.iter_mut().for_each(|mut x| {
x.is_closed.store(true, Ordering::Release);
x.stream.disconnect();
x.stream.close_channel();
});
self.stream_map.clear();
ret
}
async fn server_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
while let Some(msg) = stream_rx.next().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
if self.stream_map.get(&packet.stream_id).is_some() {
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
@ -180,20 +198,20 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
WsEvent::EndFut => break,
}
}
Ok(())
}
async fn server_msg_loop<R>(
async fn server_loop<R>(
&self,
mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
// will send continues once flow_control is at 10% of max
let target_buffer_size = buffer_size * 90 / 100;
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
self.tx
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
@ -214,7 +232,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.stream_map.lock().await.insert(
self.stream_map.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
@ -232,7 +250,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
Role::Server,
stream_type,
ch_rx,
self.close_tx.clone(),
close_tx.clone(),
is_closed,
flow_control,
flow_control_event,
@ -242,7 +260,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
@ -257,9 +275,47 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
}
Continue(_) => break Err(WispError::InvalidPacketType),
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
@ -290,8 +346,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
/// }
/// ```
pub struct ServerMux {
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
}
@ -305,22 +360,19 @@ impl ServerMux {
where
R: ws::WebSocketRead,
{
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
stream_map: map.clone(),
},
ServerMuxInner {
MuxInner {
tx: write,
close_tx,
stream_map: map.clone(),
stream_map: DashMap::new().into(),
}
.into_future(read, close_rx, tx, buffer_size),
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
)
}
@ -333,124 +385,13 @@ impl ServerMux {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}
struct ClientMuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
}
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
pub(crate) async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
ret
}
async fn client_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
}
}
/// Client side multiplexor.
///
/// # Example
@ -470,9 +411,9 @@ where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
stream_map: Arc<DashMap<u32, MuxMapValue>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
buf_size: u32,
target_buf_size: u32,
}
@ -492,23 +433,23 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (tx, rx) = mpsc::unbounded::<WsEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = mpsc::channel::<WsEvent>(256);
let map = Arc::new(DashMap::new());
Ok((
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
close_tx: tx.clone(),
buf_size: packet.buffer_remaining,
// server-only
target_buf_size: 0,
},
ClientMuxInner {
MuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
.client_into_future(read, rx),
))
} else {
Err(WispError::InvalidPacketType)
@ -540,7 +481,7 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
self.next_free_stream_id
.store(next_stream_id, Ordering::Release);
self.stream_map.lock().await.insert(
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
@ -568,12 +509,10 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}