diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 11c1097..cc231f3 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -15,9 +15,9 @@ use pin_project_lite::pin_project; use wasm_bindgen_futures::spawn_local; use webpki_roots::TLS_SERVER_ROOTS; use wisp_mux::{ - extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, + extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder}, ws::{WebSocketRead, WebSocketWrite}, - ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, + ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Extensions, }; use crate::{ @@ -106,10 +106,12 @@ impl StreamProvider { &self, mut locked: MutexGuard<'_, Option>, ) -> Result<(), EpoxyError> { - let extensions_vec: Vec> = - vec![Box::new(UdpProtocolExtensionBuilder)]; + let extensions_vec: Vec = + vec![AnyProtocolExtensionBuilder::new( + UdpProtocolExtensionBuilder, + )]; let extensions = if self.wisp_v2 { - Some(extensions_vec) + Some(WispV2Extensions::new(extensions_vec)) } else { None }; diff --git a/server/src/config.rs b/server/src/config.rs index 95c18cc..f55f5d0 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -6,12 +6,15 @@ use lazy_static::lazy_static; use log::LevelFilter; use regex::RegexSet; use serde::{Deserialize, Serialize}; -use wisp_mux::extensions::{ - cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder}, - motd::MotdProtocolExtensionBuilder, - password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::UdpProtocolExtensionBuilder, - ProtocolExtensionBuilder, +use wisp_mux::{ + extensions::{ + cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder}, + motd::MotdProtocolExtensionBuilder, + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::UdpProtocolExtensionBuilder, + AnyProtocolExtensionBuilder, + }, + WispV2Extensions, }; use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER}; @@ -195,8 +198,6 @@ pub struct Config { pub stream: StreamConfig, } -type AnyProtocolExtensionBuilder = Box; - struct ConfigCache { pub blocked_ports: Vec>, pub allowed_ports: Vec>, @@ -293,41 +294,49 @@ impl Default for WispConfig { } impl WispConfig { - pub async fn to_opts( - &self, - ) -> anyhow::Result<(Option>, Vec, u32)> { + pub async fn to_opts(&self) -> anyhow::Result<(Option, Vec, u32)> { if self.wisp_v2 { let mut extensions: Vec = Vec::new(); let mut required_extensions: Vec = Vec::new(); if self.extensions.contains(&ProtocolExtension::Udp) { - extensions.push(Box::new(UdpProtocolExtensionBuilder)); + extensions.push(AnyProtocolExtensionBuilder::new( + UdpProtocolExtensionBuilder, + )); } if self.extensions.contains(&ProtocolExtension::Motd) { - extensions.push(Box::new(MotdProtocolExtensionBuilder::Server( - self.motd_extension.clone(), - ))); + extensions.push(AnyProtocolExtensionBuilder::new( + MotdProtocolExtensionBuilder::Server(self.motd_extension.clone()), + )); } match self.auth_extension { Some(ProtocolExtensionAuth::Password) => { - extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server( - self.password_extension_users.clone(), - ))); + extensions.push(AnyProtocolExtensionBuilder::new( + PasswordProtocolExtensionBuilder::new_server( + self.password_extension_users.clone(), + ), + )); required_extensions.push(PasswordProtocolExtension::ID); } Some(ProtocolExtensionAuth::Certificate) => { - extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server( - get_certificates_from_paths(self.certificate_extension_keys.clone()) - .await?, - ))); + extensions.push(AnyProtocolExtensionBuilder::new( + CertAuthProtocolExtensionBuilder::new_server( + get_certificates_from_paths(self.certificate_extension_keys.clone()) + .await?, + ), + )); required_extensions.push(CertAuthProtocolExtension::ID); } None => {} } - Ok((Some(extensions), required_extensions, self.buffer_size)) + Ok(( + Some(WispV2Extensions::new(extensions)), + required_extensions, + self.buffer_size, + )) } else { Ok((None, Vec::new(), self.buffer_size)) } diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 48b0d49..a3c705a 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -36,9 +36,9 @@ use wisp_mux::{ motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - ProtocolExtensionBuilder, + AnyProtocolExtensionBuilder, }, - ClientMux, StreamType, WispError, + ClientMux, StreamType, WispError, WispV2Extensions, }; #[derive(Debug)] @@ -169,23 +169,27 @@ async fn main() -> Result<(), Box> { (Cursor::new(parts.read_buf).chain(r), w) }); - let mut extensions: Vec> = Vec::new(); + let mut extensions: Vec = Vec::new(); let mut extension_ids: Vec = Vec::new(); if opts.udp { - extensions.push(Box::new(UdpProtocolExtensionBuilder)); + extensions.push(AnyProtocolExtensionBuilder::new( + UdpProtocolExtensionBuilder, + )); extension_ids.push(UdpProtocolExtension::ID); } if opts.motd { - extensions.push(Box::new(MotdProtocolExtensionBuilder::Client)); + extensions.push(AnyProtocolExtensionBuilder::new( + MotdProtocolExtensionBuilder::Client, + )); } if let Some(auth) = auth { - extensions.push(Box::new(auth)); + extensions.push(AnyProtocolExtensionBuilder::new(auth)); extension_ids.push(PasswordProtocolExtension::ID); } if let Some(certauth) = opts.certauth { let key = get_cert(certauth).await?; let extension = CertAuthProtocolExtensionBuilder::new_client(key); - extensions.push(Box::new(extension)); + extensions.push(AnyProtocolExtensionBuilder::new(extension)); extension_ids.push(CertAuthProtocolExtension::ID); } @@ -194,7 +198,7 @@ async fn main() -> Result<(), Box> { .await? .with_no_required_extensions() } else { - ClientMux::create(rx, tx, Some(extensions)) + ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions))) .await? .with_required_extensions(extension_ids.as_slice()) .await? diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 8f966df..3c81875 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -155,7 +155,7 @@ impl dyn ProtocolExtension { } /// Trait to build a Wisp protocol extension from a payload. -pub trait ProtocolExtensionBuilder { +pub trait ProtocolExtensionBuilder: Sync + Send + 'static { /// Get the protocol extension ID. /// /// Used to decide whether this builder should be used. @@ -170,4 +170,81 @@ pub trait ProtocolExtensionBuilder { /// Build a protocol extension to send to the other side. fn build_to_extension(&mut self, role: Role) -> Result; + + /// Do not override. + fn __internal_type_id(&self) -> TypeId { + TypeId::of::() + } +} + +impl dyn ProtocolExtensionBuilder { + fn __is(&self) -> bool { + let t = TypeId::of::(); + self.__internal_type_id() == t + } + + fn __downcast(self: Box) -> Result, Box> { + if self.__is::() { + unsafe { + let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } + + fn __downcast_ref(&self) -> Option<&T> { + if self.__is::() { + unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) } + } else { + None + } + } + + fn __downcast_mut(&mut self) -> Option<&mut T> { + if self.__is::() { + unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) } + } else { + None + } + } +} + +/// Type-erased protocol extension builder. +pub struct AnyProtocolExtensionBuilder(Box); + +impl AnyProtocolExtensionBuilder { + /// Create a new type-erased protocol extension builder. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } + + /// Downcast the protocol extension builder. + pub fn downcast(self) -> Result, Self> { + self.0.__downcast().map_err(Self) + } + + /// Downcast the protocol extension builder. + pub fn downcast_ref(&self) -> Option<&T> { + self.0.__downcast_ref() + } + + /// Downcast the protocol extension builder. + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.0.__downcast_mut() + } +} + +impl Deref for AnyProtocolExtensionBuilder { + type Target = dyn ProtocolExtensionBuilder; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for AnyProtocolExtensionBuilder { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index cf1e2f1..a9d846c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -19,13 +19,17 @@ pub mod ws; pub use crate::{packet::*, stream::*}; -use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; +use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}; use flume as mpsc; use futures::{channel::oneshot, Future}; use inner::{MuxInner, WsEvent}; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, +use std::{ + ops::DerefMut, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; @@ -173,7 +177,7 @@ async fn maybe_wisp_v2( read: &mut R, write: &LockedWebSocketWrite, role: Role, - builders: &mut [Box], + builders: &mut [AnyProtocolExtensionBuilder], ) -> Result<(Vec, Option>, bool), WispError> where R: ws::WebSocketRead + Send, @@ -205,7 +209,7 @@ where async fn send_info_packet( write: &LockedWebSocketWrite, - builders: &mut [Box], + builders: &mut [AnyProtocolExtensionBuilder], ) -> Result<(), WispError> { write .write_frame( @@ -220,6 +224,42 @@ async fn send_info_packet( .await } +/// Wisp V2 handshake and protocol extension settings wrapper struct. +pub struct WispV2Extensions { + builders: Vec, + closure: Box< + dyn Fn( + &mut [AnyProtocolExtensionBuilder], + ) -> Pin> + Sync + Send>> + + Send, + >, +} + +impl WispV2Extensions { + /// Create a Wisp V2 settings struct with no middleware. + pub fn new(builders: Vec) -> Self { + Self { + builders, + closure: Box::new(|_| Box::pin(async { Ok(()) })), + } + } + + /// Create a Wisp V2 settings struct with some middleware. + pub fn new_with_middleware(builders: Vec, closure: C) -> Self + where + C: Fn( + &mut [AnyProtocolExtensionBuilder], + ) -> Pin> + Sync + Send>> + + Send + + 'static, + { + Self { + builders, + closure: Box::new(closure), + } + } +} + /// Server-side multiplexor. pub struct ServerMux { /// Whether the connection was downgraded to Wisp v1. @@ -237,14 +277,14 @@ pub struct ServerMux { impl ServerMux { /// Create a new server-side multiplexor. /// - /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn create( mut rx: R, tx: W, buffer_size: u32, - extension_builders: Option>>, + wisp_v2: Option, ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, @@ -256,13 +296,17 @@ impl ServerMux { tx.write_frame(Packet::new_continue(0, buffer_size).into()) .await?; - let (supported_extensions, extra_packet, downgraded) = - if let Some(mut builders) = extension_builders { - send_info_packet(&tx, &mut builders).await?; - maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? - } else { - (Vec::new(), None, true) - }; + let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { + mut builders, + closure, + }) = wisp_v2 + { + send_info_packet(&tx, builders.deref_mut()).await?; + (closure)(builders.deref_mut()).await?; + maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? + } else { + (Vec::new(), None, true) + }; let (mux_result, muxstream_recv) = MuxInner::new_server( AppendingWebSocketRead(extra_packet, rx), @@ -424,13 +468,13 @@ pub struct ClientMux { impl ClientMux { /// Create a new client side multiplexor. /// - /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn create( mut rx: R, tx: W, - extension_builders: Option>>, + wisp_v2: Option, ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, @@ -444,17 +488,21 @@ impl ClientMux { } if let PacketType::Continue(packet) = first_packet.packet_type { - let (supported_extensions, extra_packet, downgraded) = - if let Some(mut builders) = extension_builders { - let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; - // if not downgraded - if !res.2 { - send_info_packet(&tx, &mut builders).await?; - } - res - } else { - (Vec::new(), None, true) - }; + let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { + mut builders, + closure, + }) = wisp_v2 + { + let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; + // if not downgraded + if !res.2 { + (closure)(&mut builders).await?; + send_info_packet(&tx, &mut builders).await?; + } + res + } else { + (Vec::new(), None, true) + }; let mux_result = MuxInner::new_client( AppendingWebSocketRead(extra_packet, rx), diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 86dfc56..3d143f4 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,5 +1,5 @@ use crate::{ - extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, + extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, Role, WispError, WISP_VERSION, }; @@ -431,10 +431,57 @@ impl<'a> Packet<'a> { }) } + fn parse_info( + mut bytes: Payload<'a>, + role: Role, + extension_builders: &mut [AnyProtocolExtensionBuilder], + ) -> Result { + // packet type is already read by code that calls this + if bytes.remaining() < 4 + 2 { + return Err(WispError::PacketTooSmall); + } + if bytes.get_u32_le() != 0 { + return Err(WispError::InvalidStreamId); + } + + let version = WispVersion { + major: bytes.get_u8(), + minor: bytes.get_u8(), + }; + + if version.major != WISP_VERSION.major { + return Err(WispError::IncompatibleProtocolVersion); + } + + let mut extensions = Vec::new(); + + while bytes.remaining() > 4 { + // We have some extensions + let id = bytes.get_u8(); + let length = usize::try_from(bytes.get_u32_le())?; + if bytes.remaining() < length { + return Err(WispError::PacketTooSmall); + } + if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) { + extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?) + } else { + bytes.advance(length) + } + } + + Ok(Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version, + extensions, + }), + }) + } + pub(crate) fn maybe_parse_info( frame: Frame<'a>, role: Role, - extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], + extension_builders: &mut [AnyProtocolExtensionBuilder], ) -> Result { if !frame.finished { return Err(WispError::WsFrameNotFinished); @@ -504,53 +551,6 @@ impl<'a> Packet<'a> { } } } - - fn parse_info( - mut bytes: Payload<'a>, - role: Role, - extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], - ) -> Result { - // packet type is already read by code that calls this - if bytes.remaining() < 4 + 2 { - return Err(WispError::PacketTooSmall); - } - if bytes.get_u32_le() != 0 { - return Err(WispError::InvalidStreamId); - } - - let version = WispVersion { - major: bytes.get_u8(), - minor: bytes.get_u8(), - }; - - if version.major != WISP_VERSION.major { - return Err(WispError::IncompatibleProtocolVersion); - } - - let mut extensions = Vec::new(); - - while bytes.remaining() > 4 { - // We have some extensions - let id = bytes.get_u8(); - let length = usize::try_from(bytes.get_u32_le())?; - if bytes.remaining() < length { - return Err(WispError::PacketTooSmall); - } - if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) { - extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?) - } else { - bytes.advance(length) - } - } - - Ok(Self { - stream_id: 0, - packet_type: PacketType::Info(InfoPacket { - version, - extensions, - }), - }) - } } impl Encode for Packet<'_> {