add middleware to wispv2 handshake

This commit is contained in:
r58Playz 2024-09-16 23:18:32 -07:00
parent d6f1a8da43
commit 7fdacb2623
6 changed files with 254 additions and 114 deletions

View file

@ -19,13 +19,17 @@ pub mod ws;
pub use crate::{packet::*, stream::*};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, Future};
use inner::{MuxInner, WsEvent};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
use std::{
ops::DerefMut,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
@ -173,7 +177,7 @@ async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
role: Role,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
where
R: ws::WebSocketRead + Send,
@ -205,7 +209,7 @@ where
async fn send_info_packet(
write: &LockedWebSocketWrite,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> {
write
.write_frame(
@ -220,6 +224,42 @@ async fn send_info_packet(
.await
}
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
}
impl WispV2Extensions {
/// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self {
builders,
closure: Box::new(|_| Box::pin(async { Ok(()) })),
}
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
}
}
/// Server-side multiplexor.
pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1.
@ -237,14 +277,14 @@ pub struct ServerMux {
impl ServerMux {
/// Create a new server-side multiplexor.
///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
buffer_size: u32,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
@ -256,13 +296,17 @@ impl ServerMux {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
send_info_packet(&tx, builders.deref_mut()).await?;
(closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
@ -424,13 +468,13 @@ pub struct ClientMux {
impl ClientMux {
/// Create a new client side multiplexor.
///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
@ -444,17 +488,21 @@ impl ClientMux {
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
send_info_packet(&tx, &mut builders).await?;
}
res
} else {
(Vec::new(), None, true)
};
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
(closure)(&mut builders).await?;
send_info_packet(&tx, &mut builders).await?;
}
res
} else {
(Vec::new(), None, true)
};
let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx),