add middleware to wispv2 handshake

This commit is contained in:
r58Playz 2024-09-16 23:18:32 -07:00
parent d6f1a8da43
commit 7fdacb2623
6 changed files with 254 additions and 114 deletions

View file

@ -15,9 +15,9 @@ use pin_project_lite::pin_project;
use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
ws::{WebSocketRead, WebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Extensions,
};
use crate::{
@ -106,10 +106,12 @@ impl StreamProvider {
&self,
mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder)];
let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
vec![AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
)];
let extensions = if self.wisp_v2 {
Some(extensions_vec)
Some(WispV2Extensions::new(extensions_vec))
} else {
None
};

View file

@ -6,12 +6,15 @@ use lazy_static::lazy_static;
use log::LevelFilter;
use regex::RegexSet;
use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{
use wisp_mux::{
extensions::{
cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder},
motd::MotdProtocolExtensionBuilder,
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
AnyProtocolExtensionBuilder,
},
WispV2Extensions,
};
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
@ -195,8 +198,6 @@ pub struct Config {
pub stream: StreamConfig,
}
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
struct ConfigCache {
pub blocked_ports: Vec<RangeInclusive<u16>>,
pub allowed_ports: Vec<RangeInclusive<u16>>,
@ -293,41 +294,49 @@ impl Default for WispConfig {
}
impl WispConfig {
pub async fn to_opts(
&self,
) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, Vec<u8>, u32)> {
pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Extensions>, Vec<u8>, u32)> {
if self.wisp_v2 {
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut required_extensions: Vec<u8> = Vec::new();
if self.extensions.contains(&ProtocolExtension::Udp) {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
extensions.push(AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
));
}
if self.extensions.contains(&ProtocolExtension::Motd) {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Server(
self.motd_extension.clone(),
)));
extensions.push(AnyProtocolExtensionBuilder::new(
MotdProtocolExtensionBuilder::Server(self.motd_extension.clone()),
));
}
match self.auth_extension {
Some(ProtocolExtensionAuth::Password) => {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
extensions.push(AnyProtocolExtensionBuilder::new(
PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(),
)));
),
));
required_extensions.push(PasswordProtocolExtension::ID);
}
Some(ProtocolExtensionAuth::Certificate) => {
extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server(
extensions.push(AnyProtocolExtensionBuilder::new(
CertAuthProtocolExtensionBuilder::new_server(
get_certificates_from_paths(self.certificate_extension_keys.clone())
.await?,
)));
),
));
required_extensions.push(CertAuthProtocolExtension::ID);
}
None => {}
}
Ok((Some(extensions), required_extensions, self.buffer_size))
Ok((
Some(WispV2Extensions::new(extensions)),
required_extensions,
self.buffer_size,
))
} else {
Ok((None, Vec::new(), self.buffer_size))
}

View file

@ -36,9 +36,9 @@ use wisp_mux::{
motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder},
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder,
AnyProtocolExtensionBuilder,
},
ClientMux, StreamType, WispError,
ClientMux, StreamType, WispError, WispV2Extensions,
};
#[derive(Debug)]
@ -169,23 +169,27 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
(Cursor::new(parts.read_buf).chain(r), w)
});
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
extensions.push(AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
));
extension_ids.push(UdpProtocolExtension::ID);
}
if opts.motd {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Client));
extensions.push(AnyProtocolExtensionBuilder::new(
MotdProtocolExtensionBuilder::Client,
));
}
if let Some(auth) = auth {
extensions.push(Box::new(auth));
extensions.push(AnyProtocolExtensionBuilder::new(auth));
extension_ids.push(PasswordProtocolExtension::ID);
}
if let Some(certauth) = opts.certauth {
let key = get_cert(certauth).await?;
let extension = CertAuthProtocolExtensionBuilder::new_client(key);
extensions.push(Box::new(extension));
extensions.push(AnyProtocolExtensionBuilder::new(extension));
extension_ids.push(CertAuthProtocolExtension::ID);
}
@ -194,7 +198,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.await?
.with_no_required_extensions()
} else {
ClientMux::create(rx, tx, Some(extensions))
ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions)))
.await?
.with_required_extensions(extension_ids.as_slice())
.await?

