From 694d87f7316f3a9c4465903d6acc4d44d7579078 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 14 Sep 2024 13:09:14 -0700 Subject: [PATCH] add downcasting --- server/src/handle/wisp.rs | 2 +- wisp/src/extensions/mod.rs | 52 +++++++++++++++++++++++++++++++------- wisp/src/lib.rs | 24 ++++++++---------- 3 files changed, 54 insertions(+), 24 deletions(-) diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 1e59863..01e816d 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -237,7 +237,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { debug!( "new wisp client id {:?} connected with extensions {:?}", - id, mux.supported_extension_ids + id, mux.supported_extensions.iter().map(|x| x.get_id()).collect::>() ); let mut set: JoinSet<()> = JoinSet::new(); diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 1970387..6c1347a 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -1,11 +1,14 @@ //! Wisp protocol extensions. -pub mod password; -pub mod udp; -pub mod motd; #[cfg(feature = "certificate")] pub mod cert; +pub mod motd; +pub mod password; +pub mod udp; -use std::ops::{Deref, DerefMut}; +use std::{ + any::TypeId, + ops::{Deref, DerefMut}, +}; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; @@ -17,13 +20,18 @@ use crate::{ /// Type-erased protocol extension that implements Clone. #[derive(Debug)] -pub struct AnyProtocolExtension(Box); +pub struct AnyProtocolExtension(Box); impl AnyProtocolExtension { /// Create a new type-erased protocol extension. - pub fn new(extension: T) -> Self { + pub fn new(extension: T) -> Self { Self(Box::new(extension)) } + + /// Downcast the protocol extension. + pub fn downcast(self) -> Result, Self> { + self.0.__downcast().map_err(Self) + } } impl Deref for AnyProtocolExtension { @@ -61,7 +69,7 @@ impl From for Bytes { /// See [the /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions). #[async_trait] -pub trait ProtocolExtension: std::fmt::Debug { +pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { /// Get the protocol extension ID. fn get_id(&self) -> u8; /// Get the protocol extension's supported packets. @@ -95,6 +103,29 @@ pub trait ProtocolExtension: std::fmt::Debug { /// Clone the protocol extension. fn box_clone(&self) -> Box; + + /// Do not override. + fn __internal_type_id(&self) -> TypeId { + TypeId::of::() + } +} + +impl dyn ProtocolExtension { + fn __is(&self) -> bool { + let t = TypeId::of::(); + self.__internal_type_id() == t + } + + fn __downcast(self: Box) -> Result, Box> { + if self.__is::() { + unsafe { + let raw: *mut dyn ProtocolExtension = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } } /// Trait to build a Wisp protocol extension from a payload. @@ -105,8 +136,11 @@ pub trait ProtocolExtensionBuilder { fn get_id(&self) -> u8; /// Build a protocol extension from the extension's metadata. - fn build_from_bytes(&mut self, bytes: Bytes, role: Role) - -> Result; + fn build_from_bytes( + &mut self, + bytes: Bytes, + role: Role, + ) -> Result; /// Build a protocol extension to send to the other side. fn build_to_extension(&mut self, role: Role) -> Result; diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index b39f7c4..4e2932c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -219,7 +219,7 @@ pub struct ServerMux { /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. - pub supported_extension_ids: Vec, + pub supported_extensions: Vec, actor_tx: mpsc::Sender, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, tx: ws::LockedWebSocketWrite, @@ -255,12 +255,10 @@ impl ServerMux { (Vec::new(), None, true) }; - let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); - let (mux_result, muxstream_recv) = MuxInner::new_server( AppendingWebSocketRead(extra_packet, rx), tx.clone(), - supported_extensions, + supported_extensions.clone(), buffer_size, ); @@ -269,7 +267,7 @@ impl ServerMux { muxstream_recv, actor_tx: mux_result.actor_tx, downgraded, - supported_extension_ids, + supported_extensions, tx, actor_exited: mux_result.actor_exited, }, @@ -348,7 +346,7 @@ where ) -> Result<(ServerMux, F), WispError> { let mut unsupported_extensions = Vec::new(); for extension in extensions { - if !self.0.supported_extension_ids.contains(extension) { + if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { unsupported_extensions.push(*extension); } } @@ -377,7 +375,7 @@ pub struct ClientMux { /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. - pub supported_extension_ids: Vec, + pub supported_extensions: Vec, actor_tx: mpsc::Sender, tx: ws::LockedWebSocketWrite, actor_exited: Arc, @@ -418,12 +416,10 @@ impl ClientMux { (Vec::new(), None, true) }; - let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); - let mux_result = MuxInner::new_client( AppendingWebSocketRead(extra_packet, rx), tx.clone(), - supported_extensions, + supported_extensions.clone(), packet.buffer_remaining, ); @@ -431,7 +427,7 @@ impl ClientMux { Self { actor_tx: mux_result.actor_tx, downgraded, - supported_extension_ids, + supported_extensions, tx, actor_exited: mux_result.actor_exited, }, @@ -454,9 +450,9 @@ impl ClientMux { } if stream_type == StreamType::Udp && !self - .supported_extension_ids + .supported_extensions .iter() - .any(|x| *x == UdpProtocolExtension::ID) + .any(|x| x.get_id() == UdpProtocolExtension::ID) { return Err(WispError::ExtensionsNotSupported(vec![ UdpProtocolExtension::ID, @@ -532,7 +528,7 @@ where ) -> Result<(ClientMux, F), WispError> { let mut unsupported_extensions = Vec::new(); for extension in extensions { - if !self.0.supported_extension_ids.contains(extension) { + if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { unsupported_extensions.push(*extension); } }