diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index f4205e6..e4e32e6 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -2,8 +2,8 @@ use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; use clap::Parser; use ed25519_dalek::pkcs8::DecodePrivateKey; -use fastwebsockets::handshake; -use futures::future::select_all; +use fastwebsockets::{handshake, WebSocketWrite}; +use futures::{future::select_all, FutureExt, TryFutureExt}; use http_body_util::Empty; use humantime::format_duration; use hyper::{ @@ -19,13 +19,14 @@ use std::{ io::{stdout, Cursor, IsTerminal, Write}, net::SocketAddr, path::PathBuf, + pin::Pin, process::{abort, exit}, sync::Arc, time::{Duration, Instant}, }; use tokio::{ io::AsyncReadExt, - net::TcpStream, + net::{tcp::OwnedWriteHalf, TcpStream}, select, signal::unix::{signal, SignalKind}, time::{interval, sleep}, @@ -127,11 +128,15 @@ async fn main() -> Result<(), Box> { tokio::spawn(real_main()).await? } -async fn real_main() -> Result<(), Box> { - #[cfg(feature = "tokio-console")] - console_subscriber::init(); - let opts = Cli::parse(); - +async fn create_mux( + opts: &Cli, +) -> Result< + ( + ClientMux>, + impl Future> + Send, + ), + Box, +> { if opts.wisp.scheme_str().unwrap_or_default() != "ws" { Err(Box::new(WispClientError::InvalidUriScheme))?; } @@ -139,10 +144,8 @@ async fn real_main() -> Result<(), Box> { let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; let addr_port = opts.wisp.port_u16().unwrap_or(80); let addr_path = opts.wisp.path(); - let addr_dest = opts.tcp.ip().to_string(); - let addr_dest_port = opts.tcp.port(); - let auth = opts.auth.map(|auth| { + let auth = opts.auth.as_ref().map(|auth| { let split: Vec<_> = auth.split(':').collect(); let username = split[0].to_string(); let password = split[1..].join(":"); @@ -193,24 +196,37 @@ async fn real_main() -> Result<(), Box> { extensions.push(AnyProtocolExtensionBuilder::new(auth)); extension_ids.push(PasswordProtocolExtension::ID); } - if let Some(certauth) = opts.certauth { - let key = get_cert(certauth).await?; + if let Some(certauth) = &opts.certauth { + let key = get_cert(certauth.clone()).await?; let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key)); extensions.push(AnyProtocolExtensionBuilder::new(extension)); extension_ids.push(CertAuthProtocolExtension::ID); } - let (mux, fut) = if !opts.wisp_v2 { - ClientMux::create(rx, tx, None) - .await? - .with_no_required_extensions() - } else { + let (mux, fut) = if opts.wisp_v2 { ClientMux::create(rx, tx, Some(WispV2Handshake::new(extensions))) .await? .with_required_extensions(extension_ids.as_slice()) .await? + } else { + ClientMux::create(rx, tx, None) + .await? + .with_no_required_extensions() }; + Ok((mux, fut)) +} + +#[allow(clippy::too_many_lines)] +async fn real_main() -> Result<(), Box> { + #[cfg(feature = "tokio-console")] + console_subscriber::init(); + let opts = Cli::parse(); + + let addr_dest = opts.tcp.ip().to_string(); + let addr_dest_port = opts.tcp.port(); + let (mux, fut) = create_mux(&opts).await?; + let motd_extension = mux .supported_extensions .iter() @@ -228,7 +244,12 @@ async fn real_main() -> Result<(), Box> { let mut threads = Vec::with_capacity((opts.streams * 2) + 3); - threads.push(tokio::spawn(fut)); + threads.push(Box::pin( + tokio::spawn(fut) + .map_err(|x| WispError::Other(Box::new(x))) + .map(|x| x.and_then(|x| x)), + ) + as Pin> + Send>>); let payload = vec![0; 1024 * opts.packet_size]; @@ -242,15 +263,14 @@ async fn real_main() -> Result<(), Box> { .into_split(); let cnt = cnt.clone(); let payload = payload.clone(); - threads.push(tokio::spawn(async move { - loop { - cw.write(&payload).await?; + threads.push(Box::pin(async move { + while let Ok(()) = cw.write(&payload).await { cnt.inc(); } #[allow(unreachable_code)] Ok::<(), WispError>(()) })); - threads.push(tokio::spawn(async move { + threads.push(Box::pin(async move { loop { let _ = cr.read().await; } @@ -258,7 +278,7 @@ async fn real_main() -> Result<(), Box> { } let cnt_avg = cnt.clone(); - threads.push(tokio::spawn(async move { + threads.push(Box::pin(async move { let mut interval = interval(Duration::from_millis(100)); let mut avg: SingleSumSMA = SingleSumSMA::new(); let mut last_time = 0; @@ -287,7 +307,7 @@ async fn real_main() -> Result<(), Box> { } })); - threads.push(tokio::spawn(async move { + threads.push(Box::pin(async move { let mut interrupt = signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; let mut terminate = @@ -300,13 +320,13 @@ async fn real_main() -> Result<(), Box> { })); if let Some(duration) = opts.duration { - threads.push(tokio::spawn(async move { + threads.push(Box::pin(async move { sleep(duration.into()).await; Ok(()) })); } - let out = select_all(threads.into_iter()).await; + let out = select_all(threads.into_iter().map(tokio::spawn)).await; let duration_since = Instant::now().duration_since(start_time);