From 77e377c814729ff92527f52f6d608360ac94c7da Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 27 Mar 2024 17:04:21 -0700 Subject: [PATCH] improve simple-wisp-client --- Cargo.lock | 10 ++- simple-wisp-client/Cargo.toml | 2 + simple-wisp-client/src/main.rs | 158 +++++++++++++++++++++------------ 3 files changed, 109 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a77d2ed..aa03ac4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,9 +305,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.3" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", "clap_derive", @@ -328,9 +328,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.3" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" dependencies = [ "heck", "proc-macro2 1.0.79", @@ -1735,10 +1735,12 @@ version = "1.0.0" dependencies = [ "atomic-counter", "bytes", + "clap", "console-subscriber", "fastwebsockets 0.7.1", "futures", "http-body-util", + "humantime", "hyper 1.2.0", "simple_moving_average", "tokio", diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index fed1966..dc09cde 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -6,10 +6,12 @@ edition = "2021" [dependencies] atomic-counter = "1.0.1" bytes = "1.5.0" +clap = { version = "4.5.4", features = ["cargo", "derive"] } console-subscriber = { version = "0.2.0", optional = true } fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } futures = "0.3.30" http-body-util = "0.1.0" +humantime = "2.1.0" hyper = { version = "1.1.0", features = ["http1", "client"] } simple_moving_average = "1.0.2" tokio = { version = "1.36.0", features = ["full"] } diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 6a78943..3f9c609 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -1,42 +1,51 @@ use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; +use clap::Parser; use fastwebsockets::{handshake, FragmentCollectorRead}; use futures::future::select_all; use http_body_util::Empty; +use humantime::format_duration; use hyper::{ header::{CONNECTION, UPGRADE}, - Request, + Request, Uri, }; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ error::Error, future::Future, io::{stdout, IsTerminal, Write}, + net::SocketAddr, sync::Arc, - time::Duration, + time::{Duration, Instant}, usize, }; -use tokio::{net::TcpStream, time::interval}; +use tokio::{ + net::TcpStream, + select, + signal::unix::{signal, SignalKind}, + time::{interval, sleep}, +}; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; use wisp_mux::{ClientMux, StreamType, WispError}; #[derive(Debug)] -struct StrError(String); - -impl StrError { - pub fn new(str: &str) -> Self { - Self(str.to_string()) - } +enum WispClientError { + InvalidUriScheme, + UriHasNoHost, } -impl std::fmt::Display for StrError { +impl std::fmt::Display for WispClientError { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - write!(fmt, "{}", self.0) + use WispClientError as E; + match self { + E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), + E::UriHasNoHost => write!(fmt, "URI has no host"), + } } } -impl Error for StrError {} +impl Error for WispClientError {} struct SpawnExecutor; @@ -50,59 +59,63 @@ where } } +#[derive(Parser)] +#[command(version = clap::crate_version!())] +struct Cli { + /// Wisp server URL + #[arg(short, long)] + wisp: Uri, + /// TCP server address + #[arg(short, long)] + tcp: SocketAddr, + /// Number of streams + #[arg(short, long, default_value_t = 10)] + streams: usize, + /// Size of packets sent, in KB + #[arg(short, long, default_value_t = 1)] + packet_size: usize, + /// Duration to run the test for + #[arg(short, long)] + duration: Option, +} + #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Box> { #[cfg(feature = "tokio-console")] console_subscriber::init(); - let addr = std::env::args() - .nth(1) - .ok_or(StrError::new("no src addr"))?; + let opts = Cli::parse(); - let addr_port: u16 = std::env::args() - .nth(2) - .ok_or(StrError::new("no src port"))? - .parse()?; - - let addr_path = std::env::args() - .nth(3) - .ok_or(StrError::new("no src path"))?; - - let addr_dest = std::env::args() - .nth(4) - .ok_or(StrError::new("no dest addr"))?; - - let addr_dest_port: u16 = std::env::args() - .nth(5) - .ok_or(StrError::new("no dest port"))? - .parse()?; - let should_tls: bool = std::env::args() - .nth(6) - .ok_or(StrError::new("no should tls"))? - .parse()?; - let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?; + let tls = match opts + .wisp + .scheme_str() + .ok_or(WispClientError::InvalidUriScheme)? + { + "wss" => Ok(true), + "ws" => Ok(false), + _ => Err(WispClientError::InvalidUriScheme), + }?; + let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; + let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 }); + let addr_path = opts.wisp.path(); + let addr_dest = opts.tcp.ip().to_string(); + let addr_dest_port = opts.tcp.port(); println!( - "connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}", - if should_tls { "wss" } else { "ws" }, - addr, - addr_port, - addr_path, - addr_dest, - addr_dest_port, - thread_cnt + "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", + opts.wisp, opts.packet_size, opts.tcp, opts.streams, ); let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; - let socket = if should_tls { + let socket = if tls { let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); - Either::Left(cx.connect(&addr, socket).await?) + Either::Left(cx.connect(addr, socket).await?) } else { Either::Right(socket) }; let req = Request::builder() .method("GET") .uri(addr_path) - .header("Host", &addr) + .header("Host", addr) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") .header( @@ -119,7 +132,7 @@ async fn main() -> Result<(), Box> { let rx = FragmentCollectorRead::new(rx); let (mux, fut) = ClientMux::new(rx, tx).await?; - let mut threads = Vec::with_capacity(thread_cnt + 1); + let mut threads = Vec::with_capacity(opts.streams * 2 + 3); threads.push(tokio::spawn(fut)); @@ -127,23 +140,30 @@ async fn main() -> Result<(), Box> { let cnt = Arc::new(RelaxedCounter::new(0)); - for _ in 0..thread_cnt { - let mut channel = mux + let start_time = Instant::now(); + for _ in 0..opts.streams { + let (mut cr, mut cw) = mux .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) - .await?; + .await? + .into_split(); let cnt = cnt.clone(); let payload = payload.clone(); threads.push(tokio::spawn(async move { loop { - channel.write(payload.clone()).await?; - channel.read().await; + cw.write(payload.clone()).await?; cnt.inc(); } #[allow(unreachable_code)] Ok::<(), WispError>(()) })); + threads.push(tokio::spawn(async move { + loop { + cr.read().await; + } + })); } + let cnt_avg = cnt.clone(); threads.push(tokio::spawn(async move { let mut interval = interval(Duration::from_millis(100)); let mut avg: SingleSumSMA = SingleSumSMA::new(); @@ -151,7 +171,7 @@ async fn main() -> Result<(), Box> { let is_term = stdout().is_terminal(); loop { interval.tick().await; - let now = cnt.get(); + let now = cnt_avg.get(); let stat = format!( "sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}", now, @@ -169,9 +189,33 @@ async fn main() -> Result<(), Box> { } })); - let out = select_all(threads.into_iter()).await; + threads.push(tokio::spawn(async move { + let mut interrupt = + signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; + let mut terminate = + signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; + select! { + _ = interrupt.recv() => (), + _ = terminate.recv() => (), + } + Ok(()) + })); - println!("\n\nout: {:?}", out.0); + if let Some(duration) = opts.duration { + threads.push(tokio::spawn(async move { + sleep(duration.into()).await; + Ok(()) + })); + } + + let _ = select_all(threads.into_iter()).await; + + println!( + "\n\nresults: {} packets of &[0; 1024 * {}] sent in {}", + cnt.get(), + opts.packet_size, + format_duration(Instant::now().duration_since(start_time)) + ); Ok(()) }