helper traits for protocol extensions

This commit is contained in:
Toshit Chawda 2024-10-25 22:17:39 -07:00
parent 41f2139eb1
commit 1a8773f801
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
2 changed files with 59 additions and 19 deletions

View file

@ -264,3 +264,51 @@ impl<T: ProtocolExtensionBuilder> From<T> for AnyProtocolExtensionBuilder {
Self::new(value)
}
}
/// Helper functions for `Vec<AnyProtocolExtensionBuilder>`
pub trait ProtocolExtensionBuilderVecExt {
/// Returns a reference to the protocol extension builder specified, if it was found.
fn find_extension<T: ProtocolExtensionBuilder>(&self) -> Option<&T>;
/// Returns a mutable reference to the protocol extension builder specified, if it was found.
fn find_extension_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T>;
/// Removes any instances of the protocol extension builder specified, if it was found.
fn remove_extension<T: ProtocolExtensionBuilder>(&mut self);
}
impl ProtocolExtensionBuilderVecExt for Vec<AnyProtocolExtensionBuilder> {
fn find_extension<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
self.iter().find_map(|x| x.downcast_ref::<T>())
}
fn find_extension_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
self.iter_mut().find_map(|x| x.downcast_mut::<T>())
}
fn remove_extension<T: ProtocolExtensionBuilder>(&mut self) {
self.retain(|x| x.downcast_ref::<T>().is_none());
}
}
/// Helper functions for `Vec<AnyProtocolExtension>`
pub trait ProtocolExtensionVecExt {
/// Returns a reference to the protocol extension specified, if it was found.
fn find_extension<T: ProtocolExtension>(&self) -> Option<&T>;
/// Returns a mutable reference to the protocol extension specified, if it was found.
fn find_extension_mut<T: ProtocolExtension>(&mut self) -> Option<&mut T>;
/// Removes any instances of the protocol extension specified, if it was found.
fn remove_extension<T: ProtocolExtension>(&mut self);
}
impl ProtocolExtensionVecExt for Vec<AnyProtocolExtension> {
fn find_extension<T: ProtocolExtension>(&self) -> Option<&T> {
self.iter().find_map(|x| x.downcast_ref::<T>())
}
fn find_extension_mut<T: ProtocolExtension>(&mut self) -> Option<&mut T> {
self.iter_mut().find_map(|x| x.downcast_mut::<T>())
}
fn remove_extension<T: ProtocolExtension>(&mut self) {
self.retain(|x| x.downcast_ref::<T>().is_none());
}
}

View file

@ -123,16 +123,15 @@ where
}
}
/// Wisp V2 middleware closure.
pub type WispV2Middleware = dyn for<'a> Fn(
&'a mut Vec<AnyProtocolExtensionBuilder>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send + 'a>>
+ Send;
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Handshake {
builders: Vec<AnyProtocolExtensionBuilder>,
#[expect(clippy::type_complexity)]
closure: Box<
dyn Fn(
&mut Vec<AnyProtocolExtensionBuilder>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
closure: Box<WispV2Middleware>,
}
impl WispV2Handshake {
@ -145,18 +144,11 @@ impl WispV2Handshake {
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut Vec<AnyProtocolExtensionBuilder>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
pub fn new_with_middleware(
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<WispV2Middleware>,
) -> Self {
Self { builders, closure }
}
/// Add a Wisp V2 extension builder to the settings struct.