mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
add middleware to wispv2 handshake
This commit is contained in:
parent
d6f1a8da43
commit
7fdacb2623
6 changed files with 254 additions and 114 deletions
|
@ -15,9 +15,9 @@ use pin_project_lite::pin_project;
|
||||||
use wasm_bindgen_futures::spawn_local;
|
use wasm_bindgen_futures::spawn_local;
|
||||||
use webpki_roots::TLS_SERVER_ROOTS;
|
use webpki_roots::TLS_SERVER_ROOTS;
|
||||||
use wisp_mux::{
|
use wisp_mux::{
|
||||||
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
|
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
|
||||||
ws::{WebSocketRead, WebSocketWrite},
|
ws::{WebSocketRead, WebSocketWrite},
|
||||||
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
|
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Extensions,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -106,10 +106,12 @@ impl StreamProvider {
|
||||||
&self,
|
&self,
|
||||||
mut locked: MutexGuard<'_, Option<ClientMux>>,
|
mut locked: MutexGuard<'_, Option<ClientMux>>,
|
||||||
) -> Result<(), EpoxyError> {
|
) -> Result<(), EpoxyError> {
|
||||||
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
|
let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
|
||||||
vec![Box::new(UdpProtocolExtensionBuilder)];
|
vec![AnyProtocolExtensionBuilder::new(
|
||||||
|
UdpProtocolExtensionBuilder,
|
||||||
|
)];
|
||||||
let extensions = if self.wisp_v2 {
|
let extensions = if self.wisp_v2 {
|
||||||
Some(extensions_vec)
|
Some(WispV2Extensions::new(extensions_vec))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
|
@ -6,12 +6,15 @@ use lazy_static::lazy_static;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use regex::RegexSet;
|
use regex::RegexSet;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use wisp_mux::extensions::{
|
use wisp_mux::{
|
||||||
cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder},
|
extensions::{
|
||||||
motd::MotdProtocolExtensionBuilder,
|
cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder},
|
||||||
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
motd::MotdProtocolExtensionBuilder,
|
||||||
udp::UdpProtocolExtensionBuilder,
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
ProtocolExtensionBuilder,
|
udp::UdpProtocolExtensionBuilder,
|
||||||
|
AnyProtocolExtensionBuilder,
|
||||||
|
},
|
||||||
|
WispV2Extensions,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
|
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
|
||||||
|
@ -195,8 +198,6 @@ pub struct Config {
|
||||||
pub stream: StreamConfig,
|
pub stream: StreamConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
|
|
||||||
|
|
||||||
struct ConfigCache {
|
struct ConfigCache {
|
||||||
pub blocked_ports: Vec<RangeInclusive<u16>>,
|
pub blocked_ports: Vec<RangeInclusive<u16>>,
|
||||||
pub allowed_ports: Vec<RangeInclusive<u16>>,
|
pub allowed_ports: Vec<RangeInclusive<u16>>,
|
||||||
|
@ -293,41 +294,49 @@ impl Default for WispConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WispConfig {
|
impl WispConfig {
|
||||||
pub async fn to_opts(
|
pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Extensions>, Vec<u8>, u32)> {
|
||||||
&self,
|
|
||||||
) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, Vec<u8>, u32)> {
|
|
||||||
if self.wisp_v2 {
|
if self.wisp_v2 {
|
||||||
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
|
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
|
||||||
let mut required_extensions: Vec<u8> = Vec::new();
|
let mut required_extensions: Vec<u8> = Vec::new();
|
||||||
|
|
||||||
if self.extensions.contains(&ProtocolExtension::Udp) {
|
if self.extensions.contains(&ProtocolExtension::Udp) {
|
||||||
extensions.push(Box::new(UdpProtocolExtensionBuilder));
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
|
UdpProtocolExtensionBuilder,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.extensions.contains(&ProtocolExtension::Motd) {
|
if self.extensions.contains(&ProtocolExtension::Motd) {
|
||||||
extensions.push(Box::new(MotdProtocolExtensionBuilder::Server(
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
self.motd_extension.clone(),
|
MotdProtocolExtensionBuilder::Server(self.motd_extension.clone()),
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.auth_extension {
|
match self.auth_extension {
|
||||||
Some(ProtocolExtensionAuth::Password) => {
|
Some(ProtocolExtensionAuth::Password) => {
|
||||||
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
self.password_extension_users.clone(),
|
PasswordProtocolExtensionBuilder::new_server(
|
||||||
)));
|
self.password_extension_users.clone(),
|
||||||
|
),
|
||||||
|
));
|
||||||
required_extensions.push(PasswordProtocolExtension::ID);
|
required_extensions.push(PasswordProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
Some(ProtocolExtensionAuth::Certificate) => {
|
Some(ProtocolExtensionAuth::Certificate) => {
|
||||||
extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server(
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
get_certificates_from_paths(self.certificate_extension_keys.clone())
|
CertAuthProtocolExtensionBuilder::new_server(
|
||||||
.await?,
|
get_certificates_from_paths(self.certificate_extension_keys.clone())
|
||||||
)));
|
.await?,
|
||||||
|
),
|
||||||
|
));
|
||||||
required_extensions.push(CertAuthProtocolExtension::ID);
|
required_extensions.push(CertAuthProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
None => {}
|
None => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((Some(extensions), required_extensions, self.buffer_size))
|
Ok((
|
||||||
|
Some(WispV2Extensions::new(extensions)),
|
||||||
|
required_extensions,
|
||||||
|
self.buffer_size,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok((None, Vec::new(), self.buffer_size))
|
Ok((None, Vec::new(), self.buffer_size))
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,9 +36,9 @@ use wisp_mux::{
|
||||||
motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder},
|
motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder},
|
||||||
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
|
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
|
||||||
ProtocolExtensionBuilder,
|
AnyProtocolExtensionBuilder,
|
||||||
},
|
},
|
||||||
ClientMux, StreamType, WispError,
|
ClientMux, StreamType, WispError, WispV2Extensions,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -169,23 +169,27 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
(Cursor::new(parts.read_buf).chain(r), w)
|
(Cursor::new(parts.read_buf).chain(r), w)
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
|
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
|
||||||
let mut extension_ids: Vec<u8> = Vec::new();
|
let mut extension_ids: Vec<u8> = Vec::new();
|
||||||
if opts.udp {
|
if opts.udp {
|
||||||
extensions.push(Box::new(UdpProtocolExtensionBuilder));
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
|
UdpProtocolExtensionBuilder,
|
||||||
|
));
|
||||||
extension_ids.push(UdpProtocolExtension::ID);
|
extension_ids.push(UdpProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
if opts.motd {
|
if opts.motd {
|
||||||
extensions.push(Box::new(MotdProtocolExtensionBuilder::Client));
|
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||||
|
MotdProtocolExtensionBuilder::Client,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
if let Some(auth) = auth {
|
if let Some(auth) = auth {
|
||||||
extensions.push(Box::new(auth));
|
extensions.push(AnyProtocolExtensionBuilder::new(auth));
|
||||||
extension_ids.push(PasswordProtocolExtension::ID);
|
extension_ids.push(PasswordProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
if let Some(certauth) = opts.certauth {
|
if let Some(certauth) = opts.certauth {
|
||||||
let key = get_cert(certauth).await?;
|
let key = get_cert(certauth).await?;
|
||||||
let extension = CertAuthProtocolExtensionBuilder::new_client(key);
|
let extension = CertAuthProtocolExtensionBuilder::new_client(key);
|
||||||
extensions.push(Box::new(extension));
|
extensions.push(AnyProtocolExtensionBuilder::new(extension));
|
||||||
extension_ids.push(CertAuthProtocolExtension::ID);
|
extension_ids.push(CertAuthProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,7 +198,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
.await?
|
.await?
|
||||||
.with_no_required_extensions()
|
.with_no_required_extensions()
|
||||||
} else {
|
} else {
|
||||||
ClientMux::create(rx, tx, Some(extensions))
|
ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions)))
|
||||||
.await?
|
.await?
|
||||||
.with_required_extensions(extension_ids.as_slice())
|
.with_required_extensions(extension_ids.as_slice())
|
||||||
.await?
|
.await?
|
||||||
|
|
|
@ -155,7 +155,7 @@ impl dyn ProtocolExtension {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait to build a Wisp protocol extension from a payload.
|
/// Trait to build a Wisp protocol extension from a payload.
|
||||||
pub trait ProtocolExtensionBuilder {
|
pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
|
||||||
/// Get the protocol extension ID.
|
/// Get the protocol extension ID.
|
||||||
///
|
///
|
||||||
/// Used to decide whether this builder should be used.
|
/// Used to decide whether this builder should be used.
|
||||||
|
@ -170,4 +170,81 @@ pub trait ProtocolExtensionBuilder {
|
||||||
|
|
||||||
/// Build a protocol extension to send to the other side.
|
/// Build a protocol extension to send to the other side.
|
||||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
||||||
|
|
||||||
|
/// Do not override.
|
||||||
|
fn __internal_type_id(&self) -> TypeId {
|
||||||
|
TypeId::of::<Self>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl dyn ProtocolExtensionBuilder {
|
||||||
|
fn __is<T: ProtocolExtensionBuilder>(&self) -> bool {
|
||||||
|
let t = TypeId::of::<T>();
|
||||||
|
self.__internal_type_id() == t
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __downcast<T: ProtocolExtensionBuilder>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
|
||||||
|
if self.__is::<T>() {
|
||||||
|
unsafe {
|
||||||
|
let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self);
|
||||||
|
Ok(Box::from_raw(raw as *mut T))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
|
||||||
|
if self.__is::<T>() {
|
||||||
|
unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) }
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
|
||||||
|
if self.__is::<T>() {
|
||||||
|
unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) }
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type-erased protocol extension builder.
|
||||||
|
pub struct AnyProtocolExtensionBuilder(Box<dyn ProtocolExtensionBuilder>);
|
||||||
|
|
||||||
|
impl AnyProtocolExtensionBuilder {
|
||||||
|
/// Create a new type-erased protocol extension builder.
|
||||||
|
pub fn new<T: ProtocolExtensionBuilder>(extension: T) -> Self {
|
||||||
|
Self(Box::new(extension))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Downcast the protocol extension builder.
|
||||||
|
pub fn downcast<T: ProtocolExtensionBuilder>(self) -> Result<Box<T>, Self> {
|
||||||
|
self.0.__downcast().map_err(Self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Downcast the protocol extension builder.
|
||||||
|
pub fn downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
|
||||||
|
self.0.__downcast_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Downcast the protocol extension builder.
|
||||||
|
pub fn downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
|
||||||
|
self.0.__downcast_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for AnyProtocolExtensionBuilder {
|
||||||
|
type Target = dyn ProtocolExtensionBuilder;
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DerefMut for AnyProtocolExtensionBuilder {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.0.deref_mut()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
104
wisp/src/lib.rs
104
wisp/src/lib.rs
|
@ -19,13 +19,17 @@ pub mod ws;
|
||||||
|
|
||||||
pub use crate::{packet::*, stream::*};
|
pub use crate::{packet::*, stream::*};
|
||||||
|
|
||||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
|
||||||
use flume as mpsc;
|
use flume as mpsc;
|
||||||
use futures::{channel::oneshot, Future};
|
use futures::{channel::oneshot, Future};
|
||||||
use inner::{MuxInner, WsEvent};
|
use inner::{MuxInner, WsEvent};
|
||||||
use std::sync::{
|
use std::{
|
||||||
atomic::{AtomicBool, Ordering},
|
ops::DerefMut,
|
||||||
Arc,
|
pin::Pin,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
||||||
|
|
||||||
|
@ -173,7 +177,7 @@ async fn maybe_wisp_v2<R>(
|
||||||
read: &mut R,
|
read: &mut R,
|
||||||
write: &LockedWebSocketWrite,
|
write: &LockedWebSocketWrite,
|
||||||
role: Role,
|
role: Role,
|
||||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
|
@ -205,7 +209,7 @@ where
|
||||||
|
|
||||||
async fn send_info_packet(
|
async fn send_info_packet(
|
||||||
write: &LockedWebSocketWrite,
|
write: &LockedWebSocketWrite,
|
||||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
) -> Result<(), WispError> {
|
) -> Result<(), WispError> {
|
||||||
write
|
write
|
||||||
.write_frame(
|
.write_frame(
|
||||||
|
@ -220,6 +224,42 @@ async fn send_info_packet(
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
||||||
|
pub struct WispV2Extensions {
|
||||||
|
builders: Vec<AnyProtocolExtensionBuilder>,
|
||||||
|
closure: Box<
|
||||||
|
dyn Fn(
|
||||||
|
&mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||||
|
+ Send,
|
||||||
|
>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WispV2Extensions {
|
||||||
|
/// Create a Wisp V2 settings struct with no middleware.
|
||||||
|
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
|
||||||
|
Self {
|
||||||
|
builders,
|
||||||
|
closure: Box::new(|_| Box::pin(async { Ok(()) })),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a Wisp V2 settings struct with some middleware.
|
||||||
|
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
||||||
|
where
|
||||||
|
C: Fn(
|
||||||
|
&mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||||
|
+ Send
|
||||||
|
+ 'static,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
builders,
|
||||||
|
closure: Box::new(closure),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Server-side multiplexor.
|
/// Server-side multiplexor.
|
||||||
pub struct ServerMux {
|
pub struct ServerMux {
|
||||||
/// Whether the connection was downgraded to Wisp v1.
|
/// Whether the connection was downgraded to Wisp v1.
|
||||||
|
@ -237,14 +277,14 @@ pub struct ServerMux {
|
||||||
impl ServerMux {
|
impl ServerMux {
|
||||||
/// Create a new server-side multiplexor.
|
/// Create a new server-side multiplexor.
|
||||||
///
|
///
|
||||||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
pub async fn create<R, W>(
|
pub async fn create<R, W>(
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
tx: W,
|
tx: W,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
|
@ -256,13 +296,17 @@ impl ServerMux {
|
||||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let (supported_extensions, extra_packet, downgraded) =
|
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||||
if let Some(mut builders) = extension_builders {
|
mut builders,
|
||||||
send_info_packet(&tx, &mut builders).await?;
|
closure,
|
||||||
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
}) = wisp_v2
|
||||||
} else {
|
{
|
||||||
(Vec::new(), None, true)
|
send_info_packet(&tx, builders.deref_mut()).await?;
|
||||||
};
|
(closure)(builders.deref_mut()).await?;
|
||||||
|
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
||||||
|
} else {
|
||||||
|
(Vec::new(), None, true)
|
||||||
|
};
|
||||||
|
|
||||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
@ -424,13 +468,13 @@ pub struct ClientMux {
|
||||||
impl ClientMux {
|
impl ClientMux {
|
||||||
/// Create a new client side multiplexor.
|
/// Create a new client side multiplexor.
|
||||||
///
|
///
|
||||||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
pub async fn create<R, W>(
|
pub async fn create<R, W>(
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
tx: W,
|
tx: W,
|
||||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
|
@ -444,17 +488,21 @@ impl ClientMux {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||||
let (supported_extensions, extra_packet, downgraded) =
|
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||||
if let Some(mut builders) = extension_builders {
|
mut builders,
|
||||||
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
closure,
|
||||||
// if not downgraded
|
}) = wisp_v2
|
||||||
if !res.2 {
|
{
|
||||||
send_info_packet(&tx, &mut builders).await?;
|
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
||||||
}
|
// if not downgraded
|
||||||
res
|
if !res.2 {
|
||||||
} else {
|
(closure)(&mut builders).await?;
|
||||||
(Vec::new(), None, true)
|
send_info_packet(&tx, &mut builders).await?;
|
||||||
};
|
}
|
||||||
|
res
|
||||||
|
} else {
|
||||||
|
(Vec::new(), None, true)
|
||||||
|
};
|
||||||
|
|
||||||
let mux_result = MuxInner::new_client(
|
let mux_result = MuxInner::new_client(
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||||
Role, WispError, WISP_VERSION,
|
Role, WispError, WISP_VERSION,
|
||||||
};
|
};
|
||||||
|
@ -431,10 +431,57 @@ impl<'a> Packet<'a> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_info(
|
||||||
|
mut bytes: Payload<'a>,
|
||||||
|
role: Role,
|
||||||
|
extension_builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Result<Self, WispError> {
|
||||||
|
// packet type is already read by code that calls this
|
||||||
|
if bytes.remaining() < 4 + 2 {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
if bytes.get_u32_le() != 0 {
|
||||||
|
return Err(WispError::InvalidStreamId);
|
||||||
|
}
|
||||||
|
|
||||||
|
let version = WispVersion {
|
||||||
|
major: bytes.get_u8(),
|
||||||
|
minor: bytes.get_u8(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if version.major != WISP_VERSION.major {
|
||||||
|
return Err(WispError::IncompatibleProtocolVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut extensions = Vec::new();
|
||||||
|
|
||||||
|
while bytes.remaining() > 4 {
|
||||||
|
// We have some extensions
|
||||||
|
let id = bytes.get_u8();
|
||||||
|
let length = usize::try_from(bytes.get_u32_le())?;
|
||||||
|
if bytes.remaining() < length {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
|
||||||
|
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
|
||||||
|
} else {
|
||||||
|
bytes.advance(length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stream_id: 0,
|
||||||
|
packet_type: PacketType::Info(InfoPacket {
|
||||||
|
version,
|
||||||
|
extensions,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn maybe_parse_info(
|
pub(crate) fn maybe_parse_info(
|
||||||
frame: Frame<'a>,
|
frame: Frame<'a>,
|
||||||
role: Role,
|
role: Role,
|
||||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
extension_builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
) -> Result<Self, WispError> {
|
) -> Result<Self, WispError> {
|
||||||
if !frame.finished {
|
if !frame.finished {
|
||||||
return Err(WispError::WsFrameNotFinished);
|
return Err(WispError::WsFrameNotFinished);
|
||||||
|
@ -504,53 +551,6 @@ impl<'a> Packet<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_info(
|
|
||||||
mut bytes: Payload<'a>,
|
|
||||||
role: Role,
|
|
||||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
|
||||||
) -> Result<Self, WispError> {
|
|
||||||
// packet type is already read by code that calls this
|
|
||||||
if bytes.remaining() < 4 + 2 {
|
|
||||||
return Err(WispError::PacketTooSmall);
|
|
||||||
}
|
|
||||||
if bytes.get_u32_le() != 0 {
|
|
||||||
return Err(WispError::InvalidStreamId);
|
|
||||||
}
|
|
||||||
|
|
||||||
let version = WispVersion {
|
|
||||||
major: bytes.get_u8(),
|
|
||||||
minor: bytes.get_u8(),
|
|
||||||
};
|
|
||||||
|
|
||||||
if version.major != WISP_VERSION.major {
|
|
||||||
return Err(WispError::IncompatibleProtocolVersion);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut extensions = Vec::new();
|
|
||||||
|
|
||||||
while bytes.remaining() > 4 {
|
|
||||||
// We have some extensions
|
|
||||||
let id = bytes.get_u8();
|
|
||||||
let length = usize::try_from(bytes.get_u32_le())?;
|
|
||||||
if bytes.remaining() < length {
|
|
||||||
return Err(WispError::PacketTooSmall);
|
|
||||||
}
|
|
||||||
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
|
|
||||||
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
|
|
||||||
} else {
|
|
||||||
bytes.advance(length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
stream_id: 0,
|
|
||||||
packet_type: PacketType::Info(InfoPacket {
|
|
||||||
version,
|
|
||||||
extensions,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encode for Packet<'_> {
|
impl Encode for Packet<'_> {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue