From 566fc38cc9b68b5ddb561f692413579c6bfde618 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 17 Mar 2024 22:39:52 -0700 Subject: [PATCH] allow blocking local addresses --- server/src/main.rs | 91 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 18 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 6364c79..a10a363 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,4 @@ -#![feature(let_chains)] +#![feature(let_chains, ip)] use std::io::Error; use bytes::Bytes; @@ -12,7 +12,7 @@ use hyper::{ body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio_util::codec::{BytesCodec, Framed}; @@ -23,17 +23,29 @@ use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, Wis type HttpBody = http_body_util::Full; +/// Server implementation of the Wisp protocol in Rust, made for epoxy #[derive(Parser)] -#[command(version = clap::crate_version!(), about = "Server implementation of the Wisp protocol in Rust, made for epoxy.")] +#[command(version = clap::crate_version!())] struct Cli { + /// URL prefix the server should serve on #[arg(long, default_value = "")] prefix: String, + /// Port the server should bind to #[arg(long, short, default_value = "4000")] port: String, + /// Host the server should bind to #[arg(long = "host", short, value_name = "HOST", default_value = "0.0.0.0")] bind_host: String, + /// Whether the server should listen on a Unix socket located at the value of the bind_host + /// argument #[arg(long, short)] unix_socket: bool, + /// Whether the server should block IP addresses that are not globally reachable + /// + /// See https://doc.rust-lang.org/std/net/struct.Ipv4Addr.html#method.is_global for which IP + /// addresses are blocked + #[arg(long, short)] + block_local: bool, } #[cfg(not(unix))] @@ -107,11 +119,12 @@ async fn main() -> Result<(), Error> { println!("listening on `{}`", addr); while let Ok((stream, addr)) = socket.accept().await { - let prefix_cloned = opt.prefix.clone(); + let prefix = opt.prefix.clone(); tokio::spawn(async move { let io = TokioIo::new(stream); - let service = - service_fn(move |res| accept_http(res, addr.clone(), prefix_cloned.clone())); + let service = service_fn(move |res| { + accept_http(res, addr.clone(), prefix.clone(), opt.block_local) + }); let conn = http1::Builder::new() .serve_connection(io, service) .with_upgrades(); @@ -128,17 +141,18 @@ async fn accept_http( mut req: Request, addr: String, prefix: String, + block_local: bool, ) -> Result, WebSocketError> { - let uri = req.uri().clone().path().to_string(); + let uri = req.uri().path().to_string(); if upgrade::is_upgrade_request(&req) && let Some(uri) = uri.strip_prefix(&prefix) { let (res, fut) = upgrade::upgrade(&mut req)?; - if uri.is_empty() || uri == "/" { - tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); + if uri == "/" { + tokio::spawn(async move { accept_ws(fut, addr.clone(), block_local).await }); } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { - tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await }); + tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone(), block_local).await }); } Ok(Response::from_parts( @@ -202,6 +216,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result Result<(), Box> { let (rx, tx) = fut.await?.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); @@ -218,6 +233,29 @@ async fn accept_ws( while let Some((packet, stream)) = mux.server_new_stream().await { tokio::spawn(async move { + if block_local { + match lookup_host(format!( + "{}:{}", + packet.destination_hostname, packet.destination_port + )) + .await + .ok() + .and_then(|mut x| x.next()) + .map(|x| !x.ip().is_global()) + { + Some(true) => { + let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; + return; + } + Some(false) => {} + None => { + let _ = stream + .close(CloseReason::ServerStreamConnectionRefused) + .await; + return; + } + } + } let close_err = stream.get_close_handle(); let close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) @@ -244,18 +282,35 @@ async fn accept_wsproxy( fut: upgrade::UpgradeFut, incoming_uri: String, addr: String, + block_local: bool, ) -> Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?); println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); - match hyper::Uri::try_from(incoming_uri.clone()) { - Ok(_) => (), - Err(err) => { - ws_stream - .write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")) - .await?; - return Err(Box::new(err)); + if block_local { + match lookup_host(&incoming_uri) + .await + .ok() + .and_then(|mut x| x.next()) + .map(|x| !x.ip().is_global()) + { + Some(true) => { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) + .await?; + return Ok(()); + } + Some(false) => {} + None => { + ws_stream + .write_frame(Frame::close( + CloseCode::Error.into(), + b"failed to resolve uri", + )) + .await?; + return Ok(()); + } } } @@ -263,7 +318,7 @@ async fn accept_wsproxy( Ok(stream) => stream, Err(err) => { ws_stream - .write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect")) + .write_frame(Frame::close(CloseCode::Error.into(), b"failed to connect")) .await?; return Err(Box::new(err)); }