WispV2Extensions -> WispV2Handshake and motd helpers

This commit is contained in:
Toshit Chawda 2024-10-25 18:23:16 -07:00
parent 36fddc8943
commit 41f2139eb1
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
7 changed files with 41 additions and 20 deletions

View file

@ -17,7 +17,7 @@ use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
ws::{WebSocketRead, WebSocketWrite}, ws::{WebSocketRead, WebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Extensions, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake,
}; };
use crate::{ use crate::{
@ -119,7 +119,7 @@ impl StreamProvider {
UdpProtocolExtensionBuilder, UdpProtocolExtensionBuilder,
)]; )];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
Some(WispV2Extensions::new(extensions_vec)) Some(WispV2Handshake::new(extensions_vec))
} else { } else {
None None
}; };

View file

@ -14,7 +14,7 @@ use wisp_mux::{
udp::UdpProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
AnyProtocolExtensionBuilder, AnyProtocolExtensionBuilder,
}, },
WispV2Extensions, WispV2Handshake,
}; };
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER}; use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
@ -346,7 +346,7 @@ impl Default for WispConfig {
} }
impl WispConfig { impl WispConfig {
pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Extensions>, Vec<u8>, u32)> { pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Handshake>, Vec<u8>, u32)> {
if self.wisp_v2 { if self.wisp_v2 {
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new(); let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut required_extensions: Vec<u8> = Vec::new(); let mut required_extensions: Vec<u8> = Vec::new();
@ -391,7 +391,7 @@ impl WispConfig {
} }
Ok(( Ok((
Some(WispV2Extensions::new(extensions)), Some(WispV2Handshake::new(extensions)),
required_extensions, required_extensions,
self.buffer_size, self.buffer_size,
)) ))

View file

@ -38,7 +38,7 @@ use wisp_mux::{
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
AnyProtocolExtensionBuilder, AnyProtocolExtensionBuilder,
}, },
ClientMux, StreamType, WispError, WispV2Extensions, ClientMux, StreamType, WispError, WispV2Handshake,
}; };
#[derive(Debug)] #[derive(Debug)]
@ -198,7 +198,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.await? .await?
.with_no_required_extensions() .with_no_required_extensions()
} else { } else {
ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions))) ClientMux::create(rx, tx, Some(WispV2Handshake::new(extensions)))
.await? .await?
.with_required_extensions(extension_ids.as_slice()) .with_required_extensions(extension_ids.as_slice())
.await? .await?

View file

@ -76,6 +76,18 @@ pub enum MotdProtocolExtensionBuilder {
Client, Client,
} }
impl MotdProtocolExtensionBuilder {
/// Create a new server variant of the MOTD protocol extension builder.
pub fn new_server(motd: String) -> Self {
Self::Server(motd)
}
/// Create a new client variant of the MOTD protocol extension builder.
pub fn new_client() -> Self {
Self::Client
}
}
impl ProtocolExtensionBuilder for MotdProtocolExtensionBuilder { impl ProtocolExtensionBuilder for MotdProtocolExtensionBuilder {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
MotdProtocolExtension::ID MotdProtocolExtension::ID

View file

@ -20,15 +20,15 @@ use crate::{
use super::{ use super::{
get_supported_extensions, validate_continue_packet, Multiplexor, MuxResult, get_supported_extensions, validate_continue_packet, Multiplexor, MuxResult,
WispHandshakeResult, WispHandshakeResultKind, WispV2Extensions, WispHandshakeResult, WispHandshakeResultKind, WispV2Handshake,
}; };
async fn handshake<R: WebSocketRead>( async fn handshake<R: WebSocketRead>(
rx: &mut R, rx: &mut R,
tx: &LockedWebSocketWrite, tx: &LockedWebSocketWrite,
v2_info: Option<WispV2Extensions>, v2_info: Option<WispV2Handshake>,
) -> Result<(WispHandshakeResult, u32), WispError> { ) -> Result<(WispHandshakeResult, u32), WispError> {
if let Some(WispV2Extensions { if let Some(WispV2Handshake {
mut builders, mut builders,
closure, closure,
}) = v2_info }) = v2_info
@ -100,7 +100,7 @@ impl ClientMux {
pub async fn create<R, W>( pub async fn create<R, W>(
mut rx: R, mut rx: R,
tx: W, tx: W,
wisp_v2: Option<WispV2Extensions>, wisp_v2: Option<WispV2Handshake>,
) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,

View file

@ -123,14 +123,19 @@ where
} }
} }
type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>;
/// Wisp V2 handshake and protocol extension settings wrapper struct. /// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions { pub struct WispV2Handshake {
builders: Vec<AnyProtocolExtensionBuilder>, builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<dyn Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send>, #[expect(clippy::type_complexity)]
closure: Box<
dyn Fn(
&mut Vec<AnyProtocolExtensionBuilder>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
} }
impl WispV2Extensions { impl WispV2Handshake {
/// Create a Wisp V2 settings struct with no middleware. /// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self { pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self { Self {
@ -142,7 +147,11 @@ impl WispV2Extensions {
/// Create a Wisp V2 settings struct with some middleware. /// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where where
C: Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send + 'static, C: Fn(
&mut Vec<AnyProtocolExtensionBuilder>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{ {
Self { Self {
builders, builders,

View file

@ -19,16 +19,16 @@ use crate::{
use super::{ use super::{
get_supported_extensions, send_info_packet, Multiplexor, MuxResult, WispHandshakeResult, get_supported_extensions, send_info_packet, Multiplexor, MuxResult, WispHandshakeResult,
WispHandshakeResultKind, WispV2Extensions, WispHandshakeResultKind, WispV2Handshake,
}; };
async fn handshake<R: WebSocketRead>( async fn handshake<R: WebSocketRead>(
rx: &mut R, rx: &mut R,
tx: &LockedWebSocketWrite, tx: &LockedWebSocketWrite,
buffer_size: u32, buffer_size: u32,
v2_info: Option<WispV2Extensions>, v2_info: Option<WispV2Handshake>,
) -> Result<WispHandshakeResult, WispError> { ) -> Result<WispHandshakeResult, WispError> {
if let Some(WispV2Extensions { if let Some(WispV2Handshake {
mut builders, mut builders,
closure, closure,
}) = v2_info }) = v2_info
@ -95,7 +95,7 @@ impl ServerMux {
mut rx: R, mut rx: R,
tx: W, tx: W,
buffer_size: u32, buffer_size: u32,
wisp_v2: Option<WispV2Extensions>, wisp_v2: Option<WispV2Handshake>,
) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError> ) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,