improve simple-wisp-client

This commit is contained in:
Toshit Chawda 2024-03-27 17:04:21 -07:00
parent 412e797321
commit 77e377c814
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 109 additions and 61 deletions

10
Cargo.lock generated
View file

@ -305,9 +305,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.3" version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813" checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -328,9 +328,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.3" version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f" checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro2 1.0.79", "proc-macro2 1.0.79",
@ -1735,10 +1735,12 @@ version = "1.0.0"
dependencies = [ dependencies = [
"atomic-counter", "atomic-counter",
"bytes", "bytes",
"clap",
"console-subscriber", "console-subscriber",
"fastwebsockets 0.7.1", "fastwebsockets 0.7.1",
"futures", "futures",
"http-body-util", "http-body-util",
"humantime",
"hyper 1.2.0", "hyper 1.2.0",
"simple_moving_average", "simple_moving_average",
"tokio", "tokio",

View file

@ -6,10 +6,12 @@ edition = "2021"
[dependencies] [dependencies]
atomic-counter = "1.0.1" atomic-counter = "1.0.1"
bytes = "1.5.0" bytes = "1.5.0"
clap = { version = "4.5.4", features = ["cargo", "derive"] }
console-subscriber = { version = "0.2.0", optional = true } console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] }
futures = "0.3.30" futures = "0.3.30"
http-body-util = "0.1.0" http-body-util = "0.1.0"
humantime = "2.1.0"
hyper = { version = "1.1.0", features = ["http1", "client"] } hyper = { version = "1.1.0", features = ["http1", "client"] }
simple_moving_average = "1.0.2" simple_moving_average = "1.0.2"
tokio = { version = "1.36.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] }

View file

@ -1,42 +1,51 @@
use atomic_counter::{AtomicCounter, RelaxedCounter}; use atomic_counter::{AtomicCounter, RelaxedCounter};
use bytes::Bytes; use bytes::Bytes;
use clap::Parser;
use fastwebsockets::{handshake, FragmentCollectorRead}; use fastwebsockets::{handshake, FragmentCollectorRead};
use futures::future::select_all; use futures::future::select_all;
use http_body_util::Empty; use http_body_util::Empty;
use humantime::format_duration;
use hyper::{ use hyper::{
header::{CONNECTION, UPGRADE}, header::{CONNECTION, UPGRADE},
Request, Request, Uri,
}; };
use simple_moving_average::{SingleSumSMA, SMA}; use simple_moving_average::{SingleSumSMA, SMA};
use std::{ use std::{
error::Error, error::Error,
future::Future, future::Future,
io::{stdout, IsTerminal, Write}, io::{stdout, IsTerminal, Write},
net::SocketAddr,
sync::Arc, sync::Arc,
time::Duration, time::{Duration, Instant},
usize, usize,
}; };
use tokio::{net::TcpStream, time::interval}; use tokio::{
net::TcpStream,
select,
signal::unix::{signal, SignalKind},
time::{interval, sleep},
};
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{ClientMux, StreamType, WispError}; use wisp_mux::{ClientMux, StreamType, WispError};
#[derive(Debug)] #[derive(Debug)]
struct StrError(String); enum WispClientError {
InvalidUriScheme,
impl StrError { UriHasNoHost,
pub fn new(str: &str) -> Self {
Self(str.to_string())
}
} }
impl std::fmt::Display for StrError { impl std::fmt::Display for WispClientError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(fmt, "{}", self.0) use WispClientError as E;
match self {
E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"),
E::UriHasNoHost => write!(fmt, "URI has no host"),
}
} }
} }
impl Error for StrError {} impl Error for WispClientError {}
struct SpawnExecutor; struct SpawnExecutor;
@ -50,59 +59,63 @@ where
} }
} }
#[derive(Parser)]
#[command(version = clap::crate_version!())]
struct Cli {
/// Wisp server URL
#[arg(short, long)]
wisp: Uri,
/// TCP server address
#[arg(short, long)]
tcp: SocketAddr,
/// Number of streams
#[arg(short, long, default_value_t = 10)]
streams: usize,
/// Size of packets sent, in KB
#[arg(short, long, default_value_t = 1)]
packet_size: usize,
/// Duration to run the test for
#[arg(short, long)]
duration: Option<humantime::Duration>,
}
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[cfg(feature = "tokio-console")] #[cfg(feature = "tokio-console")]
console_subscriber::init(); console_subscriber::init();
let addr = std::env::args() let opts = Cli::parse();
.nth(1)
.ok_or(StrError::new("no src addr"))?;
let addr_port: u16 = std::env::args() let tls = match opts
.nth(2) .wisp
.ok_or(StrError::new("no src port"))? .scheme_str()
.parse()?; .ok_or(WispClientError::InvalidUriScheme)?
{
let addr_path = std::env::args() "wss" => Ok(true),
.nth(3) "ws" => Ok(false),
.ok_or(StrError::new("no src path"))?; _ => Err(WispClientError::InvalidUriScheme),
}?;
let addr_dest = std::env::args() let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?;
.nth(4) let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 });
.ok_or(StrError::new("no dest addr"))?; let addr_path = opts.wisp.path();
let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port: u16 = std::env::args() let addr_dest_port = opts.tcp.port();
.nth(5)
.ok_or(StrError::new("no dest port"))?
.parse()?;
let should_tls: bool = std::env::args()
.nth(6)
.ok_or(StrError::new("no should tls"))?
.parse()?;
let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?;
println!( println!(
"connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}", "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
if should_tls { "wss" } else { "ws" }, opts.wisp, opts.packet_size, opts.tcp, opts.streams,
addr,
addr_port,
addr_path,
addr_dest,
addr_dest_port,
thread_cnt
); );
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if should_tls { let socket = if tls {
let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?);
Either::Left(cx.connect(&addr, socket).await?) Either::Left(cx.connect(addr, socket).await?)
} else { } else {
Either::Right(socket) Either::Right(socket)
}; };
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri(addr_path) .uri(addr_path)
.header("Host", &addr) .header("Host", addr)
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header( .header(
@ -119,7 +132,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?; let (mux, fut) = ClientMux::new(rx, tx).await?;
let mut threads = Vec::with_capacity(thread_cnt + 1); let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
threads.push(tokio::spawn(fut)); threads.push(tokio::spawn(fut));
@ -127,23 +140,30 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let cnt = Arc::new(RelaxedCounter::new(0)); let cnt = Arc::new(RelaxedCounter::new(0));
for _ in 0..thread_cnt { let start_time = Instant::now();
let mut channel = mux for _ in 0..opts.streams {
let (mut cr, mut cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?; .await?
.into_split();
let cnt = cnt.clone(); let cnt = cnt.clone();
let payload = payload.clone(); let payload = payload.clone();
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
loop { loop {
channel.write(payload.clone()).await?; cw.write(payload.clone()).await?;
channel.read().await;
cnt.inc(); cnt.inc();
} }
#[allow(unreachable_code)] #[allow(unreachable_code)]
Ok::<(), WispError>(()) Ok::<(), WispError>(())
})); }));
threads.push(tokio::spawn(async move {
loop {
cr.read().await;
}
}));
} }
let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100)); let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new(); let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
@ -151,7 +171,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let is_term = stdout().is_terminal(); let is_term = stdout().is_terminal();
loop { loop {
interval.tick().await; interval.tick().await;
let now = cnt.get(); let now = cnt_avg.get();
let stat = format!( let stat = format!(
"sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}", "sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}",
now, now,
@ -169,9 +189,33 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
} }
})); }));
let out = select_all(threads.into_iter()).await; threads.push(tokio::spawn(async move {
let mut interrupt =
signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?;
let mut terminate =
signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?;
select! {
_ = interrupt.recv() => (),
_ = terminate.recv() => (),
}
Ok(())
}));
println!("\n\nout: {:?}", out.0); if let Some(duration) = opts.duration {
threads.push(tokio::spawn(async move {
sleep(duration.into()).await;
Ok(())
}));
}
let _ = select_all(threads.into_iter()).await;
println!(
"\n\nresults: {} packets of &[0; 1024 * {}] sent in {}",
cnt.get(),
opts.packet_size,
format_duration(Instant::now().duration_since(start_time))
);
Ok(()) Ok(())
} }