make requiring protocol extensions easy

This commit is contained in:
Toshit Chawda 2024-04-20 18:38:38 -07:00
parent 063b527914
commit 01d7ac5002
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 143 additions and 90 deletions

View file

@ -156,53 +156,33 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx);
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
extension_ids.push(UdpProtocolExtension::ID);
}
let enforce_auth = auth.is_some();
if let Some(auth) = auth {
extensions.push(Box::new(auth));
extension_ids.push(PasswordProtocolExtension::ID);
}
let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await?
ClientMux::create(rx, tx, None)
.await?
.with_no_required_extensions()
} else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
ClientMux::create(rx, tx, Some(extensions.as_slice()))
.await?
.with_required_extensions(extension_ids.as_slice()).await?
};
if opts.udp
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
println!(
"server did not support udp, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
if enforce_auth
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
let mut threads = Vec::with_capacity(opts.streams + 4);
let mut reads = Vec::with_capacity(opts.streams);
threads.push(tokio::spawn(fut));
@ -226,13 +206,15 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[allow(unreachable_code)]
Ok::<(), WispError>(())
}));
threads.push(tokio::spawn(async move {
loop {
cr.read().await;
}
}));
reads.push(cr);
}
threads.push(tokio::spawn(async move {
loop {
select_all(reads.iter().map(|x| Box::pin(x.read()))).await;
}
}));
let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
@ -295,14 +277,16 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
mux.close().await?;
println!(
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(),
opts.packet_size,
cnt.get() * opts.packet_size,
format_duration(duration_since),
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
);
if duration_since.as_secs() != 0 {
println!(
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(),
opts.packet_size,
cnt.get() * opts.packet_size,
format_duration(duration_since),
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
);
}
Ok(())
}