fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets

This commit is contained in:
Toshit Chawda 2024-04-13 22:34:26 -07:00
parent 4d433b60c4
commit d10b7691e4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 220 additions and 50 deletions

View file

@ -27,7 +27,14 @@ use tokio::{
};
use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either;
use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError};
use wisp_mux::{
extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder,
},
ClientMux, StreamType, WispError,
};
#[derive(Debug)]
enum WispClientError {
@ -80,6 +87,11 @@ struct Cli {
/// Ask for UDP
#[arg(short, long)]
udp: bool,
/// Enable auth: format is `username:password`
///
/// Usernames and passwords are sent in plaintext!!
#[arg(long)]
auth: Option<String>,
}
#[tokio::main(flavor = "multi_thread")]
@ -103,6 +115,13 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port();
let auth = opts.auth.map(|auth| {
let split: Vec<_> = auth.split(':').collect();
let username = split[0].to_string();
let password = split[1..].join(":");
PasswordProtocolExtensionBuilder::new_client(username, password)
});
println!(
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
@ -133,18 +152,49 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx);
let (mut mux, fut) = if opts.udp {
let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?;
if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) {
println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
exit(1);
}
(mux, fut)
} else {
ClientMux::new(rx, tx, Some(&[])).await?
};
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Sync)>> = Vec::new();
if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
}
let enforce_auth = auth.is_some();
if let Some(auth) = auth {
extensions.push(Box::new(auth));
}
let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> =
extensions.iter().map(|x| x.as_ref()).collect();
println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?;
if opts.udp
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
println!(
"server did not support udp, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
if enforce_auth
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);