diff --git a/server/src/main.rs b/server/src/main.rs index f76b0e2..feebd83 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -5,11 +5,12 @@ use bytes::Bytes; use clap::Parser; use fastwebsockets::{ upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, - WebSocketError, + WebSocket, WebSocketError, }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, + body::Incoming, server::conn::http1, service::service_fn, upgrade::Upgraded, Request, Response, + StatusCode, }; use hyper_util::rt::TokioIo; use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; @@ -46,6 +47,17 @@ struct Cli { /// addresses are blocked #[arg(long, short = 'B')] block_local: bool, + /// Whether the server should block UDP + /// + /// This does nothing for wsproxy as that is always TCP + #[arg(long)] + block_udp: bool, + /// Whether the server should block ports other than 80 or 443 + #[arg(long)] + block_non_http: bool, + /// Maximum WebSocket frame size allowed + #[arg(long, short, default_value_t = 64 << 20)] + frame_size: usize, } #[cfg(not(unix))] @@ -134,7 +146,15 @@ async fn main() -> Result<(), Error> { tokio::spawn(async move { let io = TokioIo::new(stream); let service = service_fn(move |res| { - accept_http(res, addr.clone(), prefix.clone(), opt.block_local) + accept_http( + res, + addr.clone(), + prefix.clone(), + opt.block_local, + opt.block_udp, + opt.block_non_http, + opt.frame_size, + ) }); let conn = http1::Builder::new() .serve_connection(io, service) @@ -153,6 +173,9 @@ async fn accept_http( addr: String, prefix: String, block_local: bool, + block_udp: bool, + block_non_http: bool, + max_size: usize, ) -> Result, WebSocketError> { let uri = req.uri().path().to_string(); if upgrade::is_upgrade_request(&req) @@ -160,10 +183,18 @@ async fn accept_http( { let (res, fut) = upgrade::upgrade(&mut req)?; + let mut ws = fut.await?; + + ws.set_max_message_size(max_size); + if uri.is_empty() { - tokio::spawn(async move { accept_ws(fut, addr.clone(), block_local).await }); + tokio::spawn(async move { + accept_ws(ws, addr.clone(), block_local, block_udp, block_non_http).await + }); } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { - tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone(), block_local).await }); + tokio::spawn(async move { + accept_wsproxy(ws, uri, addr.clone(), block_local, block_non_http).await + }); } Ok(Response::from_parts( @@ -230,11 +261,13 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result>, addr: String, block_local: bool, + block_non_http: bool, + block_udp: bool, ) -> Result<(), Box> { - let (rx, tx) = fut.await?.split(tokio::io::split); + let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); @@ -249,6 +282,13 @@ async fn accept_ws( while let Some((packet, mut stream)) = mux.server_new_stream().await { tokio::spawn(async move { + if (block_non_http + && !(packet.destination_port == 80 || packet.destination_port == 443)) + || (block_udp && packet.stream_type == StreamType::Udp) + { + let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; + return; + } if block_local { match lookup_host(format!( "{}:{}", @@ -295,39 +335,42 @@ async fn accept_ws( } async fn accept_wsproxy( - fut: upgrade::UpgradeFut, + ws: WebSocket>, incoming_uri: String, addr: String, block_local: bool, + block_non_http: bool, ) -> Result<(), Box> { - let mut ws_stream = FragmentCollector::new(fut.await?); + let mut ws_stream = FragmentCollector::new(ws); println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); - if block_local { - match lookup_host(&incoming_uri) - .await - .ok() - .and_then(|mut x| x.next()) - .map(|x| !x.ip().is_global()) - { - Some(true) => { - ws_stream - .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) - .await?; - return Ok(()); - } - Some(false) => {} - None => { - ws_stream - .write_frame(Frame::close( - CloseCode::Error.into(), - b"failed to resolve uri", - )) - .await?; - return Ok(()); - } - } + let Some(host) = lookup_host(&incoming_uri) + .await + .ok() + .and_then(|mut x| x.next()) + else { + ws_stream + .write_frame(Frame::close( + CloseCode::Error.into(), + b"failed to resolve uri", + )) + .await?; + return Ok(()); + }; + + if block_local && !host.ip().is_global() { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) + .await?; + return Ok(()); + } + + if block_non_http && !(host.port() == 80 || host.port() == 443) { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) + .await?; + return Ok(()); } let tcp_stream = match TcpStream::connect(incoming_uri).await { diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 3f9c609..39a0f2e 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -126,6 +126,8 @@ async fn main() -> Result<(), Box> { .header("Sec-WebSocket-Protocol", "wisp-v1") .body(Empty::::new())?; + println!("{:?}", req); + let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; let (rx, tx) = ws.split(tokio::io::split); @@ -136,7 +138,7 @@ async fn main() -> Result<(), Box> { threads.push(tokio::spawn(fut)); - let payload = Bytes::from_static(&[0; 1024]); + let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); let cnt = Arc::new(RelaxedCounter::new(0)); @@ -173,10 +175,14 @@ async fn main() -> Result<(), Box> { interval.tick().await; let now = cnt_avg.get(); let stat = format!( - "sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}", + "sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} ({} KiB / 100ms), moving average (10 s): {:?} ({} KiB / 10 s)", + opts.packet_size, now, + now * opts.packet_size, now - last_time, - avg.get_average() + (now - last_time) * opts.packet_size, + avg.get_average(), + avg.get_average() * opts.packet_size, ); if is_term { print!("\x1b[2K{}\r", stat); @@ -208,13 +214,23 @@ async fn main() -> Result<(), Box> { })); } - let _ = select_all(threads.into_iter()).await; + let out = select_all(threads.into_iter()).await; + + if let Err(err) = out.0? { + println!("\n\nerr: {:?}", err); + } + + out.2.into_iter().for_each(|x| x.abort()); + + let duration_since = Instant::now().duration_since(start_time); println!( - "\n\nresults: {} packets of &[0; 1024 * {}] sent in {}", + "\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", cnt.get(), opts.packet_size, - format_duration(Instant::now().duration_since(start_time)) + cnt.get() * opts.packet_size, + format_duration(duration_since), + (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), ); Ok(())