improve simple-wisp-client

This commit is contained in:
Toshit Chawda 2024-03-27 17:04:21 -07:00
parent 412e797321
commit 77e377c814
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 109 additions and 61 deletions

View file

@ -6,10 +6,12 @@ edition = "2021"
[dependencies]
atomic-counter = "1.0.1"
bytes = "1.5.0"
clap = { version = "4.5.4", features = ["cargo", "derive"] }
console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] }
futures = "0.3.30"
http-body-util = "0.1.0"
humantime = "2.1.0"
hyper = { version = "1.1.0", features = ["http1", "client"] }
simple_moving_average = "1.0.2"
tokio = { version = "1.36.0", features = ["full"] }

View file

@ -1,42 +1,51 @@
use atomic_counter::{AtomicCounter, RelaxedCounter};
use bytes::Bytes;
use clap::Parser;
use fastwebsockets::{handshake, FragmentCollectorRead};
use futures::future::select_all;
use http_body_util::Empty;
use humantime::format_duration;
use hyper::{
header::{CONNECTION, UPGRADE},
Request,
Request, Uri,
};
use simple_moving_average::{SingleSumSMA, SMA};
use std::{
error::Error,
future::Future,
io::{stdout, IsTerminal, Write},
net::SocketAddr,
sync::Arc,
time::Duration,
time::{Duration, Instant},
usize,
};
use tokio::{net::TcpStream, time::interval};
use tokio::{
net::TcpStream,
select,
signal::unix::{signal, SignalKind},
time::{interval, sleep},
};
use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either;
use wisp_mux::{ClientMux, StreamType, WispError};
#[derive(Debug)]
struct StrError(String);
impl StrError {
pub fn new(str: &str) -> Self {
Self(str.to_string())
}
enum WispClientError {
InvalidUriScheme,
UriHasNoHost,
}
impl std::fmt::Display for StrError {
impl std::fmt::Display for WispClientError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(fmt, "{}", self.0)
use WispClientError as E;
match self {
E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"),
E::UriHasNoHost => write!(fmt, "URI has no host"),
}
}
}
impl Error for StrError {}
impl Error for WispClientError {}
struct SpawnExecutor;
@ -50,59 +59,63 @@ where
}
}
#[derive(Parser)]
#[command(version = clap::crate_version!())]
struct Cli {
/// Wisp server URL
#[arg(short, long)]
wisp: Uri,
/// TCP server address
#[arg(short, long)]
tcp: SocketAddr,
/// Number of streams
#[arg(short, long, default_value_t = 10)]
streams: usize,
/// Size of packets sent, in KB
#[arg(short, long, default_value_t = 1)]
packet_size: usize,
/// Duration to run the test for
#[arg(short, long)]
duration: Option<humantime::Duration>,
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[cfg(feature = "tokio-console")]
console_subscriber::init();
let addr = std::env::args()
.nth(1)
.ok_or(StrError::new("no src addr"))?;
let opts = Cli::parse();
let addr_port: u16 = std::env::args()
.nth(2)
.ok_or(StrError::new("no src port"))?
.parse()?;
let addr_path = std::env::args()
.nth(3)
.ok_or(StrError::new("no src path"))?;
let addr_dest = std::env::args()
.nth(4)
.ok_or(StrError::new("no dest addr"))?;
let addr_dest_port: u16 = std::env::args()
.nth(5)
.ok_or(StrError::new("no dest port"))?
.parse()?;
let should_tls: bool = std::env::args()
.nth(6)
.ok_or(StrError::new("no should tls"))?
.parse()?;
let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?;
let tls = match opts
.wisp
.scheme_str()
.ok_or(WispClientError::InvalidUriScheme)?
{
"wss" => Ok(true),
"ws" => Ok(false),
_ => Err(WispClientError::InvalidUriScheme),
}?;
let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?;
let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 });
let addr_path = opts.wisp.path();
let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port();
println!(
"connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}",
if should_tls { "wss" } else { "ws" },
addr,
addr_port,
addr_path,
addr_dest,
addr_dest_port,
thread_cnt
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
);
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if should_tls {
let socket = if tls {
let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?);
Either::Left(cx.connect(&addr, socket).await?)
Either::Left(cx.connect(addr, socket).await?)
} else {
Either::Right(socket)
};
let req = Request::builder()
.method("GET")
.uri(addr_path)
.header("Host", &addr)
.header("Host", addr)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(
@ -119,7 +132,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?;
let mut threads = Vec::with_capacity(thread_cnt + 1);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
threads.push(tokio::spawn(fut));
@ -127,23 +140,30 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let cnt = Arc::new(RelaxedCounter::new(0));
for _ in 0..thread_cnt {
let mut channel = mux
let start_time = Instant::now();
for _ in 0..opts.streams {
let (mut cr, mut cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?;
.await?
.into_split();
let cnt = cnt.clone();
let payload = payload.clone();
threads.push(tokio::spawn(async move {
loop {
channel.write(payload.clone()).await?;
channel.read().await;
cw.write(payload.clone()).await?;
cnt.inc();
}
#[allow(unreachable_code)]
Ok::<(), WispError>(())
}));
threads.push(tokio::spawn(async move {
loop {
cr.read().await;
}
}));
}
let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
@ -151,7 +171,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let is_term = stdout().is_terminal();
loop {
interval.tick().await;
let now = cnt.get();
let now = cnt_avg.get();
let stat = format!(
"sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}",
now,
@ -169,9 +189,33 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
}
}));
let out = select_all(threads.into_iter()).await;
threads.push(tokio::spawn(async move {
let mut interrupt =
signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?;
let mut terminate =
signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?;
select! {
_ = interrupt.recv() => (),
_ = terminate.recv() => (),
}
Ok(())
}));
println!("\n\nout: {:?}", out.0);
if let Some(duration) = opts.duration {
threads.push(tokio::spawn(async move {
sleep(duration.into()).await;
Ok(())
}));
}
let _ = select_all(threads.into_iter()).await;
println!(
"\n\nresults: {} packets of &[0; 1024 * {}] sent in {}",
cnt.get(),
opts.packet_size,
format_duration(Instant::now().duration_since(start_time))
);
Ok(())
}