mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
add middleware to wispv2 handshake
This commit is contained in:
parent
d6f1a8da43
commit
7fdacb2623
6 changed files with 254 additions and 114 deletions
|
@ -155,7 +155,7 @@ impl dyn ProtocolExtension {
|
|||
}
|
||||
|
||||
/// Trait to build a Wisp protocol extension from a payload.
|
||||
pub trait ProtocolExtensionBuilder {
|
||||
pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
|
||||
/// Get the protocol extension ID.
|
||||
///
|
||||
/// Used to decide whether this builder should be used.
|
||||
|
@ -170,4 +170,81 @@ pub trait ProtocolExtensionBuilder {
|
|||
|
||||
/// Build a protocol extension to send to the other side.
|
||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Do not override.
|
||||
fn __internal_type_id(&self) -> TypeId {
|
||||
TypeId::of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn ProtocolExtensionBuilder {
|
||||
fn __is<T: ProtocolExtensionBuilder>(&self) -> bool {
|
||||
let t = TypeId::of::<T>();
|
||||
self.__internal_type_id() == t
|
||||
}
|
||||
|
||||
fn __downcast<T: ProtocolExtensionBuilder>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
|
||||
if self.__is::<T>() {
|
||||
unsafe {
|
||||
let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self);
|
||||
Ok(Box::from_raw(raw as *mut T))
|
||||
}
|
||||
} else {
|
||||
Err(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
|
||||
if self.__is::<T>() {
|
||||
unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) }
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
|
||||
if self.__is::<T>() {
|
||||
unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) }
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Type-erased protocol extension builder.
|
||||
pub struct AnyProtocolExtensionBuilder(Box<dyn ProtocolExtensionBuilder>);
|
||||
|
||||
impl AnyProtocolExtensionBuilder {
|
||||
/// Create a new type-erased protocol extension builder.
|
||||
pub fn new<T: ProtocolExtensionBuilder>(extension: T) -> Self {
|
||||
Self(Box::new(extension))
|
||||
}
|
||||
|
||||
/// Downcast the protocol extension builder.
|
||||
pub fn downcast<T: ProtocolExtensionBuilder>(self) -> Result<Box<T>, Self> {
|
||||
self.0.__downcast().map_err(Self)
|
||||
}
|
||||
|
||||
/// Downcast the protocol extension builder.
|
||||
pub fn downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
|
||||
self.0.__downcast_ref()
|
||||
}
|
||||
|
||||
/// Downcast the protocol extension builder.
|
||||
pub fn downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
|
||||
self.0.__downcast_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for AnyProtocolExtensionBuilder {
|
||||
type Target = dyn ProtocolExtensionBuilder;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for AnyProtocolExtensionBuilder {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0.deref_mut()
|
||||
}
|
||||
}
|
||||
|
|
104
wisp/src/lib.rs
104
wisp/src/lib.rs
|
@ -19,13 +19,17 @@ pub mod ws;
|
|||
|
||||
pub use crate::{packet::*, stream::*};
|
||||
|
||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
|
||||
use flume as mpsc;
|
||||
use futures::{channel::oneshot, Future};
|
||||
use inner::{MuxInner, WsEvent};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
use std::{
|
||||
ops::DerefMut,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
||||
|
||||
|
@ -173,7 +177,7 @@ async fn maybe_wisp_v2<R>(
|
|||
read: &mut R,
|
||||
write: &LockedWebSocketWrite,
|
||||
role: Role,
|
||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -205,7 +209,7 @@ where
|
|||
|
||||
async fn send_info_packet(
|
||||
write: &LockedWebSocketWrite,
|
||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<(), WispError> {
|
||||
write
|
||||
.write_frame(
|
||||
|
@ -220,6 +224,42 @@ async fn send_info_packet(
|
|||
.await
|
||||
}
|
||||
|
||||
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
||||
pub struct WispV2Extensions {
|
||||
builders: Vec<AnyProtocolExtensionBuilder>,
|
||||
closure: Box<
|
||||
dyn Fn(
|
||||
&mut [AnyProtocolExtensionBuilder],
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||
+ Send,
|
||||
>,
|
||||
}
|
||||
|
||||
impl WispV2Extensions {
|
||||
/// Create a Wisp V2 settings struct with no middleware.
|
||||
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
|
||||
Self {
|
||||
builders,
|
||||
closure: Box::new(|_| Box::pin(async { Ok(()) })),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Wisp V2 settings struct with some middleware.
|
||||
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
||||
where
|
||||
C: Fn(
|
||||
&mut [AnyProtocolExtensionBuilder],
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
Self {
|
||||
builders,
|
||||
closure: Box::new(closure),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
pub struct ServerMux {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
|
@ -237,14 +277,14 @@ pub struct ServerMux {
|
|||
impl ServerMux {
|
||||
/// Create a new server-side multiplexor.
|
||||
///
|
||||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn create<R, W>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
buffer_size: u32,
|
||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
||||
wisp_v2: Option<WispV2Extensions>,
|
||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -256,13 +296,17 @@ impl ServerMux {
|
|||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(mut builders) = extension_builders {
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = wisp_v2
|
||||
{
|
||||
send_info_packet(&tx, builders.deref_mut()).await?;
|
||||
(closure)(builders.deref_mut()).await?;
|
||||
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||
AppendingWebSocketRead(extra_packet, rx),
|
||||
|
@ -424,13 +468,13 @@ pub struct ClientMux {
|
|||
impl ClientMux {
|
||||
/// Create a new client side multiplexor.
|
||||
///
|
||||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn create<R, W>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
||||
wisp_v2: Option<WispV2Extensions>,
|
||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -444,17 +488,21 @@ impl ClientMux {
|
|||
}
|
||||
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(mut builders) = extension_builders {
|
||||
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
||||
// if not downgraded
|
||||
if !res.2 {
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
}
|
||||
res
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = wisp_v2
|
||||
{
|
||||
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
||||
// if not downgraded
|
||||
if !res.2 {
|
||||
(closure)(&mut builders).await?;
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
}
|
||||
res
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
||||
let mux_result = MuxInner::new_client(
|
||||
AppendingWebSocketRead(extra_packet, rx),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
||||
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||
Role, WispError, WISP_VERSION,
|
||||
};
|
||||
|
@ -431,10 +431,57 @@ impl<'a> Packet<'a> {
|
|||
})
|
||||
}
|
||||
|
||||
fn parse_info(
|
||||
mut bytes: Payload<'a>,
|
||||
role: Role,
|
||||
extension_builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<Self, WispError> {
|
||||
// packet type is already read by code that calls this
|
||||
if bytes.remaining() < 4 + 2 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if bytes.get_u32_le() != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
|
||||
let version = WispVersion {
|
||||
major: bytes.get_u8(),
|
||||
minor: bytes.get_u8(),
|
||||
};
|
||||
|
||||
if version.major != WISP_VERSION.major {
|
||||
return Err(WispError::IncompatibleProtocolVersion);
|
||||
}
|
||||
|
||||
let mut extensions = Vec::new();
|
||||
|
||||
while bytes.remaining() > 4 {
|
||||
// We have some extensions
|
||||
let id = bytes.get_u8();
|
||||
let length = usize::try_from(bytes.get_u32_le())?;
|
||||
if bytes.remaining() < length {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
|
||||
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version,
|
||||
extensions,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_parse_info(
|
||||
frame: Frame<'a>,
|
||||
role: Role,
|
||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
extension_builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<Self, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
|
@ -504,53 +551,6 @@ impl<'a> Packet<'a> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_info(
|
||||
mut bytes: Payload<'a>,
|
||||
role: Role,
|
||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
// packet type is already read by code that calls this
|
||||
if bytes.remaining() < 4 + 2 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if bytes.get_u32_le() != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
|
||||
let version = WispVersion {
|
||||
major: bytes.get_u8(),
|
||||
minor: bytes.get_u8(),
|
||||
};
|
||||
|
||||
if version.major != WISP_VERSION.major {
|
||||
return Err(WispError::IncompatibleProtocolVersion);
|
||||
}
|
||||
|
||||
let mut extensions = Vec::new();
|
||||
|
||||
while bytes.remaining() > 4 {
|
||||
// We have some extensions
|
||||
let id = bytes.get_u8();
|
||||
let length = usize::try_from(bytes.get_u32_le())?;
|
||||
if bytes.remaining() < length {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
|
||||
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version,
|
||||
extensions,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for Packet<'_> {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue