add middleware to wispv2 handshake

This commit is contained in:
r58Playz 2024-09-16 23:18:32 -07:00
parent d6f1a8da43
commit 7fdacb2623
6 changed files with 254 additions and 114 deletions

View file

@ -36,9 +36,9 @@ use wisp_mux::{
motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder},
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder,
AnyProtocolExtensionBuilder,
},
ClientMux, StreamType, WispError,
ClientMux, StreamType, WispError, WispV2Extensions,
};
#[derive(Debug)]
@ -169,23 +169,27 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
(Cursor::new(parts.read_buf).chain(r), w)
});
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
extensions.push(AnyProtocolExtensionBuilder::new(
UdpProtocolExtensionBuilder,
));
extension_ids.push(UdpProtocolExtension::ID);
}
if opts.motd {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Client));
extensions.push(AnyProtocolExtensionBuilder::new(
MotdProtocolExtensionBuilder::Client,
));
}
if let Some(auth) = auth {
extensions.push(Box::new(auth));
extensions.push(AnyProtocolExtensionBuilder::new(auth));
extension_ids.push(PasswordProtocolExtension::ID);
}
if let Some(certauth) = opts.certauth {
let key = get_cert(certauth).await?;
let extension = CertAuthProtocolExtensionBuilder::new_client(key);
extensions.push(Box::new(extension));
extensions.push(AnyProtocolExtensionBuilder::new(extension));
extension_ids.push(CertAuthProtocolExtension::ID);
}
@ -194,7 +198,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.await?
.with_no_required_extensions()
} else {
ClientMux::create(rx, tx, Some(extensions))
ClientMux::create(rx, tx, Some(WispV2Extensions::new(extensions)))
.await?
.with_required_extensions(extension_ids.as_slice())
.await?