more improvements and fix wisp impl

This commit is contained in:
Toshit Chawda 2024-02-07 08:38:37 -08:00
parent 1a897ec03a
commit 85a30aeec5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
12 changed files with 478 additions and 413 deletions

View file

@ -3,17 +3,18 @@
mod fastwebsockets;
mod packet;
mod stream;
pub mod ws;
#[cfg(feature = "ws_stream_wasm")]
mod ws_stream_wasm;
#[cfg(feature = "hyper_tower")]
pub mod tokioio;
#[cfg(feature = "hyper_tower")]
pub mod tower;
pub mod ws;
#[cfg(feature = "ws_stream_wasm")]
mod ws_stream_wasm;
pub use crate::packet::*;
pub use crate::stream::*;
use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
collections::HashMap,
@ -23,7 +24,7 @@ use std::{
},
};
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum Role {
Client,
Server,
@ -96,13 +97,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
buffer_size: u32
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender).fuse() => x
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
};
self.stream_map.lock().await.iter().for_each(|x| {
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01)));
@ -137,12 +139,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
&self,
mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
buffer_size: u32,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.tx
.write_frame(Packet::new_continue(0, u32::MAX).into())
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
@ -157,10 +160,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
AtomicU32::new(buffer_size).into(),
Event::new().into(),
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
@ -168,11 +174,6 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
self.tx
.write_frame(
Packet::new_continue(packet.stream_id, u32::MAX).into(),
)
.await?;
}
}
Continue(_) => unreachable!(),
@ -200,7 +201,7 @@ where
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn new<R>(read: R, write: W) -> (Self, impl Future<Output = Result<(), WispError>>)
pub fn new<R>(read: R, write: W, buffer_size: u32) -> (Self, impl Future<Output = Result<(), WispError>>)
where
R: ws::WebSocketRead,
{
@ -215,7 +216,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
close_tx,
stream_map: map.clone(),
}
.into_future(read, close_rx, tx),
.into_future(read, close_rx, tx, buffer_size),
)
}
@ -229,7 +230,8 @@ where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
stream_map:
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
}
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
@ -280,13 +282,20 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
let _ = stream.0.unbounded_send(WsEvent::Send(data));
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
stream
.1
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.2.notify(u32::MAX);
}
}
Continue(_) => {}
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
let _ = stream.0.unbounded_send(WsEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
@ -302,32 +311,46 @@ where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
stream_map:
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<MuxEvent>,
buf_size: u32,
}
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
pub fn new<R>(read: R, write: W) -> (Self, impl Future<Output = Result<(), WispError>>)
pub async fn new<R>(
mut read: R,
write: W,
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError>
where
R: ws::WebSocketRead,
{
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
let write = ws::LockedWebSocketWrite::new(write);
(
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
},
ClientMuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
)
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet {
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
Ok((
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
buf_size: packet.buffer_remaining,
},
ClientMuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
))
} else {
Err(WispError::InvalidPacketType)
}
}
pub async fn client_new_stream(
@ -337,6 +360,8 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
port: u16,
) -> Result<MuxStream<W>, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let evt: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
self.tx
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
@ -347,13 +372,19 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
.ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release,
);
self.stream_map.lock().await.insert(stream_id, ch_tx);
self.stream_map
.lock()
.await
.insert(stream_id, (ch_tx, flow_control.clone(), evt.clone()));
Ok(MuxStream::new(
stream_id,
Role::Client,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
evt,
))
}
}