mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
clean up muxresult types
This commit is contained in:
parent
fc63298397
commit
c8de5524b4
4 changed files with 81 additions and 114 deletions
|
@ -27,9 +27,9 @@ pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||||
/// The role of the multiplexor.
|
/// The role of the multiplexor.
|
||||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
/// Client side, can create new channels to proxy.
|
/// Client side, can create new streams.
|
||||||
Client,
|
Client,
|
||||||
/// Server side, can listen for channels to proxy.
|
/// Server side, can listen for streams created by the client.
|
||||||
Server,
|
Server,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,12 @@ use futures::channel::oneshot;
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||||
inner::{MuxInner, WsEvent},
|
inner::{MuxInner, WsEvent},
|
||||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload},
|
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
||||||
WispError,
|
WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
|
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
|
||||||
|
|
||||||
/// Client side multiplexor.
|
/// Client side multiplexor.
|
||||||
pub struct ClientMux {
|
pub struct ClientMux {
|
||||||
|
@ -42,7 +42,7 @@ impl ClientMux {
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
tx: W,
|
tx: W,
|
||||||
wisp_v2: Option<WispV2Extensions>,
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
where
|
where
|
||||||
R: WebSocketRead + Send,
|
R: WebSocketRead + Send,
|
||||||
W: WebSocketWrite + Send + 'static,
|
W: WebSocketWrite + Send + 'static,
|
||||||
|
@ -78,7 +78,7 @@ impl ClientMux {
|
||||||
packet.buffer_remaining,
|
packet.buffer_remaining,
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(ClientMuxResult(
|
Ok(MuxResult(
|
||||||
Self {
|
Self {
|
||||||
actor_tx: mux_result.actor_tx,
|
actor_tx: mux_result.actor_tx,
|
||||||
downgraded,
|
downgraded,
|
||||||
|
@ -174,50 +174,13 @@ impl Drop for ClientMux {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of `ClientMux::new`.
|
impl Multiplexor for ClientMux {
|
||||||
pub struct ClientMuxResult<F>(ClientMux, F)
|
fn has_extension(&self, extension_id: u8) -> bool {
|
||||||
where
|
self.supported_extensions
|
||||||
F: Future<Output = Result<(), WispError>> + Send;
|
.iter()
|
||||||
|
.any(|x| x.get_id() == extension_id)
|
||||||
impl<F> ClientMuxResult<F>
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send,
|
|
||||||
{
|
|
||||||
/// Require no protocol extensions.
|
|
||||||
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
|
|
||||||
(self.0, self.1)
|
|
||||||
}
|
}
|
||||||
|
async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||||
/// Require protocol extensions by their ID.
|
self.close_with_reason(reason).await
|
||||||
pub async fn with_required_extensions(
|
|
||||||
self,
|
|
||||||
extensions: &[u8],
|
|
||||||
) -> Result<(ClientMux, F), WispError> {
|
|
||||||
let mut unsupported_extensions = Vec::new();
|
|
||||||
for extension in extensions {
|
|
||||||
if !self
|
|
||||||
.0
|
|
||||||
.supported_extensions
|
|
||||||
.iter()
|
|
||||||
.any(|x| x.get_id() == *extension)
|
|
||||||
{
|
|
||||||
unsupported_extensions.push(*extension);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if unsupported_extensions.is_empty() {
|
|
||||||
Ok((self.0, self.1))
|
|
||||||
} else {
|
|
||||||
self.0
|
|
||||||
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
|
||||||
.await?;
|
|
||||||
self.1.await?;
|
|
||||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
|
||||||
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
|
||||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,15 +2,15 @@ mod client;
|
||||||
mod server;
|
mod server;
|
||||||
use std::{future::Future, pin::Pin, time::Duration};
|
use std::{future::Future, pin::Pin, time::Duration};
|
||||||
|
|
||||||
pub use client::{ClientMux, ClientMuxResult};
|
pub use client::ClientMux;
|
||||||
use futures::{select, FutureExt};
|
use futures::{select, FutureExt};
|
||||||
use futures_timer::Delay;
|
use futures_timer::Delay;
|
||||||
pub use server::{ServerMux, ServerMuxResult};
|
pub use server::ServerMux;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||||
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
|
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
|
||||||
Packet, PacketType, Role, WispError,
|
CloseReason, Packet, PacketType, Role, WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
async fn maybe_wisp_v2<R>(
|
async fn maybe_wisp_v2<R>(
|
||||||
|
@ -67,15 +67,61 @@ async fn send_info_packet(
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Multiplexor {
|
||||||
|
fn has_extension(&self, extension_id: u8) -> bool;
|
||||||
|
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of creating a multiplexor. Helps require protocol extensions.
|
||||||
|
#[allow(private_bounds)]
|
||||||
|
pub struct MuxResult<M, F>(M, F)
|
||||||
|
where
|
||||||
|
M: Multiplexor,
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send;
|
||||||
|
|
||||||
|
#[allow(private_bounds)]
|
||||||
|
impl<M, F> MuxResult<M, F>
|
||||||
|
where
|
||||||
|
M: Multiplexor,
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send,
|
||||||
|
{
|
||||||
|
/// Require no protocol extensions.
|
||||||
|
pub fn with_no_required_extensions(self) -> (M, F) {
|
||||||
|
(self.0, self.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
||||||
|
/// extensions are not supported.
|
||||||
|
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(M, F), WispError> {
|
||||||
|
let mut unsupported_extensions = Vec::new();
|
||||||
|
for extension in extensions {
|
||||||
|
if !self.0.has_extension(*extension) {
|
||||||
|
unsupported_extensions.push(*extension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if unsupported_extensions.is_empty() {
|
||||||
|
Ok((self.0, self.1))
|
||||||
|
} else {
|
||||||
|
self.0.exit(CloseReason::ExtensionsIncompatible).await?;
|
||||||
|
self.1.await?;
|
||||||
|
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||||
|
pub async fn with_udp_extension_required(self) -> Result<(M, F), WispError> {
|
||||||
|
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>;
|
||||||
|
type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder];
|
||||||
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
||||||
pub struct WispV2Extensions {
|
pub struct WispV2Extensions {
|
||||||
builders: Vec<AnyProtocolExtensionBuilder>,
|
builders: Vec<AnyProtocolExtensionBuilder>,
|
||||||
closure: Box<
|
closure: Box<dyn Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send>,
|
||||||
dyn Fn(
|
|
||||||
&mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
|
||||||
+ Send,
|
|
||||||
>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WispV2Extensions {
|
impl WispV2Extensions {
|
||||||
|
@ -90,11 +136,7 @@ impl WispV2Extensions {
|
||||||
/// Create a Wisp V2 settings struct with some middleware.
|
/// Create a Wisp V2 settings struct with some middleware.
|
||||||
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
||||||
where
|
where
|
||||||
C: Fn(
|
C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static,
|
||||||
&mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
|
||||||
+ Send
|
|
||||||
+ 'static,
|
|
||||||
{
|
{
|
||||||
Self {
|
Self {
|
||||||
builders,
|
builders,
|
||||||
|
|
|
@ -11,13 +11,13 @@ use flume as mpsc;
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
extensions::AnyProtocolExtension,
|
||||||
inner::{MuxInner, WsEvent},
|
inner::{MuxInner, WsEvent},
|
||||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
|
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
|
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
|
||||||
|
|
||||||
/// Server-side multiplexor.
|
/// Server-side multiplexor.
|
||||||
pub struct ServerMux {
|
pub struct ServerMux {
|
||||||
|
@ -44,7 +44,7 @@ impl ServerMux {
|
||||||
tx: W,
|
tx: W,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
wisp_v2: Option<WispV2Extensions>,
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
where
|
where
|
||||||
R: WebSocketRead + Send,
|
R: WebSocketRead + Send,
|
||||||
W: WebSocketWrite + Send + 'static,
|
W: WebSocketWrite + Send + 'static,
|
||||||
|
@ -74,7 +74,7 @@ impl ServerMux {
|
||||||
buffer_size,
|
buffer_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(ServerMuxResult(
|
Ok(MuxResult(
|
||||||
Self {
|
Self {
|
||||||
muxstream_recv,
|
muxstream_recv,
|
||||||
actor_tx: mux_result.actor_tx,
|
actor_tx: mux_result.actor_tx,
|
||||||
|
@ -175,51 +175,13 @@ impl Drop for ServerMux {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of `ServerMux::new`.
|
impl Multiplexor for ServerMux {
|
||||||
pub struct ServerMuxResult<F>(ServerMux, F)
|
fn has_extension(&self, extension_id: u8) -> bool {
|
||||||
where
|
self.supported_extensions
|
||||||
F: Future<Output = Result<(), WispError>> + Send;
|
.iter()
|
||||||
|
.any(|x| x.get_id() == extension_id)
|
||||||
impl<F> ServerMuxResult<F>
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send,
|
|
||||||
{
|
|
||||||
/// Require no protocol extensions.
|
|
||||||
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
|
|
||||||
(self.0, self.1)
|
|
||||||
}
|
}
|
||||||
|
async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||||
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
self.close_with_reason(reason).await
|
||||||
/// extensions are not supported.
|
|
||||||
pub async fn with_required_extensions(
|
|
||||||
self,
|
|
||||||
extensions: &[u8],
|
|
||||||
) -> Result<(ServerMux, F), WispError> {
|
|
||||||
let mut unsupported_extensions = Vec::new();
|
|
||||||
for extension in extensions {
|
|
||||||
if !self
|
|
||||||
.0
|
|
||||||
.supported_extensions
|
|
||||||
.iter()
|
|
||||||
.any(|x| x.get_id() == *extension)
|
|
||||||
{
|
|
||||||
unsupported_extensions.push(*extension);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if unsupported_extensions.is_empty() {
|
|
||||||
Ok((self.0, self.1))
|
|
||||||
} else {
|
|
||||||
self.0
|
|
||||||
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
|
||||||
.await?;
|
|
||||||
self.1.await?;
|
|
||||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
|
||||||
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
|
|
||||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue