use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; use clap::Parser; use fastwebsockets::{handshake, FragmentCollectorRead}; use futures::future::select_all; use http_body_util::Empty; use humantime::format_duration; use hyper::{ header::{CONNECTION, UPGRADE}, Request, Uri, }; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ error::Error, future::Future, io::{stdout, IsTerminal, Write}, net::SocketAddr, process::exit, sync::Arc, time::{Duration, Instant}, }; use tokio::{ net::TcpStream, select, signal::unix::{signal, SignalKind}, time::{interval, sleep}, }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError}; #[derive(Debug)] enum WispClientError { InvalidUriScheme, UriHasNoHost, } impl std::fmt::Display for WispClientError { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { use WispClientError as E; match self { E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), E::UriHasNoHost => write!(fmt, "URI has no host"), } } } impl Error for WispClientError {} struct SpawnExecutor; impl hyper::rt::Executor for SpawnExecutor where Fut: Future + Send + 'static, Fut::Output: Send + 'static, { fn execute(&self, fut: Fut) { tokio::task::spawn(fut); } } #[derive(Parser)] #[command(version = clap::crate_version!())] struct Cli { /// Wisp server URL #[arg(short, long)] wisp: Uri, /// TCP server address #[arg(short, long)] tcp: SocketAddr, /// Number of streams #[arg(short, long, default_value_t = 10)] streams: usize, /// Size of packets sent, in KB #[arg(short, long, default_value_t = 1)] packet_size: usize, /// Duration to run the test for #[arg(short, long)] duration: Option, /// Ask for UDP #[arg(short, long)] udp: bool, } #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Box> { #[cfg(feature = "tokio-console")] console_subscriber::init(); let opts = Cli::parse(); let tls = match opts .wisp .scheme_str() .ok_or(WispClientError::InvalidUriScheme)? { "wss" => Ok(true), "ws" => Ok(false), _ => Err(WispClientError::InvalidUriScheme), }?; let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 }); let addr_path = opts.wisp.path(); let addr_dest = opts.tcp.ip().to_string(); let addr_dest_port = opts.tcp.port(); println!( "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", opts.wisp, opts.packet_size, opts.tcp, opts.streams, ); let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; let socket = if tls { let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); Either::Left(cx.connect(addr, socket).await?) } else { Either::Right(socket) }; let req = Request::builder() .method("GET") .uri(addr_path) .header("Host", addr) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") .header( "Sec-WebSocket-Key", fastwebsockets::handshake::generate_key(), ) .header("Sec-WebSocket-Version", "13") .body(Empty::::new())?; let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); let (mut mux, fut) = if opts.udp { let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?; if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) { println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); exit(1); } (mux, fut) } else { ClientMux::new(rx, tx, Some(&[])).await? }; println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); let mut threads = Vec::with_capacity(opts.streams * 2 + 3); threads.push(tokio::spawn(fut)); let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); let cnt = Arc::new(RelaxedCounter::new(0)); let start_time = Instant::now(); for _ in 0..opts.streams { let (mut cr, mut cw) = mux .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .await? .into_split(); let cnt = cnt.clone(); let payload = payload.clone(); threads.push(tokio::spawn(async move { loop { cw.write(payload.clone()).await?; cnt.inc(); } #[allow(unreachable_code)] Ok::<(), WispError>(()) })); threads.push(tokio::spawn(async move { loop { cr.read().await; } })); } let cnt_avg = cnt.clone(); threads.push(tokio::spawn(async move { let mut interval = interval(Duration::from_millis(100)); let mut avg: SingleSumSMA = SingleSumSMA::new(); let mut last_time = 0; let is_term = stdout().is_terminal(); loop { interval.tick().await; let now = cnt_avg.get(); let stat = format!( "sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} ({} KiB / 100ms), moving average (10 s): {:?} ({} KiB / 10 s)", opts.packet_size, now, now * opts.packet_size, now - last_time, (now - last_time) * opts.packet_size, avg.get_average(), avg.get_average() * opts.packet_size, ); if is_term { print!("\x1b[2K{}\r", stat); } else { println!("{}", stat); } stdout().flush().unwrap(); avg.add_sample(now - last_time); last_time = now; } })); threads.push(tokio::spawn(async move { let mut interrupt = signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; let mut terminate = signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; select! { _ = interrupt.recv() => (), _ = terminate.recv() => (), } Ok(()) })); if let Some(duration) = opts.duration { threads.push(tokio::spawn(async move { sleep(duration.into()).await; Ok(()) })); } let out = select_all(threads.into_iter()).await; if let Err(err) = out.0? { println!("\n\nerr: {:?}", err); exit(1); } out.2.into_iter().for_each(|x| x.abort()); let duration_since = Instant::now().duration_since(start_time); println!( "\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", cnt.get(), opts.packet_size, cnt.get() * opts.packet_size, format_duration(duration_since), (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), ); Ok(()) }