mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
wisp-mux... SEVEN!!
This commit is contained in:
parent
194ad4e5c8
commit
3f381d6b39
53 changed files with 3721 additions and 4821 deletions
|
@ -3,18 +3,26 @@ pub(crate) mod inner;
|
|||
mod server;
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
pub use client::ClientMux;
|
||||
pub use server::ServerMux;
|
||||
use futures::SinkExt;
|
||||
use inner::{MultiplexorActor, MuxInner, WsEvent};
|
||||
|
||||
pub use client::ClientImpl;
|
||||
pub use server::ServerImpl;
|
||||
|
||||
pub type ServerMux<W> = Multiplexor<ServerImpl<W>, W>;
|
||||
pub type ClientMux<W> = Multiplexor<ClientImpl, W>;
|
||||
|
||||
use crate::{
|
||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||
ws::{LockedWebSocketWrite, WebSocketWrite},
|
||||
CloseReason, Packet, PacketType, Role, WispError,
|
||||
packet::{CloseReason, InfoPacket, Packet, PacketType},
|
||||
ws::{WebSocketRead, WebSocketWrite},
|
||||
LockedWebSocketWrite, LockedWebSocketWriteGuard, Role, WispError, WISP_VERSION,
|
||||
};
|
||||
|
||||
struct WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind,
|
||||
downgraded: bool,
|
||||
buffer_size: u32,
|
||||
}
|
||||
|
||||
enum WispHandshakeResultKind {
|
||||
|
@ -22,7 +30,7 @@ enum WispHandshakeResultKind {
|
|||
extensions: Vec<AnyProtocolExtension>,
|
||||
},
|
||||
V1 {
|
||||
frame: Option<Packet<'static>>,
|
||||
packet: Option<Packet<'static>>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -30,35 +38,56 @@ impl WispHandshakeResultKind {
|
|||
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Packet<'static>>) {
|
||||
match self {
|
||||
Self::V2 { extensions } => (extensions, None),
|
||||
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
|
||||
Self::V1 { packet } => (vec![UdpProtocolExtension.into()], packet),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_info_packet<W: WebSocketWrite>(
|
||||
write: &LockedWebSocketWrite<W>,
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
async fn handle_handshake<R: WebSocketRead, W: WebSocketWrite>(
|
||||
read: &mut R,
|
||||
write: &mut LockedWebSocketWrite<W>,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
) -> Result<(), WispError> {
|
||||
write
|
||||
.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter_mut()
|
||||
.map(|x| x.build_to_extension(Role::Server))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
write.lock().await;
|
||||
let mut handle = write.get_handle();
|
||||
for extension in extensions {
|
||||
extension.handle_handshake(read, &mut handle).await?;
|
||||
}
|
||||
drop(handle);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_continue_packet(packet: &Packet<'_>) -> Result<u32, WispError> {
|
||||
async fn send_info_packet<W: WebSocketWrite>(
|
||||
write: &mut LockedWebSocketWrite<W>,
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
role: Role,
|
||||
) -> Result<(), WispError> {
|
||||
let extensions = builders
|
||||
.iter_mut()
|
||||
.map(|x| x.build_to_extension(role))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let packet = InfoPacket {
|
||||
version: WISP_VERSION,
|
||||
extensions,
|
||||
}
|
||||
.encode();
|
||||
|
||||
write.lock().await;
|
||||
let ret = write.get().send(packet).await;
|
||||
write.unlock();
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
fn validate_continue_packet(packet: &Packet) -> Result<u32, WispError> {
|
||||
if packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
return Err(WispError::InvalidStreamId(packet.stream_id));
|
||||
}
|
||||
|
||||
let PacketType::Continue(continue_packet) = packet.packet_type else {
|
||||
return Err(WispError::InvalidPacketType);
|
||||
return Err(WispError::InvalidPacketType(packet.packet_type.get_type()));
|
||||
};
|
||||
|
||||
Ok(continue_packet.buffer_remaining)
|
||||
|
@ -75,35 +104,185 @@ fn get_supported_extensions(
|
|||
.collect()
|
||||
}
|
||||
|
||||
trait Multiplexor {
|
||||
fn has_extension(&self, extension_id: u8) -> bool;
|
||||
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
|
||||
trait MultiplexorImpl<W: WebSocketWrite> {
|
||||
type Actor: MultiplexorActor<W> + 'static;
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
&mut self,
|
||||
rx: &mut R,
|
||||
tx: &mut LockedWebSocketWrite<W>,
|
||||
v2: Option<WispV2Handshake>,
|
||||
) -> Result<WispHandshakeResult, WispError>;
|
||||
|
||||
async fn handle_error(
|
||||
&mut self,
|
||||
err: WispError,
|
||||
tx: &mut LockedWebSocketWrite<W>,
|
||||
) -> Result<WispError, WispError>;
|
||||
}
|
||||
|
||||
#[expect(private_bounds)]
|
||||
pub struct Multiplexor<M: MultiplexorImpl<W>, W: WebSocketWrite> {
|
||||
mux: M,
|
||||
|
||||
downgraded: bool,
|
||||
supported_extensions: Vec<AnyProtocolExtension>,
|
||||
|
||||
actor_tx: flume::Sender<WsEvent<W>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
}
|
||||
|
||||
#[expect(private_bounds)]
|
||||
impl<M: MultiplexorImpl<W>, W: WebSocketWrite> Multiplexor<M, W> {
|
||||
async fn create<R>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
wisp_v2: Option<WispV2Handshake>,
|
||||
mut muxer: M,
|
||||
actor: M::Actor,
|
||||
) -> Result<MuxResult<M, W>, WispError>
|
||||
where
|
||||
R: WebSocketRead,
|
||||
{
|
||||
let mut tx = LockedWebSocketWrite::new(tx);
|
||||
|
||||
let ret = async {
|
||||
let handshake_result = muxer.handshake(&mut rx, &mut tx, wisp_v2).await?;
|
||||
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||
|
||||
Ok((
|
||||
MuxInner::new(
|
||||
rx,
|
||||
tx.clone(),
|
||||
actor,
|
||||
extra_packet,
|
||||
extensions.clone(),
|
||||
handshake_result.buffer_size,
|
||||
),
|
||||
handshake_result.downgraded,
|
||||
extensions,
|
||||
))
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok((mux_result, downgraded, extensions)) => Ok(MuxResult(
|
||||
Self {
|
||||
mux: muxer,
|
||||
|
||||
downgraded,
|
||||
supported_extensions: extensions,
|
||||
|
||||
actor_tx: mux_result.actor_tx,
|
||||
tx,
|
||||
},
|
||||
Box::pin(mux_result.mux.into_future()),
|
||||
)),
|
||||
Err(x) => Err(muxer.handle_error(x, &mut tx).await?),
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
pub fn was_downgraded(&self) -> bool {
|
||||
self.downgraded
|
||||
}
|
||||
|
||||
/// Get a shared reference to the extensions that are supported by both sides.
|
||||
pub fn get_extensions(&self) -> &[AnyProtocolExtension] {
|
||||
&self.supported_extensions
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the extensions that are supported by both sides.
|
||||
pub fn get_extensions_mut(&mut self) -> &mut [AnyProtocolExtension] {
|
||||
&mut self.supported_extensions
|
||||
}
|
||||
|
||||
/// Get a `Vec` of all extension IDs that are supported by both sides.
|
||||
pub fn get_extension_ids(&self) -> Vec<u8> {
|
||||
self.supported_extensions
|
||||
.iter()
|
||||
.map(|x| x.get_id())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a locked guard to the write half of the websocket.
|
||||
pub async fn lock_ws(&self) -> Result<LockedWebSocketWriteGuard<W>, WispError> {
|
||||
if self.actor_tx.is_disconnected() {
|
||||
Err(WispError::WsImplSocketClosed)
|
||||
} else {
|
||||
let mut cloned = self.tx.clone();
|
||||
cloned.lock().await;
|
||||
Ok(cloned.get_guard())
|
||||
}
|
||||
}
|
||||
|
||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
self.actor_tx
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
|
||||
/// Close all streams.
|
||||
///
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.close_internal(None).await
|
||||
}
|
||||
|
||||
/// Close all streams and send a close reason on stream ID 0.
|
||||
///
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
self.close_internal(Some(reason)).await
|
||||
}
|
||||
|
||||
/* TODO
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
is_closed: self.actor_exited.clone(),
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
pub type MultiplexorActorFuture = Pin<Box<dyn Future<Output = Result<(), WispError>> + Send>>;
|
||||
|
||||
/// Result of creating a multiplexor. Helps require protocol extensions.
|
||||
#[expect(private_bounds)]
|
||||
pub struct MuxResult<M, F>(M, F)
|
||||
pub struct MuxResult<M, W>(Multiplexor<M, W>, MultiplexorActorFuture)
|
||||
where
|
||||
M: Multiplexor,
|
||||
F: Future<Output = Result<(), WispError>> + Send;
|
||||
M: MultiplexorImpl<W>,
|
||||
W: WebSocketWrite;
|
||||
|
||||
#[expect(private_bounds)]
|
||||
impl<M, F> MuxResult<M, F>
|
||||
impl<M, W> MuxResult<M, W>
|
||||
where
|
||||
M: Multiplexor,
|
||||
F: Future<Output = Result<(), WispError>> + Send,
|
||||
M: MultiplexorImpl<W>,
|
||||
W: WebSocketWrite,
|
||||
{
|
||||
/// Require no protocol extensions.
|
||||
pub fn with_no_required_extensions(self) -> (M, F) {
|
||||
pub fn with_no_required_extensions(self) -> (Multiplexor<M, W>, MultiplexorActorFuture) {
|
||||
(self.0, self.1)
|
||||
}
|
||||
|
||||
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
||||
/// extensions are not supported.
|
||||
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(M, F), WispError> {
|
||||
pub async fn with_required_extensions(
|
||||
self,
|
||||
extensions: &[u8],
|
||||
) -> Result<(Multiplexor<M, W>, MultiplexorActorFuture), WispError> {
|
||||
let mut unsupported_extensions = Vec::new();
|
||||
let supported_extensions = self.0.get_extensions();
|
||||
|
||||
for extension in extensions {
|
||||
if !self.0.has_extension(*extension) {
|
||||
if !supported_extensions
|
||||
.iter()
|
||||
.any(|x| x.get_id() == *extension)
|
||||
{
|
||||
unsupported_extensions.push(*extension);
|
||||
}
|
||||
}
|
||||
|
@ -111,14 +290,18 @@ where
|
|||
if unsupported_extensions.is_empty() {
|
||||
Ok((self.0, self.1))
|
||||
} else {
|
||||
self.0.exit(CloseReason::ExtensionsIncompatible).await?;
|
||||
self.0
|
||||
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
||||
.await?;
|
||||
self.1.await?;
|
||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||
}
|
||||
}
|
||||
|
||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||
pub async fn with_udp_extension_required(self) -> Result<(M, F), WispError> {
|
||||
pub async fn with_udp_extension_required(
|
||||
self,
|
||||
) -> Result<(Multiplexor<M, W>, MultiplexorActorFuture), WispError> {
|
||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||
.await
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue