fix the dreaded simple wisp client freeze

This commit is contained in:
Toshit Chawda 2024-11-29 22:36:41 -08:00
parent 68c0198784
commit 583a78bcc0
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D

View file

@ -2,8 +2,8 @@ use atomic_counter::{AtomicCounter, RelaxedCounter};
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
use ed25519_dalek::pkcs8::DecodePrivateKey; use ed25519_dalek::pkcs8::DecodePrivateKey;
use fastwebsockets::handshake; use fastwebsockets::{handshake, WebSocketWrite};
use futures::future::select_all; use futures::{future::select_all, FutureExt, TryFutureExt};
use http_body_util::Empty; use http_body_util::Empty;
use humantime::format_duration; use humantime::format_duration;
use hyper::{ use hyper::{
@ -19,13 +19,14 @@ use std::{
io::{stdout, Cursor, IsTerminal, Write}, io::{stdout, Cursor, IsTerminal, Write},
net::SocketAddr, net::SocketAddr,
path::PathBuf, path::PathBuf,
pin::Pin,
process::{abort, exit}, process::{abort, exit},
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tokio::{ use tokio::{
io::AsyncReadExt, io::AsyncReadExt,
net::TcpStream, net::{tcp::OwnedWriteHalf, TcpStream},
select, select,
signal::unix::{signal, SignalKind}, signal::unix::{signal, SignalKind},
time::{interval, sleep}, time::{interval, sleep},
@ -127,11 +128,15 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
tokio::spawn(real_main()).await? tokio::spawn(real_main()).await?
} }
async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> { async fn create_mux(
#[cfg(feature = "tokio-console")] opts: &Cli,
console_subscriber::init(); ) -> Result<
let opts = Cli::parse(); (
ClientMux<WebSocketWrite<OwnedWriteHalf>>,
impl Future<Output = Result<(), WispError>> + Send,
),
Box<dyn Error + Send + Sync>,
> {
if opts.wisp.scheme_str().unwrap_or_default() != "ws" { if opts.wisp.scheme_str().unwrap_or_default() != "ws" {
Err(Box::new(WispClientError::InvalidUriScheme))?; Err(Box::new(WispClientError::InvalidUriScheme))?;
} }
@ -139,10 +144,8 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?;
let addr_port = opts.wisp.port_u16().unwrap_or(80); let addr_port = opts.wisp.port_u16().unwrap_or(80);
let addr_path = opts.wisp.path(); let addr_path = opts.wisp.path();
let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port();
let auth = opts.auth.map(|auth| { let auth = opts.auth.as_ref().map(|auth| {
let split: Vec<_> = auth.split(':').collect(); let split: Vec<_> = auth.split(':').collect();
let username = split[0].to_string(); let username = split[0].to_string();
let password = split[1..].join(":"); let password = split[1..].join(":");
@ -193,24 +196,37 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
extensions.push(AnyProtocolExtensionBuilder::new(auth)); extensions.push(AnyProtocolExtensionBuilder::new(auth));
extension_ids.push(PasswordProtocolExtension::ID); extension_ids.push(PasswordProtocolExtension::ID);
} }
if let Some(certauth) = opts.certauth { if let Some(certauth) = &opts.certauth {
let key = get_cert(certauth).await?; let key = get_cert(certauth.clone()).await?;
let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key)); let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key));
extensions.push(AnyProtocolExtensionBuilder::new(extension)); extensions.push(AnyProtocolExtensionBuilder::new(extension));
extension_ids.push(CertAuthProtocolExtension::ID); extension_ids.push(CertAuthProtocolExtension::ID);
} }
let (mux, fut) = if !opts.wisp_v2 { let (mux, fut) = if opts.wisp_v2 {
ClientMux::create(rx, tx, None)
.await?
.with_no_required_extensions()
} else {
ClientMux::create(rx, tx, Some(WispV2Handshake::new(extensions))) ClientMux::create(rx, tx, Some(WispV2Handshake::new(extensions)))
.await? .await?
.with_required_extensions(extension_ids.as_slice()) .with_required_extensions(extension_ids.as_slice())
.await? .await?
} else {
ClientMux::create(rx, tx, None)
.await?
.with_no_required_extensions()
}; };
Ok((mux, fut))
}
#[allow(clippy::too_many_lines)]
async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[cfg(feature = "tokio-console")]
console_subscriber::init();
let opts = Cli::parse();
let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port();
let (mux, fut) = create_mux(&opts).await?;
let motd_extension = mux let motd_extension = mux
.supported_extensions .supported_extensions
.iter() .iter()
@ -228,7 +244,12 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
let mut threads = Vec::with_capacity((opts.streams * 2) + 3); let mut threads = Vec::with_capacity((opts.streams * 2) + 3);
threads.push(tokio::spawn(fut)); threads.push(Box::pin(
tokio::spawn(fut)
.map_err(|x| WispError::Other(Box::new(x)))
.map(|x| x.and_then(|x| x)),
)
as Pin<Box<dyn Future<Output = Result<(), WispError>> + Send>>);
let payload = vec![0; 1024 * opts.packet_size]; let payload = vec![0; 1024 * opts.packet_size];
@ -242,15 +263,14 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
.into_split(); .into_split();
let cnt = cnt.clone(); let cnt = cnt.clone();
let payload = payload.clone(); let payload = payload.clone();
threads.push(tokio::spawn(async move { threads.push(Box::pin(async move {
loop { while let Ok(()) = cw.write(&payload).await {
cw.write(&payload).await?;
cnt.inc(); cnt.inc();
} }
#[allow(unreachable_code)] #[allow(unreachable_code)]
Ok::<(), WispError>(()) Ok::<(), WispError>(())
})); }));
threads.push(tokio::spawn(async move { threads.push(Box::pin(async move {
loop { loop {
let _ = cr.read().await; let _ = cr.read().await;
} }
@ -258,7 +278,7 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
} }
let cnt_avg = cnt.clone(); let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move { threads.push(Box::pin(async move {
let mut interval = interval(Duration::from_millis(100)); let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new(); let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
let mut last_time = 0; let mut last_time = 0;
@ -287,7 +307,7 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
} }
})); }));
threads.push(tokio::spawn(async move { threads.push(Box::pin(async move {
let mut interrupt = let mut interrupt =
signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?;
let mut terminate = let mut terminate =
@ -300,13 +320,13 @@ async fn real_main() -> Result<(), Box<dyn Error + Send + Sync>> {
})); }));
if let Some(duration) = opts.duration { if let Some(duration) = opts.duration {
threads.push(tokio::spawn(async move { threads.push(Box::pin(async move {
sleep(duration.into()).await; sleep(duration.into()).await;
Ok(()) Ok(())
})); }));
} }
let out = select_all(threads.into_iter()).await; let out = select_all(threads.into_iter().map(tokio::spawn)).await;
let duration_since = Instant::now().duration_since(start_time); let duration_since = Instant::now().duration_since(start_time);