add middleware to wispv2 handshake

This commit is contained in:
r58Playz 2024-09-16 23:18:32 -07:00
parent d6f1a8da43
commit 7fdacb2623
6 changed files with 254 additions and 114 deletions

View file

@ -15,9 +15,9 @@ use pin_project_lite::pin_project;
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS; use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
ws::{WebSocketRead, WebSocketWrite}, ws::{WebSocketRead, WebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Extensions,
}; };
use crate::{ use crate::{
@ -106,10 +106,12 @@ impl StreamProvider {
&self, &self,
mut locked: MutexGuard<'_, Option<ClientMux>>, mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> { ) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> = let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
vec![Box::new(UdpProtocolExtensionBuilder)]; vec![AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
)];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
Some(extensions_vec) Some(WispV2Extensions::new(extensions_vec))
} else { } else {
None None
}; };

View file

@ -6,12 +6,15 @@ use lazy_static::lazy_static;
use log::LevelFilter; use log::LevelFilter;
use regex::RegexSet; use regex::RegexSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{ use wisp_mux::{
cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder}, extensions::{
motd::MotdProtocolExtensionBuilder, cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder},
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, motd::MotdProtocolExtensionBuilder,
udp::UdpProtocolExtensionBuilder, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
ProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
AnyProtocolExtensionBuilder,
},
WispV2Extensions,
}; };
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER}; use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
@ -195,8 +198,6 @@ pub struct Config {
pub stream: StreamConfig, pub stream: StreamConfig,
} }
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
struct ConfigCache { struct ConfigCache {
pub blocked_ports: Vec<RangeInclusive<u16>>, pub blocked_ports: Vec<RangeInclusive<u16>>,
pub allowed_ports: Vec<RangeInclusive<u16>>, pub allowed_ports: Vec<RangeInclusive<u16>>,
@ -293,41 +294,49 @@ impl Default for WispConfig {
} }
impl WispConfig { impl WispConfig {
pub async fn to_opts( pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Extensions>, Vec<u8>, u32)> {
&self,
) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, Vec<u8>, u32)> {
if self.wisp_v2 { if self.wisp_v2 {
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new(); let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut required_extensions: Vec<u8> = Vec::new(); let mut required_extensions: Vec<u8> = Vec::new();
if self.extensions.contains(&ProtocolExtension::Udp) { if self.extensions.contains(&ProtocolExtension::Udp) {
extensions.push(Box::new(UdpProtocolExtensionBuilder)); extensions.push(AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
));
} }
if self.extensions.contains(&ProtocolExtension::Motd) { if self.extensions.contains(&ProtocolExtension::Motd) {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Server( extensions.push(AnyProtocolExtensionBuilder::new(
self.motd_extension.clone(), MotdProtocolExtensionBuilder::Server(self.motd_extension.clone()),
))); ));
} }
match self.auth_extension { match self.auth_extension {
Some(ProtocolExtensionAuth::Password) => { Some(ProtocolExtensionAuth::Password) => {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server( extensions.push(AnyProtocolExtensionBuilder::new(
self.password_extension_users.clone(), PasswordProtocolExtensionBuilder::new_server(
))); self.password_extension_users.clone(),
),
));
required_extensions.push(PasswordProtocolExtension::ID); required_extensions.push(PasswordProtocolExtension::ID);
} }
Some(ProtocolExtensionAuth::Certificate) => { Some(ProtocolExtensionAuth::Certificate) => {
extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server( extensions.push(AnyProtocolExtensionBuilder::new(
get_certificates_from_paths(self.certificate_extension_keys.clone()) CertAuthProtocolExtensionBuilder::new_server(
.await?, get_certificates_from_paths(self.certificate_extension_keys.clone())
))); .await?,
),
));
required_extensions.push(CertAuthProtocolExtension::ID); required_extensions.push(CertAuthProtocolExtension::ID);
} }
None => {} None => {}
} }
Ok((Some(extensions), required_extensions, self.buffer_size)) Ok((
Some(WispV2Extensions::new(extensions)),
required_extensions,
self.buffer_size,
))
} else { } else {
Ok((None, Vec::new(), self.buffer_size)) Ok((None, Vec::new(), self.buffer_size))
} }

View file

@ -36,9 +36,9 @@ use wisp_mux::{
motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder}, motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder},
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder, AnyProtocolExtensionBuilder,
}, },
ClientMux, StreamType, WispError, ClientMux, StreamType, WispError, WispV2Extensions,
}; };
#[derive(Debug)] #[derive(Debug)]
@ -169,23 +169,27 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
(Cursor::new(parts.read_buf).chain(r), w) (Cursor::new(parts.read_buf).chain(r), w)
}); });
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new(); let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new(); let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp { if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder)); extensions.push(AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
));
extension_ids.push(UdpProtocolExtension::ID); extension_ids.push(UdpProtocolExtension::ID);
} }
if opts.motd { if opts.motd {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Client)); extensions.push(AnyProtocolExtensionBuilder::new(
MotdProtocolExtensionBuilder::Client,
));
} }
if let Some(auth) = auth { if let Some(auth) = auth {
extensions.push(Box::new(auth)); extensions.push(AnyProtocolExtensionBuilder::new(auth));
extension_ids.push(PasswordProtocolExtension::ID); extension_ids.push(PasswordProtocolExtension::ID);
} }
if let Some(certauth) = opts.certauth { if let Some(certauth) = opts.certauth {
let key = get_cert(certauth).await?; let key = get_cert(certauth).await?;
let extension = CertAuthProtocolExtensionBuilder::new_client(key); let extension = CertAuthProtocolExtensionBuilder::new_client(key);
extensions.push(Box::new(extension)); extensions.push(AnyProtocolExtensionBuilder::new(extension));
extension_ids.push(CertAuthProtocolExtension::ID); extension_ids.push(CertAuthProtocolExtension::ID);
} }
@ -194,7 +198,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.await? .await?
.with_no_required_extensions() .with_no_required_extensions()
} else { } else {
ClientMux::create(rx, tx, Some(extensions)) ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions)))
.await? .await?
.with_required_extensions(extension_ids.as_slice()) .with_required_extensions(extension_ids.as_slice())
.await? .await?

View file

@ -155,7 +155,7 @@ impl dyn ProtocolExtension {
} }
/// Trait to build a Wisp protocol extension from a payload. /// Trait to build a Wisp protocol extension from a payload.
pub trait ProtocolExtensionBuilder { pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
/// Get the protocol extension ID. /// Get the protocol extension ID.
/// ///
/// Used to decide whether this builder should be used. /// Used to decide whether this builder should be used.
@ -170,4 +170,81 @@ pub trait ProtocolExtensionBuilder {
/// Build a protocol extension to send to the other side. /// Build a protocol extension to send to the other side.
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>; fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
/// Do not override.
fn __internal_type_id(&self) -> TypeId {
TypeId::of::<Self>()
}
}
impl dyn ProtocolExtensionBuilder {
fn __is<T: ProtocolExtensionBuilder>(&self) -> bool {
let t = TypeId::of::<T>();
self.__internal_type_id() == t
}
fn __downcast<T: ProtocolExtensionBuilder>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
if self.__is::<T>() {
unsafe {
let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self);
Ok(Box::from_raw(raw as *mut T))
}
} else {
Err(self)
}
}
fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
if self.__is::<T>() {
unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) }
} else {
None
}
}
fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
if self.__is::<T>() {
unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) }
} else {
None
}
}
}
/// Type-erased protocol extension builder.
pub struct AnyProtocolExtensionBuilder(Box<dyn ProtocolExtensionBuilder>);
impl AnyProtocolExtensionBuilder {
/// Create a new type-erased protocol extension builder.
pub fn new<T: ProtocolExtensionBuilder>(extension: T) -> Self {
Self(Box::new(extension))
}
/// Downcast the protocol extension builder.
pub fn downcast<T: ProtocolExtensionBuilder>(self) -> Result<Box<T>, Self> {
self.0.__downcast().map_err(Self)
}
/// Downcast the protocol extension builder.
pub fn downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
self.0.__downcast_ref()
}
/// Downcast the protocol extension builder.
pub fn downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
self.0.__downcast_mut()
}
}
impl Deref for AnyProtocolExtensionBuilder {
type Target = dyn ProtocolExtensionBuilder;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl DerefMut for AnyProtocolExtensionBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
} }

View file

@ -19,13 +19,17 @@ pub mod ws;
pub use crate::{packet::*, stream::*}; pub use crate::{packet::*, stream::*};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
use flume as mpsc; use flume as mpsc;
use futures::{channel::oneshot, Future}; use futures::{channel::oneshot, Future};
use inner::{MuxInner, WsEvent}; use inner::{MuxInner, WsEvent};
use std::sync::{ use std::{
atomic::{AtomicBool, Ordering}, ops::DerefMut,
Arc, pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
@ -173,7 +177,7 @@ async fn maybe_wisp_v2<R>(
read: &mut R, read: &mut R,
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
role: Role, role: Role,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>], builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError> ) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
@ -205,7 +209,7 @@ where
async fn send_info_packet( async fn send_info_packet(
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>], builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> { ) -> Result<(), WispError> {
write write
.write_frame( .write_frame(
@ -220,6 +224,42 @@ async fn send_info_packet(
.await .await
} }
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
}
impl WispV2Extensions {
/// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self {
builders,
closure: Box::new(|_| Box::pin(async { Ok(()) })),
}
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
}
}
/// Server-side multiplexor. /// Server-side multiplexor.
pub struct ServerMux { pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1. /// Whether the connection was downgraded to Wisp v1.
@ -237,14 +277,14 @@ pub struct ServerMux {
impl ServerMux { impl ServerMux {
/// Create a new server-side multiplexor. /// Create a new server-side multiplexor.
/// ///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>( pub async fn create<R, W>(
mut rx: R, mut rx: R,
tx: W, tx: W,
buffer_size: u32, buffer_size: u32,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>, wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
@ -256,13 +296,17 @@ impl ServerMux {
tx.write_frame(Packet::new_continue(0, buffer_size).into()) tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?; .await?;
let (supported_extensions, extra_packet, downgraded) = let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
if let Some(mut builders) = extension_builders { mut builders,
send_info_packet(&tx, &mut builders).await?; closure,
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? }) = wisp_v2
} else { {
(Vec::new(), None, true) send_info_packet(&tx, builders.deref_mut()).await?;
}; (closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server( let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
@ -424,13 +468,13 @@ pub struct ClientMux {
impl ClientMux { impl ClientMux {
/// Create a new client side multiplexor. /// Create a new client side multiplexor.
/// ///
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>( pub async fn create<R, W>(
mut rx: R, mut rx: R,
tx: W, tx: W,
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>, wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
@ -444,17 +488,21 @@ impl ClientMux {
} }
if let PacketType::Continue(packet) = first_packet.packet_type { if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) = let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
if let Some(mut builders) = extension_builders { mut builders,
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; closure,
// if not downgraded }) = wisp_v2
if !res.2 { {
send_info_packet(&tx, &mut builders).await?; let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
} // if not downgraded
res if !res.2 {
} else { (closure)(&mut builders).await?;
(Vec::new(), None, true) send_info_packet(&tx, &mut builders).await?;
}; }
res
} else {
(Vec::new(), None, true)
};
let mux_result = MuxInner::new_client( let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
Role, WispError, WISP_VERSION, Role, WispError, WISP_VERSION,
}; };
@ -431,10 +431,57 @@ impl<'a> Packet<'a> {
}) })
} }
fn parse_info(
mut bytes: Payload<'a>,
role: Role,
extension_builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<Self, WispError> {
// packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
// We have some extensions
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
pub(crate) fn maybe_parse_info( pub(crate) fn maybe_parse_info(
frame: Frame<'a>, frame: Frame<'a>,
role: Role, role: Role,
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
if !frame.finished { if !frame.finished {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
@ -504,53 +551,6 @@ impl<'a> Packet<'a> {
} }
} }
} }
fn parse_info(
mut bytes: Payload<'a>,
role: Role,
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> {
// packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
// We have some extensions
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
} }
impl Encode for Packet<'_> { impl Encode for Packet<'_> {