force a bounded channel

This commit is contained in:
Toshit Chawda 2024-04-14 17:59:24 -07:00
parent f2021e2382
commit 5af56fe582
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 35 additions and 28 deletions

View file

@ -20,8 +20,7 @@ use dashmap::DashMap;
use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use futures::{
channel::{mpsc, oneshot},
select, Future, FutureExt, SinkExt, StreamExt,
channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt
};
use futures_timer::Delay;
use std::{
@ -152,7 +151,7 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::UnboundedSender<Bytes>,
stream: Mutex<mpsc::Sender<Bytes>>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
@ -209,11 +208,11 @@ impl MuxInner {
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
x = wisp_fut.fuse() => x,
};
self.stream_map.iter_mut().for_each(|mut x| {
for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release);
x.stream.disconnect();
x.stream.close_channel();
});
x.stream.lock().await.disconnect();
x.stream.lock().await.close_channel();
}
self.stream_map.clear();
ret
}
@ -235,7 +234,7 @@ impl MuxInner {
}
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let (ch_tx, ch_rx) = mpsc::unbounded();
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)
@ -257,7 +256,7 @@ impl MuxInner {
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
@ -281,9 +280,9 @@ impl MuxInner {
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.disconnect();
stream.stream.close_channel();
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
@ -326,7 +325,7 @@ impl MuxInner {
use PacketType::*;
match packet.packet_type {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
@ -335,7 +334,7 @@ impl MuxInner {
self.stream_map.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
stream: ch_tx.into(),
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
@ -361,7 +360,7 @@ impl MuxInner {
}
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
let _ = stream.stream.lock().await.send(data).await;
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
@ -378,10 +377,10 @@ impl MuxInner {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
}
}
}
@ -410,7 +409,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
let _ = stream.stream.lock().await.send(data).await;
}
}
Continue(inner_packet) => {
@ -427,10 +426,10 @@ impl MuxInner {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
stream.stream.lock().await.disconnect();
stream.stream.lock().await.close_channel();
}
}
}