enforce UdpProtocolExtension if requested

This commit is contained in:
Toshit Chawda 2024-04-13 20:32:21 -07:00
parent 397fd43dc5
commit 4d433b60c4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
2 changed files with 11 additions and 2 deletions

View file

@ -270,6 +270,8 @@ async fn accept_ws(
let (mut mux, fut) =
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?;
println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids);
tokio::spawn(async move {
if let Err(e) = fut.await {
println!("err in mux: {:?}", e);

View file

@ -27,7 +27,7 @@ use tokio::{
};
use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either;
use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, ClientMux, StreamType, WispError};
use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError};
#[derive(Debug)]
enum WispClientError {
@ -134,11 +134,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx);
let (mut mux, fut) = if opts.udp {
ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?
let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?;
if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) {
println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
exit(1);
}
(mux, fut)
} else {
ClientMux::new(rx, tx, Some(&[])).await?
};
println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
threads.push(tokio::spawn(fut));