From 2efb6412288365520fa90cca18e1a22d6977dcdc Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 23 Oct 2024 23:00:23 -0700 Subject: [PATCH] separate clientmux and servermux into new files --- wisp/src/lib.rs | 521 +---------------------------------------- wisp/src/mux/client.rs | 223 ++++++++++++++++++ wisp/src/mux/mod.rs | 109 +++++++++ wisp/src/mux/server.rs | 225 ++++++++++++++++++ 4 files changed, 559 insertions(+), 519 deletions(-) create mode 100644 wisp/src/mux/client.rs create mode 100644 wisp/src/mux/mod.rs create mode 100644 wisp/src/mux/server.rs diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index c3505dc..73a8c90 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -12,29 +12,15 @@ mod fastwebsockets; #[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))] pub mod generic; mod inner; +mod mux; mod packet; mod sink_unfold; mod stream; pub mod ws; -pub use crate::{packet::*, stream::*}; +pub use crate::{mux::*, packet::*, stream::*}; -use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}; -use flume as mpsc; -use futures::{channel::oneshot, select, Future, FutureExt}; -use futures_timer::Delay; -use inner::{MuxInner, WsEvent}; -use std::{ - ops::DerefMut, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; use thiserror::Error; -use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload}; /// Wisp version supported by this crate. pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; @@ -128,506 +114,3 @@ pub enum WispError { #[error("Certificate authentication protocol extension: Invalid signature")] CertAuthExtensionSigInvalid, } - -async fn maybe_wisp_v2( - read: &mut R, - write: &LockedWebSocketWrite, - role: Role, - builders: &mut [AnyProtocolExtensionBuilder], -) -> Result<(Vec, Option>, bool), WispError> -where - R: ws::WebSocketRead + Send, -{ - let mut supported_extensions = Vec::new(); - let mut extra_packet: Option> = None; - let mut downgraded = true; - - let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(write).fuse() => Some(x?), - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, role, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.replace(ws::Frame::from(packet).clone()); - } - } - - for extension in supported_extensions.iter_mut() { - extension.handle_handshake(read, write).await?; - } - Ok((supported_extensions, extra_packet, downgraded)) -} - -async fn send_info_packet( - write: &LockedWebSocketWrite, - builders: &mut [AnyProtocolExtensionBuilder], -) -> Result<(), WispError> { - write - .write_frame( - Packet::new_info( - builders - .iter_mut() - .map(|x| x.build_to_extension(Role::Server)) - .collect::, _>>()?, - ) - .into(), - ) - .await -} - -/// Wisp V2 handshake and protocol extension settings wrapper struct. -pub struct WispV2Extensions { - builders: Vec, - closure: Box< - dyn Fn( - &mut [AnyProtocolExtensionBuilder], - ) -> Pin> + Sync + Send>> - + Send, - >, -} - -impl WispV2Extensions { - /// Create a Wisp V2 settings struct with no middleware. - pub fn new(builders: Vec) -> Self { - Self { - builders, - closure: Box::new(|_| Box::pin(async { Ok(()) })), - } - } - - /// Create a Wisp V2 settings struct with some middleware. - pub fn new_with_middleware(builders: Vec, closure: C) -> Self - where - C: Fn( - &mut [AnyProtocolExtensionBuilder], - ) -> Pin> + Sync + Send>> - + Send - + 'static, - { - Self { - builders, - closure: Box::new(closure), - } - } - - /// Add a Wisp V2 extension builder to the settings struct. - pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) { - self.builders.push(extension); - } -} - -/// Server-side multiplexor. -pub struct ServerMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extensions: Vec, - actor_tx: mpsc::Sender, - muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, - tx: ws::LockedWebSocketWrite, - actor_exited: Arc, -} - -impl ServerMux { - /// Create a new server-side multiplexor. - /// - /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. - /// **It is not guaranteed that all extensions you specify are available.** You must manually check - /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut rx: R, - tx: W, - buffer_size: u32, - wisp_v2: Option, - ) -> Result> + Send>, WispError> - where - R: ws::WebSocketRead + Send, - W: ws::WebSocketWrite + Send + 'static, - { - let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); - let ret_tx = tx.clone(); - let ret = async { - tx.write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; - - let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { - mut builders, - closure, - }) = wisp_v2 - { - send_info_packet(&tx, builders.deref_mut()).await?; - (closure)(builders.deref_mut()).await?; - maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? - } else { - (Vec::new(), None, true) - }; - - let (mux_result, muxstream_recv) = MuxInner::new_server( - AppendingWebSocketRead(extra_packet, rx), - tx.clone(), - supported_extensions.clone(), - buffer_size, - ); - - Ok(ServerMuxResult( - Self { - muxstream_recv, - actor_tx: mux_result.actor_tx, - downgraded, - supported_extensions, - tx, - actor_exited: mux_result.actor_exited, - }, - mux_result.mux.into_future(), - )) - } - .await; - - match ret { - Ok(x) => Ok(x), - Err(x) => match x { - WispError::PasswordExtensionCredsInvalid => { - ret_tx - .write_frame( - Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(), - ) - .await?; - ret_tx.close().await?; - Err(x) - } - WispError::CertAuthExtensionSigInvalid => { - ret_tx - .write_frame( - Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(), - ) - .await?; - ret_tx.close().await?; - Err(x) - } - x => Err(x), - }, - } - } - - /// Wait for a stream to be created. - pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { - if self.actor_exited.load(Ordering::Acquire) { - return None; - } - self.muxstream_recv.recv_async().await.ok() - } - - /// Send a ping to the client. - pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - let (tx, rx) = oneshot::channel(); - self.actor_tx - .send_async(WsEvent::SendPing(payload, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } - - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.actor_tx - .send_async(WsEvent::EndFut(reason)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend) - } - - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await - } - - /// Close all streams and send a close reason on stream ID 0. - /// - /// Also terminates the multiplexor future. - pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_internal(Some(reason)).await - } - - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.actor_exited.clone(), - } - } -} - -impl Drop for ServerMux { - fn drop(&mut self) { - let _ = self.actor_tx.send(WsEvent::EndFut(None)); - } -} - -/// Result of `ServerMux::new`. -pub struct ServerMuxResult(ServerMux, F) -where - F: Future> + Send; - -impl ServerMuxResult -where - F: Future> + Send, -{ - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ServerMux, F) { - (self.0, self.1) - } - - /// Require protocol extensions by their ID. Will close the multiplexor connection if - /// extensions are not supported. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ServerMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self - .0 - .supported_extensions - .iter() - .any(|x| x.get_id() == *extension) - { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0 - .close_with_reason(CloseReason::ExtensionsIncompatible) - .await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } - - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await - } -} - -/// Client side multiplexor. -pub struct ClientMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extensions: Vec, - actor_tx: mpsc::Sender, - tx: ws::LockedWebSocketWrite, - actor_exited: Arc, -} - -impl ClientMux { - /// Create a new client side multiplexor. - /// - /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. - /// **It is not guaranteed that all extensions you specify are available.** You must manually check - /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut rx: R, - tx: W, - wisp_v2: Option, - ) -> Result> + Send>, WispError> - where - R: ws::WebSocketRead + Send, - W: ws::WebSocketWrite + Send + 'static, - { - let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); - let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?; - - if first_packet.stream_id != 0 { - return Err(WispError::InvalidStreamId); - } - - if let PacketType::Continue(packet) = first_packet.packet_type { - let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { - mut builders, - closure, - }) = wisp_v2 - { - let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; - // if not downgraded - if !res.2 { - (closure)(&mut builders).await?; - send_info_packet(&tx, &mut builders).await?; - } - res - } else { - (Vec::new(), None, true) - }; - - let mux_result = MuxInner::new_client( - AppendingWebSocketRead(extra_packet, rx), - tx.clone(), - supported_extensions.clone(), - packet.buffer_remaining, - ); - - Ok(ClientMuxResult( - Self { - actor_tx: mux_result.actor_tx, - downgraded, - supported_extensions, - tx, - actor_exited: mux_result.actor_exited, - }, - mux_result.mux.into_future(), - )) - } else { - Err(WispError::InvalidPacketType) - } - } - - /// Create a new stream, multiplexed through Wisp. - pub async fn client_new_stream( - &self, - stream_type: StreamType, - host: String, - port: u16, - ) -> Result { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - if stream_type == StreamType::Udp - && !self - .supported_extensions - .iter() - .any(|x| x.get_id() == UdpProtocolExtension::ID) - { - return Err(WispError::ExtensionsNotSupported(vec![ - UdpProtocolExtension::ID, - ])); - } - let (tx, rx) = oneshot::channel(); - self.actor_tx - .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } - - /// Send a ping to the server. - pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - let (tx, rx) = oneshot::channel(); - self.actor_tx - .send_async(WsEvent::SendPing(payload, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } - - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.actor_tx - .send_async(WsEvent::EndFut(reason)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend) - } - - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await - } - - /// Close all streams and send a close reason on stream ID 0. - /// - /// Also terminates the multiplexor future. - pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_internal(Some(reason)).await - } - - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.actor_exited.clone(), - } - } -} - -impl Drop for ClientMux { - fn drop(&mut self) { - let _ = self.actor_tx.send(WsEvent::EndFut(None)); - } -} - -/// Result of `ClientMux::new`. -pub struct ClientMuxResult(ClientMux, F) -where - F: Future> + Send; - -impl ClientMuxResult -where - F: Future> + Send, -{ - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ClientMux, F) { - (self.0, self.1) - } - - /// Require protocol extensions by their ID. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ClientMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self - .0 - .supported_extensions - .iter() - .any(|x| x.get_id() == *extension) - { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0 - .close_with_reason(CloseReason::ExtensionsIncompatible) - .await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } - - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await - } -} diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs new file mode 100644 index 0000000..54abb30 --- /dev/null +++ b/wisp/src/mux/client.rs @@ -0,0 +1,223 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use flume as mpsc; +use futures::channel::oneshot; + +use crate::{ + extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, + inner::{MuxInner, WsEvent}, + ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload}, + CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, + WispError, +}; + +use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; + +/// Client side multiplexor. +pub struct ClientMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Vec, + actor_tx: mpsc::Sender, + tx: LockedWebSocketWrite, + actor_exited: Arc, +} + +impl ClientMux { + /// Create a new client side multiplexor. + /// + /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. + pub async fn create( + mut rx: R, + tx: W, + wisp_v2: Option, + ) -> Result> + Send>, WispError> + where + R: WebSocketRead + Send, + W: WebSocketWrite + Send + 'static, + { + let tx = LockedWebSocketWrite::new(Box::new(tx)); + let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?; + + if first_packet.stream_id != 0 { + return Err(WispError::InvalidStreamId); + } + + if let PacketType::Continue(packet) = first_packet.packet_type { + let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { + mut builders, + closure, + }) = wisp_v2 + { + let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; + // if not downgraded + if !res.2 { + (closure)(&mut builders).await?; + send_info_packet(&tx, &mut builders).await?; + } + res + } else { + (Vec::new(), None, true) + }; + + let mux_result = MuxInner::new_client( + AppendingWebSocketRead(extra_packet, rx), + tx.clone(), + supported_extensions.clone(), + packet.buffer_remaining, + ); + + Ok(ClientMuxResult( + Self { + actor_tx: mux_result.actor_tx, + downgraded, + supported_extensions, + tx, + actor_exited: mux_result.actor_exited, + }, + mux_result.mux.into_future(), + )) + } else { + Err(WispError::InvalidPacketType) + } + } + + /// Create a new stream, multiplexed through Wisp. + pub async fn client_new_stream( + &self, + stream_type: StreamType, + host: String, + port: u16, + ) -> Result { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + if stream_type == StreamType::Udp + && !self + .supported_extensions + .iter() + .any(|x| x.get_id() == UdpProtocolExtension::ID) + { + return Err(WispError::ExtensionsNotSupported(vec![ + UdpProtocolExtension::ID, + ])); + } + let (tx, rx) = oneshot::channel(); + self.actor_tx + .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } + + /// Send a ping to the server. + pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + let (tx, rx) = oneshot::channel(); + self.actor_tx + .send_async(WsEvent::SendPing(payload, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } + + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + self.actor_tx + .send_async(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + + /// Close all streams. + /// + /// Also terminates the multiplexor future. + pub async fn close(&self) -> Result<(), WispError> { + self.close_internal(None).await + } + + /// Close all streams and send a close reason on stream ID 0. + /// + /// Also terminates the multiplexor future. + pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_internal(Some(reason)).await + } + + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.actor_exited.clone(), + } + } +} + +impl Drop for ClientMux { + fn drop(&mut self) { + let _ = self.actor_tx.send(WsEvent::EndFut(None)); + } +} + +/// Result of `ClientMux::new`. +pub struct ClientMuxResult(ClientMux, F) +where + F: Future> + Send; + +impl ClientMuxResult +where + F: Future> + Send, +{ + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ClientMux, F) { + (self.0, self.1) + } + + /// Require protocol extensions by their ID. + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ClientMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self + .0 + .supported_extensions + .iter() + .any(|x| x.get_id() == *extension) + { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0 + .close_with_reason(CloseReason::ExtensionsIncompatible) + .await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } + + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } +} diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs new file mode 100644 index 0000000..8e9c765 --- /dev/null +++ b/wisp/src/mux/mod.rs @@ -0,0 +1,109 @@ +mod client; +mod server; +use std::{future::Future, pin::Pin, time::Duration}; + +pub use client::ClientMux; +use futures::{select, FutureExt}; +use futures_timer::Delay; +pub use server::{ServerMux, ServerMuxResult}; + +use crate::{ + extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, + ws::{Frame, LockedWebSocketWrite, WebSocketRead}, + Packet, PacketType, Role, WispError, +}; + +async fn maybe_wisp_v2( + read: &mut R, + write: &LockedWebSocketWrite, + role: Role, + builders: &mut [AnyProtocolExtensionBuilder], +) -> Result<(Vec, Option>, bool), WispError> +where + R: WebSocketRead + Send, +{ + let mut supported_extensions = Vec::new(); + let mut extra_packet: Option> = None; + let mut downgraded = true; + + let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, role, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.replace(Frame::from(packet).clone()); + } + } + + for extension in supported_extensions.iter_mut() { + extension.handle_handshake(read, write).await?; + } + Ok((supported_extensions, extra_packet, downgraded)) +} + +async fn send_info_packet( + write: &LockedWebSocketWrite, + builders: &mut [AnyProtocolExtensionBuilder], +) -> Result<(), WispError> { + write + .write_frame( + Packet::new_info( + builders + .iter_mut() + .map(|x| x.build_to_extension(Role::Server)) + .collect::, _>>()?, + ) + .into(), + ) + .await +} + +/// Wisp V2 handshake and protocol extension settings wrapper struct. +pub struct WispV2Extensions { + builders: Vec, + closure: Box< + dyn Fn( + &mut [AnyProtocolExtensionBuilder], + ) -> Pin> + Sync + Send>> + + Send, + >, +} + +impl WispV2Extensions { + /// Create a Wisp V2 settings struct with no middleware. + pub fn new(builders: Vec) -> Self { + Self { + builders, + closure: Box::new(|_| Box::pin(async { Ok(()) })), + } + } + + /// Create a Wisp V2 settings struct with some middleware. + pub fn new_with_middleware(builders: Vec, closure: C) -> Self + where + C: Fn( + &mut [AnyProtocolExtensionBuilder], + ) -> Pin> + Sync + Send>> + + Send + + 'static, + { + Self { + builders, + closure: Box::new(closure), + } + } + + /// Add a Wisp V2 extension builder to the settings struct. + pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) { + self.builders.push(extension); + } +} diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs new file mode 100644 index 0000000..b1ee07f --- /dev/null +++ b/wisp/src/mux/server.rs @@ -0,0 +1,225 @@ +use std::{ + future::Future, + ops::DerefMut, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use flume as mpsc; +use futures::channel::oneshot; + +use crate::{ + extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, + inner::{MuxInner, WsEvent}, + ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError, +}; + +use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; + +/// Server-side multiplexor. +pub struct ServerMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Vec, + actor_tx: mpsc::Sender, + muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, + tx: LockedWebSocketWrite, + actor_exited: Arc, +} + +impl ServerMux { + /// Create a new server-side multiplexor. + /// + /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. + pub async fn create( + mut rx: R, + tx: W, + buffer_size: u32, + wisp_v2: Option, + ) -> Result> + Send>, WispError> + where + R: WebSocketRead + Send, + W: WebSocketWrite + Send + 'static, + { + let tx = LockedWebSocketWrite::new(Box::new(tx)); + let ret_tx = tx.clone(); + let ret = async { + tx.write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; + + let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { + mut builders, + closure, + }) = wisp_v2 + { + send_info_packet(&tx, builders.deref_mut()).await?; + (closure)(builders.deref_mut()).await?; + maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? + } else { + (Vec::new(), None, true) + }; + + let (mux_result, muxstream_recv) = MuxInner::new_server( + AppendingWebSocketRead(extra_packet, rx), + tx.clone(), + supported_extensions.clone(), + buffer_size, + ); + + Ok(ServerMuxResult( + Self { + muxstream_recv, + actor_tx: mux_result.actor_tx, + downgraded, + supported_extensions, + tx, + actor_exited: mux_result.actor_exited, + }, + mux_result.mux.into_future(), + )) + } + .await; + + match ret { + Ok(x) => Ok(x), + Err(x) => match x { + WispError::PasswordExtensionCredsInvalid => { + ret_tx + .write_frame( + Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(), + ) + .await?; + ret_tx.close().await?; + Err(x) + } + WispError::CertAuthExtensionSigInvalid => { + ret_tx + .write_frame( + Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(), + ) + .await?; + ret_tx.close().await?; + Err(x) + } + x => Err(x), + }, + } + } + + /// Wait for a stream to be created. + pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { + if self.actor_exited.load(Ordering::Acquire) { + return None; + } + self.muxstream_recv.recv_async().await.ok() + } + + /// Send a ping to the client. + pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + let (tx, rx) = oneshot::channel(); + self.actor_tx + .send_async(WsEvent::SendPing(payload, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } + + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + self.actor_tx + .send_async(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + + /// Close all streams. + /// + /// Also terminates the multiplexor future. + pub async fn close(&self) -> Result<(), WispError> { + self.close_internal(None).await + } + + /// Close all streams and send a close reason on stream ID 0. + /// + /// Also terminates the multiplexor future. + pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_internal(Some(reason)).await + } + + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.actor_exited.clone(), + } + } +} + +impl Drop for ServerMux { + fn drop(&mut self) { + let _ = self.actor_tx.send(WsEvent::EndFut(None)); + } +} + +/// Result of `ServerMux::new`. +pub struct ServerMuxResult(ServerMux, F) +where + F: Future> + Send; + +impl ServerMuxResult +where + F: Future> + Send, +{ + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ServerMux, F) { + (self.0, self.1) + } + + /// Require protocol extensions by their ID. Will close the multiplexor connection if + /// extensions are not supported. + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ServerMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self + .0 + .supported_extensions + .iter() + .any(|x| x.get_id() == *extension) + { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0 + .close_with_reason(CloseReason::ExtensionsIncompatible) + .await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } + + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } +}