force a bounded channel

This commit is contained in:
Toshit Chawda 2024-04-14 17:59:24 -07:00
parent f2021e2382
commit 5af56fe582
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 35 additions and 28 deletions

View file

@ -310,12 +310,12 @@ async fn accept_ws(
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 32 // size is set to 128
let (mut mux, fut) = if mux_options.enforce_auth { let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new( let (mut mux, fut) = ServerMux::new(
rx, rx,
tx, tx,
32, 128,
Some(mux_options.auth.as_slice()), Some(mux_options.auth.as_slice()),
) )
.await?; .await?;
@ -333,7 +333,7 @@ async fn accept_ws(
} }
(mux, fut) (mux, fut)
} else { } else {
ServerMux::new(rx, tx, 32, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? ServerMux::new(rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?
}; };
println!( println!(

View file

@ -92,6 +92,9 @@ struct Cli {
/// Usernames and passwords are sent in plaintext!! /// Usernames and passwords are sent in plaintext!!
#[arg(long)] #[arg(long)]
auth: Option<String>, auth: Option<String>,
/// Make a Wisp V1 connection
#[arg(long)]
wisp_v1: bool,
} }
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
@ -161,7 +164,12 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
extensions.push(Box::new(auth)); extensions.push(Box::new(auth));
} }
let (mut mux, fut) = ClientMux::new(rx, tx, Some(extensions.as_slice())).await?; let (mut mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await?
} else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
};
if opts.udp if opts.udp
&& !mux && !mux
.supported_extension_ids .supported_extension_ids

View file

@ -20,8 +20,7 @@ use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt
select, Future, FutureExt, SinkExt, StreamExt,
}; };
use futures_timer::Delay; use futures_timer::Delay;
use std::{ use std::{
@ -152,7 +151,7 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {} impl std::error::Error for WispError {}
struct MuxMapValue { struct MuxMapValue {
stream: mpsc::UnboundedSender<Bytes>, stream: Mutex<mpsc::Sender<Bytes>>,
stream_type: StreamType, stream_type: StreamType,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>, flow_control_event: Arc<Event>,
@ -209,11 +208,11 @@ impl MuxInner {
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
x = wisp_fut.fuse() => x, x = wisp_fut.fuse() => x,
}; };
self.stream_map.iter_mut().for_each(|mut x| { for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release); x.is_closed.store(true, Ordering::Release);
x.stream.disconnect(); x.stream.lock().await.disconnect();
x.stream.close_channel(); x.stream.lock().await.close_channel();
}); }
self.stream_map.clear(); self.stream_map.clear();
ret ret
} }
@ -235,7 +234,7 @@ impl MuxInner {
} }
WsEvent::CreateStream(stream_type, host, port, channel) => { WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async { let ret: Result<MuxStream, WispError> = async {
let (ch_tx, ch_rx) = mpsc::unbounded(); let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_id = next_free_stream_id; let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id let next_stream_id = next_free_stream_id
.checked_add(1) .checked_add(1)
@ -257,7 +256,7 @@ impl MuxInner {
self.stream_map.insert( self.stream_map.insert(
stream_id, stream_id,
MuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx.into(),
stream_type, stream_type,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(), flow_control_event: flow_control_event.clone(),
@ -281,9 +280,9 @@ impl MuxInner {
let _ = channel.send(ret); let _ = channel.send(ret);
} }
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.disconnect(); stream.stream.lock().await.disconnect();
stream.stream.close_channel(); stream.stream.lock().await.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else { } else {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
@ -326,7 +325,7 @@ impl MuxInner {
use PacketType::*; use PacketType::*;
match packet.packet_type { match packet.packet_type {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded(); let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
let stream_type = inner_packet.stream_type; let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into(); let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into(); let flow_control_event: Arc<Event> = Event::new().into();
@ -335,7 +334,7 @@ impl MuxInner {
self.stream_map.insert( self.stream_map.insert(
packet.stream_id, packet.stream_id,
MuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx.into(),
stream_type, stream_type,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(), flow_control_event: flow_control_event.clone(),
@ -361,7 +360,7 @@ impl MuxInner {
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data); let _ = stream.stream.lock().await.send(data).await;
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream.flow_control.store( stream.flow_control.store(
stream stream
@ -378,10 +377,10 @@ impl MuxInner {
if packet.stream_id == 0 { if packet.stream_id == 0 {
break Ok(()); break Ok(());
} }
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect(); stream.stream.lock().await.disconnect();
stream.stream.close_channel(); stream.stream.lock().await.close_channel();
} }
} }
} }
@ -410,7 +409,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data); let _ = stream.stream.lock().await.send(data).await;
} }
} }
Continue(inner_packet) => { Continue(inner_packet) => {
@ -427,10 +426,10 @@ impl MuxInner {
if packet.stream_id == 0 { if packet.stream_id == 0 {
break Ok(()); break Ok(());
} }
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect(); stream.stream.lock().await.disconnect();
stream.stream.close_channel(); stream.stream.lock().await.close_channel();
} }
} }
} }

View file

@ -38,7 +38,7 @@ pub struct MuxStreamRead {
pub stream_type: StreamType, pub stream_type: StreamType,
role: Role, role: Role,
tx: mpsc::Sender<WsEvent>, tx: mpsc::Sender<WsEvent>,
rx: mpsc::UnboundedReceiver<Bytes>, rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32, flow_control_read: AtomicU32,
@ -193,7 +193,7 @@ impl MuxStream {
stream_id: u32, stream_id: u32,
role: Role, role: Role,
stream_type: StreamType, stream_type: StreamType,
rx: mpsc::UnboundedReceiver<Bytes>, rx: mpsc::Receiver<Bytes>,
tx: mpsc::Sender<WsEvent>, tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,