diff --git a/server/src/main.rs b/server/src/main.rs index a529872..7e0e581 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -310,12 +310,12 @@ async fn accept_ws( println!("{:?}: connected", addr); // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer - // size is set to 32 + // size is set to 128 let (mut mux, fut) = if mux_options.enforce_auth { let (mut mux, fut) = ServerMux::new( rx, tx, - 32, + 128, Some(mux_options.auth.as_slice()), ) .await?; @@ -333,7 +333,7 @@ async fn accept_ws( } (mux, fut) } else { - ServerMux::new(rx, tx, 32, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? + ServerMux::new(rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? }; println!( diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 97a7b15..198d915 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -92,6 +92,9 @@ struct Cli { /// Usernames and passwords are sent in plaintext!! #[arg(long)] auth: Option, + /// Make a Wisp V1 connection + #[arg(long)] + wisp_v1: bool, } #[tokio::main(flavor = "multi_thread")] @@ -161,7 +164,12 @@ async fn main() -> Result<(), Box> { extensions.push(Box::new(auth)); } - let (mut mux, fut) = ClientMux::new(rx, tx, Some(extensions.as_slice())).await?; + let (mut mux, fut) = if opts.wisp_v1 { + ClientMux::new(rx, tx, None).await? + } else { + ClientMux::new(rx, tx, Some(extensions.as_slice())).await? + }; + if opts.udp && !mux .supported_extension_ids diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index ff732af..1454145 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -20,8 +20,7 @@ use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use futures::{ - channel::{mpsc, oneshot}, - select, Future, FutureExt, SinkExt, StreamExt, + channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt }; use futures_timer::Delay; use std::{ @@ -152,7 +151,7 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} struct MuxMapValue { - stream: mpsc::UnboundedSender, + stream: Mutex>, stream_type: StreamType, flow_control: Arc, flow_control_event: Arc, @@ -209,11 +208,11 @@ impl MuxInner { _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), x = wisp_fut.fuse() => x, }; - self.stream_map.iter_mut().for_each(|mut x| { + for x in self.stream_map.iter_mut() { x.is_closed.store(true, Ordering::Release); - x.stream.disconnect(); - x.stream.close_channel(); - }); + x.stream.lock().await.disconnect(); + x.stream.lock().await.close_channel(); + } self.stream_map.clear(); ret } @@ -235,7 +234,7 @@ impl MuxInner { } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { - let (ch_tx, ch_rx) = mpsc::unbounded(); + let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_id = next_free_stream_id; let next_stream_id = next_free_stream_id .checked_add(1) @@ -257,7 +256,7 @@ impl MuxInner { self.stream_map.insert( stream_id, MuxMapValue { - stream: ch_tx, + stream: ch_tx.into(), stream_type, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), @@ -281,9 +280,9 @@ impl MuxInner { let _ = channel.send(ret); } WsEvent::Close(packet, channel) => { - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { - stream.stream.disconnect(); - stream.stream.close_channel(); + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); let _ = channel.send(self.tx.write_frame(packet.into()).await); } else { let _ = channel.send(Err(WispError::InvalidStreamId)); @@ -326,7 +325,7 @@ impl MuxInner { use PacketType::*; match packet.packet_type { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::unbounded(); + let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_type = inner_packet.stream_type; let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); let flow_control_event: Arc = Event::new().into(); @@ -335,7 +334,7 @@ impl MuxInner { self.stream_map.insert( packet.stream_id, MuxMapValue { - stream: ch_tx, + stream: ch_tx.into(), stream_type, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), @@ -361,7 +360,7 @@ impl MuxInner { } Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); + let _ = stream.stream.lock().await.send(data).await; if stream.stream_type == StreamType::Tcp { stream.flow_control.store( stream @@ -378,10 +377,10 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); } } } @@ -410,7 +409,7 @@ impl MuxInner { Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); + let _ = stream.stream.lock().await.send(data).await; } } Continue(inner_packet) => { @@ -427,10 +426,10 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); } } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 69c711b..1a8c2da 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -38,7 +38,7 @@ pub struct MuxStreamRead { pub stream_type: StreamType, role: Role, tx: mpsc::Sender, - rx: mpsc::UnboundedReceiver, + rx: mpsc::Receiver, is_closed: Arc, flow_control: Arc, flow_control_read: AtomicU32, @@ -193,7 +193,7 @@ impl MuxStream { stream_id: u32, role: Role, stream_type: StreamType, - rx: mpsc::UnboundedReceiver, + rx: mpsc::Receiver, tx: mpsc::Sender, is_closed: Arc, flow_control: Arc,