diff --git a/Cargo.lock b/Cargo.lock index 33b2300..d68ecaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3230,7 +3230,6 @@ dependencies = [ "flume", "futures", "futures-timer", - "futures-util", "pin-project-lite", "tokio", ] diff --git a/client/src/utils.rs b/client/src/utils.rs index 98717a5..f0279e3 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -203,7 +203,11 @@ pub async fn make_mux( let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await + Ok( + ClientMux::create(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])) + .await? + .with_no_required_extensions(), + ) } pub fn spawn_mux_fut( diff --git a/server/src/main.rs b/server/src/main.rs index 1910a99..68b3da9 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -253,45 +253,39 @@ async fn accept_http( } } -async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result { +async fn handle_mux( + packet: ConnectPacket, + stream: MuxStream, +) -> Result> { let uri = format!( "{}:{}", packet.destination_hostname, packet.destination_port ); match packet.stream_type { StreamType::Tcp => { - let mut tcp_stream = TcpStream::connect(uri) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; + let mut tcp_stream = TcpStream::connect(uri).await?; let mut mux_stream = stream.into_io().into_asyncrw(); - copy_bidirectional(&mut mux_stream, &mut tcp_stream) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; + copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?; } StreamType::Udp => { let uri = lookup_host(uri) - .await - .map_err(|x| WispError::Other(Box::new(x)))? + .await? .next() .ok_or(WispError::InvalidUri)?; - let udp_socket = UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; - udp_socket - .connect(uri) - .await - .map_err(|x| WispError::Other(Box::new(x)))?; + let udp_socket = + UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?; + udp_socket.connect(uri).await?; let mut data = vec![0u8; 65507]; // udp standard max datagram size loop { tokio::select! { - size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => { + size = udp_socket.recv(&mut data) => { let size = size?; stream.write(Bytes::copy_from_slice(&data[..size])).await? }, event = stream.read() => { match event { Some(event) => { - let _ = udp_socket.send(&event).await.map_err(|x| WispError::Other(Box::new(x)))?; + let _ = udp_socket.send(&event).await?; } None => break, } @@ -319,28 +313,18 @@ async fn accept_ws( // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // size is set to 128 let (mux, fut) = if mux_options.enforce_auth { - let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?; - if !mux - .supported_extension_ids - .iter() - .any(|x| *x == PasswordProtocolExtension::ID) - { - println!( - "{:?}: client did not support auth or password was invalid", - addr - ); - mux.close_extension_incompat().await?; - return Ok(()); - } - (mux, fut) + ServerMux::create(rx, tx, 128, Some(mux_options.auth.as_slice())) + .await? + .with_required_extensions(&[PasswordProtocolExtension::ID]).await? } else { - ServerMux::new( + ServerMux::create( rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())]), ) .await? + .with_no_required_extensions() }; println!( diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index fecbf44..a478dc6 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -156,53 +156,33 @@ async fn main() -> Result<(), Box> { let rx = FragmentCollectorRead::new(rx); let mut extensions: Vec> = Vec::new(); + let mut extension_ids: Vec = Vec::new(); if opts.udp { extensions.push(Box::new(UdpProtocolExtensionBuilder())); + extension_ids.push(UdpProtocolExtension::ID); } - let enforce_auth = auth.is_some(); if let Some(auth) = auth { extensions.push(Box::new(auth)); + extension_ids.push(PasswordProtocolExtension::ID); } let (mux, fut) = if opts.wisp_v1 { - ClientMux::new(rx, tx, None).await? + ClientMux::create(rx, tx, None) + .await? + .with_no_required_extensions() } else { - ClientMux::new(rx, tx, Some(extensions.as_slice())).await? + ClientMux::create(rx, tx, Some(extensions.as_slice())) + .await? + .with_required_extensions(extension_ids.as_slice()).await? }; - if opts.udp - && !mux - .supported_extension_ids - .iter() - .any(|x| *x == UdpProtocolExtension::ID) - { - println!( - "server did not support udp, was downgraded {}, extensions supported {:?}", - mux.downgraded, mux.supported_extension_ids - ); - mux.close_extension_incompat().await?; - exit(1); - } - if enforce_auth - && !mux - .supported_extension_ids - .iter() - .any(|x| *x == PasswordProtocolExtension::ID) - { - println!( - "server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}", - mux.downgraded, mux.supported_extension_ids - ); - mux.close_extension_incompat().await?; - exit(1); - } - println!( "connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids ); - let mut threads = Vec::with_capacity(opts.streams * 2 + 3); + let mut threads = Vec::with_capacity(opts.streams + 4); + let mut reads = Vec::with_capacity(opts.streams); threads.push(tokio::spawn(fut)); @@ -226,13 +206,15 @@ async fn main() -> Result<(), Box> { #[allow(unreachable_code)] Ok::<(), WispError>(()) })); - threads.push(tokio::spawn(async move { - loop { - cr.read().await; - } - })); + reads.push(cr); } + threads.push(tokio::spawn(async move { + loop { + select_all(reads.iter().map(|x| Box::pin(x.read()))).await; + } + })); + let cnt_avg = cnt.clone(); threads.push(tokio::spawn(async move { let mut interval = interval(Duration::from_millis(100)); @@ -295,14 +277,16 @@ async fn main() -> Result<(), Box> { mux.close().await?; - println!( - "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", - cnt.get(), - opts.packet_size, - cnt.get() * opts.packet_size, - format_duration(duration_since), - (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), - ); + if duration_since.as_secs() != 0 { + println!( + "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", + cnt.get(), + opts.packet_size, + cnt.get() * opts.packet_size, + format_duration(duration_since), + (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), + ); + } Ok(()) } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 795a1c6..d34e360 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -18,7 +18,6 @@ fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" -futures-util = "0.3.30" pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true, default-features = false } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 1cf170f..1b88da8 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -80,8 +80,8 @@ pub enum WispError { ExtensionImplError(Box), /// The protocol extension implementation did not support the action. ExtensionImplNotSupported, - /// The UDP protocol extension is not supported by the server. - UdpExtensionNotSupported, + /// The specified protocol extensions are not supported by the server. + ExtensionsNotSupported(Vec), /// The string was invalid UTF-8. Utf8Error(std::str::Utf8Error), /// The integer failed to convert. @@ -137,7 +137,9 @@ impl std::fmt::Display for WispError { "Protocol extension implementation error: unsupported feature" ) } - Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"), + Self::ExtensionsNotSupported(list) => { + write!(f, "Protocol extensions {:?} not supported", list) + } Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), Self::Other(err) => write!(f, "Other error: {}", err), @@ -483,12 +485,12 @@ impl ServerMux { /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. - pub async fn new( + pub async fn create( mut read: R, write: W, buffer_size: u32, extension_builders: Option<&[Box]>, - ) -> Result<(Self, impl Future> + Send), WispError> + ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, @@ -532,7 +534,7 @@ impl ServerMux { } } - Ok(( + Ok(ServerMuxResult( Self { muxstream_recv: rx, close_tx: close_tx.clone(), @@ -590,6 +592,48 @@ impl Drop for ServerMux { } } +/// Result of `ServerMux::new`. +pub struct ServerMuxResult(ServerMux, F) +where + F: Future> + Send; + +impl ServerMuxResult +where + F: Future> + Send, +{ + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ServerMux, F) { + (self.0, self.1) + } + + /// Require protocol extensions by their ID. Will close the multiplexor connection if + /// extensions are not supported. + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ServerMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self.0.supported_extension_ids.contains(extension) { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0.close_extension_incompat().await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } + + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } +} + /// Client side multiplexor. /// /// # Example @@ -620,11 +664,11 @@ impl ClientMux { /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. - pub async fn new( + pub async fn create( mut read: R, write: W, extension_builders: Option<&[Box]>, - ) -> Result<(Self, impl Future> + Send), WispError> + ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, @@ -671,7 +715,7 @@ impl ClientMux { } let (tx, rx) = mpsc::bounded::(256); - Ok(( + Ok(ClientMuxResult( Self { stream_tx: tx.clone(), downgraded, @@ -710,7 +754,9 @@ impl ClientMux { .iter() .any(|x| *x == UdpProtocolExtension::ID) { - return Err(WispError::UdpExtensionNotSupported); + return Err(WispError::ExtensionsNotSupported(vec![ + UdpProtocolExtension::ID, + ])); } let (tx, rx) = oneshot::channel(); self.stream_tx @@ -750,3 +796,40 @@ impl Drop for ClientMux { let _ = self.stream_tx.send(WsEvent::EndFut(None)); } } + +/// Result of `ClientMux::new`. +pub struct ClientMuxResult(ClientMux, F) +where + F: Future> + Send; + +impl ClientMuxResult +where + F: Future> + Send, +{ + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ClientMux, F) { + (self.0, self.1) + } + + /// Require protocol extensions by their ID. + pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self.0.supported_extension_ids.contains(extension) { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0.close_extension_incompat().await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } + + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]).await + } +}