mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
add downcasting
This commit is contained in:
parent
c5e93675de
commit
694d87f731
3 changed files with 54 additions and 24 deletions
|
@ -237,7 +237,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"new wisp client id {:?} connected with extensions {:?}",
|
"new wisp client id {:?} connected with extensions {:?}",
|
||||||
id, mux.supported_extension_ids
|
id, mux.supported_extensions.iter().map(|x| x.get_id()).collect::<Vec<_>>()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut set: JoinSet<()> = JoinSet::new();
|
let mut set: JoinSet<()> = JoinSet::new();
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
//! Wisp protocol extensions.
|
//! Wisp protocol extensions.
|
||||||
pub mod password;
|
|
||||||
pub mod udp;
|
|
||||||
pub mod motd;
|
|
||||||
#[cfg(feature = "certificate")]
|
#[cfg(feature = "certificate")]
|
||||||
pub mod cert;
|
pub mod cert;
|
||||||
|
pub mod motd;
|
||||||
|
pub mod password;
|
||||||
|
pub mod udp;
|
||||||
|
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
@ -17,13 +20,18 @@ use crate::{
|
||||||
|
|
||||||
/// Type-erased protocol extension that implements Clone.
|
/// Type-erased protocol extension that implements Clone.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);
|
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension>);
|
||||||
|
|
||||||
impl AnyProtocolExtension {
|
impl AnyProtocolExtension {
|
||||||
/// Create a new type-erased protocol extension.
|
/// Create a new type-erased protocol extension.
|
||||||
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
|
pub fn new<T: ProtocolExtension>(extension: T) -> Self {
|
||||||
Self(Box::new(extension))
|
Self(Box::new(extension))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Downcast the protocol extension.
|
||||||
|
pub fn downcast<T: ProtocolExtension>(self) -> Result<Box<T>, Self> {
|
||||||
|
self.0.__downcast().map_err(Self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for AnyProtocolExtension {
|
impl Deref for AnyProtocolExtension {
|
||||||
|
@ -61,7 +69,7 @@ impl From<AnyProtocolExtension> for Bytes {
|
||||||
/// See [the
|
/// See [the
|
||||||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
|
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ProtocolExtension: std::fmt::Debug {
|
pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static {
|
||||||
/// Get the protocol extension ID.
|
/// Get the protocol extension ID.
|
||||||
fn get_id(&self) -> u8;
|
fn get_id(&self) -> u8;
|
||||||
/// Get the protocol extension's supported packets.
|
/// Get the protocol extension's supported packets.
|
||||||
|
@ -95,6 +103,29 @@ pub trait ProtocolExtension: std::fmt::Debug {
|
||||||
|
|
||||||
/// Clone the protocol extension.
|
/// Clone the protocol extension.
|
||||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
||||||
|
|
||||||
|
/// Do not override.
|
||||||
|
fn __internal_type_id(&self) -> TypeId {
|
||||||
|
TypeId::of::<Self>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl dyn ProtocolExtension {
|
||||||
|
fn __is<T: ProtocolExtension>(&self) -> bool {
|
||||||
|
let t = TypeId::of::<T>();
|
||||||
|
self.__internal_type_id() == t
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __downcast<T: ProtocolExtension>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
|
||||||
|
if self.__is::<T>() {
|
||||||
|
unsafe {
|
||||||
|
let raw: *mut dyn ProtocolExtension = Box::into_raw(self);
|
||||||
|
Ok(Box::from_raw(raw as *mut T))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait to build a Wisp protocol extension from a payload.
|
/// Trait to build a Wisp protocol extension from a payload.
|
||||||
|
@ -105,8 +136,11 @@ pub trait ProtocolExtensionBuilder {
|
||||||
fn get_id(&self) -> u8;
|
fn get_id(&self) -> u8;
|
||||||
|
|
||||||
/// Build a protocol extension from the extension's metadata.
|
/// Build a protocol extension from the extension's metadata.
|
||||||
fn build_from_bytes(&mut self, bytes: Bytes, role: Role)
|
fn build_from_bytes(
|
||||||
-> Result<AnyProtocolExtension, WispError>;
|
&mut self,
|
||||||
|
bytes: Bytes,
|
||||||
|
role: Role,
|
||||||
|
) -> Result<AnyProtocolExtension, WispError>;
|
||||||
|
|
||||||
/// Build a protocol extension to send to the other side.
|
/// Build a protocol extension to send to the other side.
|
||||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
||||||
|
|
|
@ -219,7 +219,7 @@ pub struct ServerMux {
|
||||||
/// If this variable is true you must assume no extensions are supported.
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
pub downgraded: bool,
|
pub downgraded: bool,
|
||||||
/// Extensions that are supported by both sides.
|
/// Extensions that are supported by both sides.
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||||
actor_tx: mpsc::Sender<WsEvent>,
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
|
@ -255,12 +255,10 @@ impl ServerMux {
|
||||||
(Vec::new(), None, true)
|
(Vec::new(), None, true)
|
||||||
};
|
};
|
||||||
|
|
||||||
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
|
|
||||||
|
|
||||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
tx.clone(),
|
tx.clone(),
|
||||||
supported_extensions,
|
supported_extensions.clone(),
|
||||||
buffer_size,
|
buffer_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -269,7 +267,7 @@ impl ServerMux {
|
||||||
muxstream_recv,
|
muxstream_recv,
|
||||||
actor_tx: mux_result.actor_tx,
|
actor_tx: mux_result.actor_tx,
|
||||||
downgraded,
|
downgraded,
|
||||||
supported_extension_ids,
|
supported_extensions,
|
||||||
tx,
|
tx,
|
||||||
actor_exited: mux_result.actor_exited,
|
actor_exited: mux_result.actor_exited,
|
||||||
},
|
},
|
||||||
|
@ -348,7 +346,7 @@ where
|
||||||
) -> Result<(ServerMux, F), WispError> {
|
) -> Result<(ServerMux, F), WispError> {
|
||||||
let mut unsupported_extensions = Vec::new();
|
let mut unsupported_extensions = Vec::new();
|
||||||
for extension in extensions {
|
for extension in extensions {
|
||||||
if !self.0.supported_extension_ids.contains(extension) {
|
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
|
||||||
unsupported_extensions.push(*extension);
|
unsupported_extensions.push(*extension);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -377,7 +375,7 @@ pub struct ClientMux {
|
||||||
/// If this variable is true you must assume no extensions are supported.
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
pub downgraded: bool,
|
pub downgraded: bool,
|
||||||
/// Extensions that are supported by both sides.
|
/// Extensions that are supported by both sides.
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||||
actor_tx: mpsc::Sender<WsEvent>,
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
actor_exited: Arc<AtomicBool>,
|
actor_exited: Arc<AtomicBool>,
|
||||||
|
@ -418,12 +416,10 @@ impl ClientMux {
|
||||||
(Vec::new(), None, true)
|
(Vec::new(), None, true)
|
||||||
};
|
};
|
||||||
|
|
||||||
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
|
|
||||||
|
|
||||||
let mux_result = MuxInner::new_client(
|
let mux_result = MuxInner::new_client(
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
tx.clone(),
|
tx.clone(),
|
||||||
supported_extensions,
|
supported_extensions.clone(),
|
||||||
packet.buffer_remaining,
|
packet.buffer_remaining,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -431,7 +427,7 @@ impl ClientMux {
|
||||||
Self {
|
Self {
|
||||||
actor_tx: mux_result.actor_tx,
|
actor_tx: mux_result.actor_tx,
|
||||||
downgraded,
|
downgraded,
|
||||||
supported_extension_ids,
|
supported_extensions,
|
||||||
tx,
|
tx,
|
||||||
actor_exited: mux_result.actor_exited,
|
actor_exited: mux_result.actor_exited,
|
||||||
},
|
},
|
||||||
|
@ -454,9 +450,9 @@ impl ClientMux {
|
||||||
}
|
}
|
||||||
if stream_type == StreamType::Udp
|
if stream_type == StreamType::Udp
|
||||||
&& !self
|
&& !self
|
||||||
.supported_extension_ids
|
.supported_extensions
|
||||||
.iter()
|
.iter()
|
||||||
.any(|x| *x == UdpProtocolExtension::ID)
|
.any(|x| x.get_id() == UdpProtocolExtension::ID)
|
||||||
{
|
{
|
||||||
return Err(WispError::ExtensionsNotSupported(vec![
|
return Err(WispError::ExtensionsNotSupported(vec![
|
||||||
UdpProtocolExtension::ID,
|
UdpProtocolExtension::ID,
|
||||||
|
@ -532,7 +528,7 @@ where
|
||||||
) -> Result<(ClientMux, F), WispError> {
|
) -> Result<(ClientMux, F), WispError> {
|
||||||
let mut unsupported_extensions = Vec::new();
|
let mut unsupported_extensions = Vec::new();
|
||||||
for extension in extensions {
|
for extension in extensions {
|
||||||
if !self.0.supported_extension_ids.contains(extension) {
|
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
|
||||||
unsupported_extensions.push(*extension);
|
unsupported_extensions.push(*extension);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue