add downcasting

This commit is contained in:
Toshit Chawda 2024-09-14 13:09:14 -07:00
parent c5e93675de
commit 694d87f731
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 54 additions and 24 deletions

View file

@ -237,7 +237,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
debug!( debug!(
"new wisp client id {:?} connected with extensions {:?}", "new wisp client id {:?} connected with extensions {:?}",
id, mux.supported_extension_ids id, mux.supported_extensions.iter().map(|x| x.get_id()).collect::<Vec<_>>()
); );
let mut set: JoinSet<()> = JoinSet::new(); let mut set: JoinSet<()> = JoinSet::new();

View file

@ -1,11 +1,14 @@
//! Wisp protocol extensions. //! Wisp protocol extensions.
pub mod password;
pub mod udp;
pub mod motd;
#[cfg(feature = "certificate")] #[cfg(feature = "certificate")]
pub mod cert; pub mod cert;
pub mod motd;
pub mod password;
pub mod udp;
use std::ops::{Deref, DerefMut}; use std::{
any::TypeId,
ops::{Deref, DerefMut},
};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
@ -17,13 +20,18 @@ use crate::{
/// Type-erased protocol extension that implements Clone. /// Type-erased protocol extension that implements Clone.
#[derive(Debug)] #[derive(Debug)]
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>); pub struct AnyProtocolExtension(Box<dyn ProtocolExtension>);
impl AnyProtocolExtension { impl AnyProtocolExtension {
/// Create a new type-erased protocol extension. /// Create a new type-erased protocol extension.
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self { pub fn new<T: ProtocolExtension>(extension: T) -> Self {
Self(Box::new(extension)) Self(Box::new(extension))
} }
/// Downcast the protocol extension.
pub fn downcast<T: ProtocolExtension>(self) -> Result<Box<T>, Self> {
self.0.__downcast().map_err(Self)
}
} }
impl Deref for AnyProtocolExtension { impl Deref for AnyProtocolExtension {
@ -61,7 +69,7 @@ impl From<AnyProtocolExtension> for Bytes {
/// See [the /// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions). /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
#[async_trait] #[async_trait]
pub trait ProtocolExtension: std::fmt::Debug { pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static {
/// Get the protocol extension ID. /// Get the protocol extension ID.
fn get_id(&self) -> u8; fn get_id(&self) -> u8;
/// Get the protocol extension's supported packets. /// Get the protocol extension's supported packets.
@ -95,6 +103,29 @@ pub trait ProtocolExtension: std::fmt::Debug {
/// Clone the protocol extension. /// Clone the protocol extension.
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>; fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
/// Do not override.
fn __internal_type_id(&self) -> TypeId {
TypeId::of::<Self>()
}
}
impl dyn ProtocolExtension {
fn __is<T: ProtocolExtension>(&self) -> bool {
let t = TypeId::of::<T>();
self.__internal_type_id() == t
}
fn __downcast<T: ProtocolExtension>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
if self.__is::<T>() {
unsafe {
let raw: *mut dyn ProtocolExtension = Box::into_raw(self);
Ok(Box::from_raw(raw as *mut T))
}
} else {
Err(self)
}
}
} }
/// Trait to build a Wisp protocol extension from a payload. /// Trait to build a Wisp protocol extension from a payload.
@ -105,8 +136,11 @@ pub trait ProtocolExtensionBuilder {
fn get_id(&self) -> u8; fn get_id(&self) -> u8;
/// Build a protocol extension from the extension's metadata. /// Build a protocol extension from the extension's metadata.
fn build_from_bytes(&mut self, bytes: Bytes, role: Role) fn build_from_bytes(
-> Result<AnyProtocolExtension, WispError>; &mut self,
bytes: Bytes,
role: Role,
) -> Result<AnyProtocolExtension, WispError>;
/// Build a protocol extension to send to the other side. /// Build a protocol extension to send to the other side.
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>; fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;

View file

@ -219,7 +219,7 @@ pub struct ServerMux {
/// If this variable is true you must assume no extensions are supported. /// If this variable is true you must assume no extensions are supported.
pub downgraded: bool, pub downgraded: bool,
/// Extensions that are supported by both sides. /// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>, pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>, actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
@ -255,12 +255,10 @@ impl ServerMux {
(Vec::new(), None, true) (Vec::new(), None, true)
}; };
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let (mux_result, muxstream_recv) = MuxInner::new_server( let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
tx.clone(), tx.clone(),
supported_extensions, supported_extensions.clone(),
buffer_size, buffer_size,
); );
@ -269,7 +267,7 @@ impl ServerMux {
muxstream_recv, muxstream_recv,
actor_tx: mux_result.actor_tx, actor_tx: mux_result.actor_tx,
downgraded, downgraded,
supported_extension_ids, supported_extensions,
tx, tx,
actor_exited: mux_result.actor_exited, actor_exited: mux_result.actor_exited,
}, },
@ -348,7 +346,7 @@ where
) -> Result<(ServerMux, F), WispError> { ) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new(); let mut unsupported_extensions = Vec::new();
for extension in extensions { for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) { if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
unsupported_extensions.push(*extension); unsupported_extensions.push(*extension);
} }
} }
@ -377,7 +375,7 @@ pub struct ClientMux {
/// If this variable is true you must assume no extensions are supported. /// If this variable is true you must assume no extensions are supported.
pub downgraded: bool, pub downgraded: bool,
/// Extensions that are supported by both sides. /// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>, pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>, actor_tx: mpsc::Sender<WsEvent>,
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
actor_exited: Arc<AtomicBool>, actor_exited: Arc<AtomicBool>,
@ -418,12 +416,10 @@ impl ClientMux {
(Vec::new(), None, true) (Vec::new(), None, true)
}; };
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let mux_result = MuxInner::new_client( let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
tx.clone(), tx.clone(),
supported_extensions, supported_extensions.clone(),
packet.buffer_remaining, packet.buffer_remaining,
); );
@ -431,7 +427,7 @@ impl ClientMux {
Self { Self {
actor_tx: mux_result.actor_tx, actor_tx: mux_result.actor_tx,
downgraded, downgraded,
supported_extension_ids, supported_extensions,
tx, tx,
actor_exited: mux_result.actor_exited, actor_exited: mux_result.actor_exited,
}, },
@ -454,9 +450,9 @@ impl ClientMux {
} }
if stream_type == StreamType::Udp if stream_type == StreamType::Udp
&& !self && !self
.supported_extension_ids .supported_extensions
.iter() .iter()
.any(|x| *x == UdpProtocolExtension::ID) .any(|x| x.get_id() == UdpProtocolExtension::ID)
{ {
return Err(WispError::ExtensionsNotSupported(vec![ return Err(WispError::ExtensionsNotSupported(vec![
UdpProtocolExtension::ID, UdpProtocolExtension::ID,
@ -532,7 +528,7 @@ where
) -> Result<(ClientMux, F), WispError> { ) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new(); let mut unsupported_extensions = Vec::new();
for extension in extensions { for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) { if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
unsupported_extensions.push(*extension); unsupported_extensions.push(*extension);
} }
} }