diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index f9dce1c..213d214 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -27,9 +27,9 @@ pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; /// The role of the multiplexor. #[derive(Debug, PartialEq, Copy, Clone)] pub enum Role { - /// Client side, can create new channels to proxy. + /// Client side, can create new streams. Client, - /// Server side, can listen for channels to proxy. + /// Server side, can listen for streams created by the client. Server, } diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs index 54abb30..2544c1e 100644 --- a/wisp/src/mux/client.rs +++ b/wisp/src/mux/client.rs @@ -12,12 +12,12 @@ use futures::channel::oneshot; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, inner::{MuxInner, WsEvent}, - ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload}, + ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; -use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; +use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; /// Client side multiplexor. pub struct ClientMux { @@ -42,7 +42,7 @@ impl ClientMux { mut rx: R, tx: W, wisp_v2: Option, - ) -> Result> + Send>, WispError> + ) -> Result> + Send>, WispError> where R: WebSocketRead + Send, W: WebSocketWrite + Send + 'static, @@ -78,7 +78,7 @@ impl ClientMux { packet.buffer_remaining, ); - Ok(ClientMuxResult( + Ok(MuxResult( Self { actor_tx: mux_result.actor_tx, downgraded, @@ -174,50 +174,13 @@ impl Drop for ClientMux { } } -/// Result of `ClientMux::new`. -pub struct ClientMuxResult(ClientMux, F) -where - F: Future> + Send; - -impl ClientMuxResult -where - F: Future> + Send, -{ - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ClientMux, F) { - (self.0, self.1) +impl Multiplexor for ClientMux { + fn has_extension(&self, extension_id: u8) -> bool { + self.supported_extensions + .iter() + .any(|x| x.get_id() == extension_id) } - - /// Require protocol extensions by their ID. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ClientMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self - .0 - .supported_extensions - .iter() - .any(|x| x.get_id() == *extension) - { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0 - .close_with_reason(CloseReason::ExtensionsIncompatible) - .await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } - - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await + async fn exit(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_with_reason(reason).await } } diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index 63a0a61..1212359 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -2,15 +2,15 @@ mod client; mod server; use std::{future::Future, pin::Pin, time::Duration}; -pub use client::{ClientMux, ClientMuxResult}; +pub use client::ClientMux; use futures::{select, FutureExt}; use futures_timer::Delay; -pub use server::{ServerMux, ServerMuxResult}; +pub use server::ServerMux; use crate::{ - extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, + extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, ws::{Frame, LockedWebSocketWrite, WebSocketRead}, - Packet, PacketType, Role, WispError, + CloseReason, Packet, PacketType, Role, WispError, }; async fn maybe_wisp_v2( @@ -67,15 +67,61 @@ async fn send_info_packet( .await } +trait Multiplexor { + fn has_extension(&self, extension_id: u8) -> bool; + async fn exit(&self, reason: CloseReason) -> Result<(), WispError>; +} + +/// Result of creating a multiplexor. Helps require protocol extensions. +#[allow(private_bounds)] +pub struct MuxResult(M, F) +where + M: Multiplexor, + F: Future> + Send; + +#[allow(private_bounds)] +impl MuxResult +where + M: Multiplexor, + F: Future> + Send, +{ + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (M, F) { + (self.0, self.1) + } + + /// Require protocol extensions by their ID. Will close the multiplexor connection if + /// extensions are not supported. + pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(M, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self.0.has_extension(*extension) { + unsupported_extensions.push(*extension); + } + } + + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0.exit(CloseReason::ExtensionsIncompatible).await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } + + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(M, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } +} + +type WispV2ClosureResult = Pin> + Sync + Send>>; +type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder]; /// Wisp V2 handshake and protocol extension settings wrapper struct. pub struct WispV2Extensions { builders: Vec, - closure: Box< - dyn Fn( - &mut [AnyProtocolExtensionBuilder], - ) -> Pin> + Sync + Send>> - + Send, - >, + closure: Box WispV2ClosureResult + Send>, } impl WispV2Extensions { @@ -90,11 +136,7 @@ impl WispV2Extensions { /// Create a Wisp V2 settings struct with some middleware. pub fn new_with_middleware(builders: Vec, closure: C) -> Self where - C: Fn( - &mut [AnyProtocolExtensionBuilder], - ) -> Pin> + Sync + Send>> - + Send - + 'static, + C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static, { Self { builders, diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs index b1ee07f..aefbf38 100644 --- a/wisp/src/mux/server.rs +++ b/wisp/src/mux/server.rs @@ -11,13 +11,13 @@ use flume as mpsc; use futures::channel::oneshot; use crate::{ - extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, + extensions::AnyProtocolExtension, inner::{MuxInner, WsEvent}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError, }; -use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; +use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; /// Server-side multiplexor. pub struct ServerMux { @@ -44,7 +44,7 @@ impl ServerMux { tx: W, buffer_size: u32, wisp_v2: Option, - ) -> Result> + Send>, WispError> + ) -> Result> + Send>, WispError> where R: WebSocketRead + Send, W: WebSocketWrite + Send + 'static, @@ -74,7 +74,7 @@ impl ServerMux { buffer_size, ); - Ok(ServerMuxResult( + Ok(MuxResult( Self { muxstream_recv, actor_tx: mux_result.actor_tx, @@ -175,51 +175,13 @@ impl Drop for ServerMux { } } -/// Result of `ServerMux::new`. -pub struct ServerMuxResult(ServerMux, F) -where - F: Future> + Send; - -impl ServerMuxResult -where - F: Future> + Send, -{ - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ServerMux, F) { - (self.0, self.1) +impl Multiplexor for ServerMux { + fn has_extension(&self, extension_id: u8) -> bool { + self.supported_extensions + .iter() + .any(|x| x.get_id() == extension_id) } - - /// Require protocol extensions by their ID. Will close the multiplexor connection if - /// extensions are not supported. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ServerMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self - .0 - .supported_extensions - .iter() - .any(|x| x.get_id() == *extension) - { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0 - .close_with_reason(CloseReason::ExtensionsIncompatible) - .await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } - - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await + async fn exit(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_with_reason(reason).await } }