enforce UdpProtocolExtension if requested

This commit is contained in:
Toshit Chawda 2024-04-13 20:32:21 -07:00
parent 397fd43dc5
commit 4d433b60c4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
2 changed files with 11 additions and 2 deletions

View file

@ -270,6 +270,8 @@ async fn accept_ws(
let (mut mux, fut) = let (mut mux, fut) =
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?;
println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = fut.await { if let Err(e) = fut.await {
println!("err in mux: {:?}", e); println!("err in mux: {:?}", e);

View file

@ -27,7 +27,7 @@ use tokio::{
}; };
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, ClientMux, StreamType, WispError}; use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError};
#[derive(Debug)] #[derive(Debug)]
enum WispClientError { enum WispClientError {
@ -134,11 +134,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let (mut mux, fut) = if opts.udp { let (mut mux, fut) = if opts.udp {
ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await? let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?;
if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) {
println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
exit(1);
}
(mux, fut)
} else { } else {
ClientMux::new(rx, tx, Some(&[])).await? ClientMux::new(rx, tx, Some(&[])).await?
}; };
println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3); let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
threads.push(tokio::spawn(fut)); threads.push(tokio::spawn(fut));