mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
remove the mutex<hashmap> in wisp_mux, other improvements
This commit is contained in:
parent
ff2a1ad269
commit
7001ee8fa5
16 changed files with 346 additions and 309 deletions
341
wisp/src/lib.rs
341
wisp/src/lib.rs
|
@ -16,14 +16,13 @@ pub use crate::packet::*;
|
|||
pub use crate::stream::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use event_listener::Event;
|
||||
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
use futures::SinkExt;
|
||||
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
/// The role of the multiplexor.
|
||||
|
@ -72,6 +71,8 @@ pub enum WispError {
|
|||
Utf8Error(std::str::Utf8Error),
|
||||
/// Other error.
|
||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||
/// Failed to send message to multiplexor task.
|
||||
MuxMessageFailedToSend,
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
|
@ -82,25 +83,29 @@ impl From<std::str::Utf8Error> for WispError {
|
|||
|
||||
impl std::fmt::Display for WispError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
use WispError::*;
|
||||
match self {
|
||||
PacketTooSmall => write!(f, "Packet too small"),
|
||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
InvalidStreamId => write!(f, "Invalid stream id"),
|
||||
InvalidCloseReason => write!(f, "Invalid close reason"),
|
||||
InvalidUri => write!(f, "Invalid URI"),
|
||||
UriHasNoHost => write!(f, "URI has no host"),
|
||||
UriHasNoPort => write!(f, "URI has no port"),
|
||||
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||
StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
|
||||
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::PacketTooSmall => write!(f, "Packet too small"),
|
||||
Self::InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
Self::InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
Self::InvalidStreamId => write!(f, "Invalid stream id"),
|
||||
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
|
||||
Self::InvalidUri => write!(f, "Invalid URI"),
|
||||
Self::UriHasNoHost => write!(f, "URI has no host"),
|
||||
Self::UriHasNoPort => write!(f, "URI has no port"),
|
||||
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
|
||||
Self::WsImplSocketClosed => {
|
||||
write!(f, "Websocket implementation error: websocket closed")
|
||||
}
|
||||
Self::WsImplNotSupported => {
|
||||
write!(f, "Websocket implementation error: unsupported feature")
|
||||
}
|
||||
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -115,61 +120,74 @@ struct MuxMapValue {
|
|||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
struct ServerMuxInner<W>
|
||||
struct MuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
||||
pub async fn into_future<R>(
|
||||
impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||
pub async fn server_into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(
|
||||
close_rx,
|
||||
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn client_into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(close_rx, self.client_loop(rx)).await
|
||||
}
|
||||
|
||||
async fn into_future(
|
||||
&self,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
||||
) -> Result<(), WispError> {
|
||||
let ret = futures::select! {
|
||||
x = self.server_bg_loop(close_rx).fuse() => x,
|
||||
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
|
||||
_ = self.stream_loop(close_rx).fuse() => Ok(()),
|
||||
x = wisp_fut.fuse() => x,
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
self.stream_map.iter_mut().for_each(|mut x| {
|
||||
x.is_closed.store(true, Ordering::Release);
|
||||
x.stream.disconnect();
|
||||
x.stream.close_channel();
|
||||
});
|
||||
self.stream_map.clear();
|
||||
ret
|
||||
}
|
||||
|
||||
async fn server_bg_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
|
||||
while let Some(msg) = stream_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
|
@ -180,20 +198,20 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
WsEvent::EndFut => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_msg_loop<R>(
|
||||
async fn server_loop<R>(
|
||||
&self,
|
||||
mut rx: R,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
// will send continues once flow_control is at 10% of max
|
||||
let target_buffer_size = buffer_size * 90 / 100;
|
||||
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
|
||||
self.tx
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
@ -214,7 +232,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
self.stream_map.insert(
|
||||
packet.stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
|
@ -232,7 +250,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
Role::Server,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.close_tx.clone(),
|
||||
close_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
|
@ -242,7 +260,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream.flow_control.store(
|
||||
|
@ -257,9 +275,47 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
}
|
||||
Continue(_) => break Err(WispError::InvalidPacketType),
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
loop {
|
||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||
if frame.opcode == ws::OpCode::Close {
|
||||
break Ok(());
|
||||
}
|
||||
let packet = Packet::try_from(frame)?;
|
||||
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream
|
||||
.flow_control
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
Close(_) => {
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
|
@ -290,8 +346,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
/// }
|
||||
/// ```
|
||||
pub struct ServerMux {
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
||||
}
|
||||
|
||||
|
@ -305,22 +360,19 @@ impl ServerMux {
|
|||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
|
||||
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
(
|
||||
Self {
|
||||
muxstream_recv: rx,
|
||||
close_tx: close_tx.clone(),
|
||||
stream_map: map.clone(),
|
||||
},
|
||||
ServerMuxInner {
|
||||
MuxInner {
|
||||
tx: write,
|
||||
close_tx,
|
||||
stream_map: map.clone(),
|
||||
stream_map: DashMap::new().into(),
|
||||
}
|
||||
.into_future(read, close_rx, tx, buffer_size),
|
||||
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -333,124 +385,13 @@ impl ServerMux {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||
/// this function is called.
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut)
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientMuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
|
||||
pub(crate) async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let ret = futures::select! {
|
||||
x = self.client_bg_loop(close_rx).fuse() => x,
|
||||
x = self.client_loop(rx).fuse() => x
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
ret
|
||||
}
|
||||
|
||||
async fn client_bg_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::EndFut => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
loop {
|
||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||
if frame.opcode == ws::OpCode::Close {
|
||||
break Ok(());
|
||||
}
|
||||
let packet = Packet::try_from(frame)?;
|
||||
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream
|
||||
.flow_control
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -470,9 +411,9 @@ where
|
|||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
next_free_stream_id: AtomicU32,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
buf_size: u32,
|
||||
target_buf_size: u32,
|
||||
}
|
||||
|
@ -492,23 +433,23 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let (tx, rx) = mpsc::unbounded::<WsEvent>();
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
let (tx, rx) = mpsc::channel::<WsEvent>(256);
|
||||
let map = Arc::new(DashMap::new());
|
||||
Ok((
|
||||
Self {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
next_free_stream_id: AtomicU32::new(1),
|
||||
close_tx: tx,
|
||||
close_tx: tx.clone(),
|
||||
buf_size: packet.buffer_remaining,
|
||||
// server-only
|
||||
target_buf_size: 0,
|
||||
},
|
||||
ClientMuxInner {
|
||||
MuxInner {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
}
|
||||
.into_future(read, rx),
|
||||
.client_into_future(read, rx),
|
||||
))
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
|
@ -540,7 +481,7 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
self.next_free_stream_id
|
||||
.store(next_stream_id, Ordering::Release);
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
self.stream_map.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
|
@ -568,12 +509,10 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||
/// function.
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut)
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue