add autoreconnect, wisp_mux 1.2.0

This commit is contained in:
Toshit Chawda 2024-03-08 22:40:15 -08:00
parent 5b4fb1392a
commit a8709255b2
20 changed files with 404 additions and 333 deletions

View file

@ -11,12 +11,6 @@ mod fastwebsockets;
mod packet;
mod sink_unfold;
mod stream;
#[cfg(feature = "hyper_tower")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper_tower")))]
pub mod tokioio;
#[cfg(feature = "hyper_tower")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper_tower")))]
pub mod tower;
pub mod ws;
pub use crate::packet::*;
@ -140,10 +134,10 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.server_close_loop(close_rx).fuse() => x,
x = self.server_bg_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
};
self.stream_map.lock().await.iter().for_each(|x| {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown)));
@ -151,7 +145,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
ret
}
async fn server_close_loop(
async fn server_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
@ -168,6 +162,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
@ -186,66 +181,62 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
loop {
let packet: Packet = rx.wisp_read_frame(&self.tx).await?.try_into()?;
use PacketType::*;
match packet.packet_type {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
self.stream_map.lock().await.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
},
self.stream_map.lock().await.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
},
);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
flow_control_event,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
stream.flow_control.store(
stream
.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
flow_control_event,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
stream.flow_control.store(
stream.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
}
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
} else {
break;
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
drop(muxstream_sender);
Ok(())
}
}
@ -272,6 +263,8 @@ pub struct ServerMux<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
}
@ -290,7 +283,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self { muxstream_recv: rx },
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
stream_map: map.clone(),
},
ServerMuxInner {
tx: write,
close_tx,
@ -304,6 +301,19 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
self.muxstream_recv.next().await
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
}
struct ClientMuxInner<W>
@ -346,6 +356,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
@ -355,10 +366,11 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
where
R: ws::WebSocketRead,
{
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
match packet.packet_type {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
@ -382,7 +394,6 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
}
}
}
Ok(())
}
}
@ -425,7 +436,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet {
if let PacketType::Continue(packet) = first_packet.packet_type {
let (tx, rx) = mpsc::unbounded::<WsEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
Ok((
@ -487,4 +498,17 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
evt,
))
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
}