clean up muxresult types

This commit is contained in:
Toshit Chawda 2024-10-23 23:23:55 -07:00
parent fc63298397
commit c8de5524b4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 81 additions and 114 deletions

View file

@ -27,9 +27,9 @@ pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
/// The role of the multiplexor. /// The role of the multiplexor.
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum Role { pub enum Role {
/// Client side, can create new channels to proxy. /// Client side, can create new streams.
Client, Client,
/// Server side, can listen for channels to proxy. /// Server side, can listen for streams created by the client.
Server, Server,
} }

View file

@ -12,12 +12,12 @@ use futures::channel::oneshot;
use crate::{ use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
inner::{MuxInner, WsEvent}, inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
WispError, WispError,
}; };
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
/// Client side multiplexor. /// Client side multiplexor.
pub struct ClientMux { pub struct ClientMux {
@ -42,7 +42,7 @@ impl ClientMux {
mut rx: R, mut rx: R,
tx: W, tx: W,
wisp_v2: Option<WispV2Extensions>, wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static, W: WebSocketWrite + Send + 'static,
@ -78,7 +78,7 @@ impl ClientMux {
packet.buffer_remaining, packet.buffer_remaining,
); );
Ok(ClientMuxResult( Ok(MuxResult(
Self { Self {
actor_tx: mux_result.actor_tx, actor_tx: mux_result.actor_tx,
downgraded, downgraded,
@ -174,50 +174,13 @@ impl Drop for ClientMux {
} }
} }
/// Result of `ClientMux::new`. impl Multiplexor for ClientMux {
pub struct ClientMuxResult<F>(ClientMux, F) fn has_extension(&self, extension_id: u8) -> bool {
where self.supported_extensions
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ClientMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter() .iter()
.any(|x| x.get_id() == *extension) .any(|x| x.get_id() == extension_id)
{
unsupported_extensions.push(*extension);
} }
} async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
if unsupported_extensions.is_empty() { self.close_with_reason(reason).await
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
} }
} }

View file

@ -2,15 +2,15 @@ mod client;
mod server; mod server;
use std::{future::Future, pin::Pin, time::Duration}; use std::{future::Future, pin::Pin, time::Duration};
pub use client::{ClientMux, ClientMuxResult}; pub use client::ClientMux;
use futures::{select, FutureExt}; use futures::{select, FutureExt};
use futures_timer::Delay; use futures_timer::Delay;
pub use server::{ServerMux, ServerMuxResult}; pub use server::ServerMux;
use crate::{ use crate::{
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{Frame, LockedWebSocketWrite, WebSocketRead}, ws::{Frame, LockedWebSocketWrite, WebSocketRead},
Packet, PacketType, Role, WispError, CloseReason, Packet, PacketType, Role, WispError,
}; };
async fn maybe_wisp_v2<R>( async fn maybe_wisp_v2<R>(
@ -67,15 +67,61 @@ async fn send_info_packet(
.await .await
} }
trait Multiplexor {
fn has_extension(&self, extension_id: u8) -> bool;
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
}
/// Result of creating a multiplexor. Helps require protocol extensions.
#[allow(private_bounds)]
pub struct MuxResult<M, F>(M, F)
where
M: Multiplexor,
F: Future<Output = Result<(), WispError>> + Send;
#[allow(private_bounds)]
impl<M, F> MuxResult<M, F>
where
M: Multiplexor,
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (M, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(M, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.has_extension(*extension) {
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.exit(CloseReason::ExtensionsIncompatible).await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(M, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}
type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>;
type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder];
/// Wisp V2 handshake and protocol extension settings wrapper struct. /// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions { pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>, builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box< closure: Box<dyn Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send>,
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
} }
impl WispV2Extensions { impl WispV2Extensions {
@ -90,11 +136,7 @@ impl WispV2Extensions {
/// Create a Wisp V2 settings struct with some middleware. /// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where where
C: Fn( C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static,
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{ {
Self { Self {
builders, builders,

View file

@ -11,13 +11,13 @@ use flume as mpsc;
use futures::channel::oneshot; use futures::channel::oneshot;
use crate::{ use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, extensions::AnyProtocolExtension,
inner::{MuxInner, WsEvent}, inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError, CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
}; };
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions}; use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
/// Server-side multiplexor. /// Server-side multiplexor.
pub struct ServerMux { pub struct ServerMux {
@ -44,7 +44,7 @@ impl ServerMux {
tx: W, tx: W,
buffer_size: u32, buffer_size: u32,
wisp_v2: Option<WispV2Extensions>, wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static, W: WebSocketWrite + Send + 'static,
@ -74,7 +74,7 @@ impl ServerMux {
buffer_size, buffer_size,
); );
Ok(ServerMuxResult( Ok(MuxResult(
Self { Self {
muxstream_recv, muxstream_recv,
actor_tx: mux_result.actor_tx, actor_tx: mux_result.actor_tx,
@ -175,51 +175,13 @@ impl Drop for ServerMux {
} }
} }
/// Result of `ServerMux::new`. impl Multiplexor for ServerMux {
pub struct ServerMuxResult<F>(ServerMux, F) fn has_extension(&self, extension_id: u8) -> bool {
where self.supported_extensions
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter() .iter()
.any(|x| x.get_id() == *extension) .any(|x| x.get_id() == extension_id)
{
unsupported_extensions.push(*extension);
} }
} async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
if unsupported_extensions.is_empty() { self.close_with_reason(reason).await
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
} }
} }