View file

@ -155,7 +155,7 @@ impl dyn ProtocolExtension {
}
/// Trait to build a Wisp protocol extension from a payload.
pub trait ProtocolExtensionBuilder {
pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
/// Get the protocol extension ID.
///
/// Used to decide whether this builder should be used.
@ -170,4 +170,81 @@ pub trait ProtocolExtensionBuilder {
/// Build a protocol extension to send to the other side.
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
/// Do not override.
fn __internal_type_id(&self) -> TypeId {
TypeId::of::<Self>()
}
}
impl dyn ProtocolExtensionBuilder {
fn __is<T: ProtocolExtensionBuilder>(&self) -> bool {
let t = TypeId::of::<T>();
self.__internal_type_id() == t
}
fn __downcast<T: ProtocolExtensionBuilder>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
if self.__is::<T>() {
unsafe {
let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self);
Ok(Box::from_raw(raw as *mut T))
}
} else {
Err(self)
}
}
fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
if self.__is::<T>() {
unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) }
} else {
None
}
}
fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
if self.__is::<T>() {
unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) }
} else {
None
}
}
}
/// Type-erased protocol extension builder.
pub struct AnyProtocolExtensionBuilder(Box<dyn ProtocolExtensionBuilder>);
impl AnyProtocolExtensionBuilder {
/// Create a new type-erased protocol extension builder.
pub fn new<T: ProtocolExtensionBuilder>(extension: T) -> Self {
Self(Box::new(extension))
}
/// Downcast the protocol extension builder.
pub fn downcast<T: ProtocolExtensionBuilder>(self) -> Result<Box<T>, Self> {
self.0.__downcast().map_err(Self)
}
/// Downcast the protocol extension builder.
pub fn downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
self.0.__downcast_ref()
}
/// Downcast the protocol extension builder.
pub fn downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
self.0.__downcast_mut()
}
}
impl Deref for AnyProtocolExtensionBuilder {
type Target = dyn ProtocolExtensionBuilder;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl DerefMut for AnyProtocolExtensionBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}

View file

@ -19,13 +19,17 @@ pub mod ws;
pub use crate::{packet::*, stream::*};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, Future};
use inner::{MuxInner, WsEvent};
use std::sync::{
use std::{
ops::DerefMut,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
@ -173,7 +177,7 @@ async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
role: Role,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
where
R: ws::WebSocketRead + Send,
@ -205,7 +209,7 @@ where
async fn send_info_packet(
write: &LockedWebSocketWrite,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> {
write
.write_frame(
@ -220,6 +224,42 @@ async fn send_info_packet(
.await
}
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
}
impl WispV2Extensions {
/// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self {
builders,
closure: Box::new(|_| Box::pin(async { Ok(()) })),
}
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
}
}
/// Server-side multiplexor.
pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1.
@ -237,14 +277,14 @@ pub struct ServerMux {
impl ServerMux {
/// Create a new server-side multiplexor.
///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
buffer_size: u32,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
@ -256,9 +296,13 @@ impl ServerMux {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
send_info_packet(&tx, builders.deref_mut()).await?;
(closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
@ -424,13 +468,13 @@ pub struct ClientMux {
impl ClientMux {
/// Create a new client side multiplexor.
///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
@ -444,11 +488,15 @@ impl ClientMux {
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
(closure)(&mut builders).await?;
send_info_packet(&tx, &mut builders).await?;
}
res

View file

@ -1,5 +1,5 @@
use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
Role, WispError, WISP_VERSION,
};
@ -431,10 +431,57 @@ impl<'a> Packet<'a> {
})
}
fn parse_info(
mut bytes: Payload<'a>,
role: Role,
extension_builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<Self, WispError> {
// packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
// We have some extensions
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
pub(crate) fn maybe_parse_info(
frame: Frame<'a>,
role: Role,
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
extension_builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<Self, WispError> {
if !frame.finished {
return Err(WispError::WsFrameNotFinished);
@ -504,53 +551,6 @@ impl<'a> Packet<'a> {
}
}
}
fn parse_info(
mut bytes: Payload<'a>,
role: Role,
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> {
// packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
// We have some extensions
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
}
impl Encode for Packet<'_> {