allow blocking local addresses

This commit is contained in:
Toshit Chawda 2024-03-17 22:39:52 -07:00
parent ce86e7b095
commit 566fc38cc9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D

View file

@ -1,4 +1,4 @@
#![feature(let_chains)] #![feature(let_chains, ip)]
use std::io::Error; use std::io::Error;
use bytes::Bytes; use bytes::Bytes;
@ -12,7 +12,7 @@ use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket};
#[cfg(unix)] #[cfg(unix)]
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
@ -23,17 +23,29 @@ use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, Wis
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
/// Server implementation of the Wisp protocol in Rust, made for epoxy
#[derive(Parser)] #[derive(Parser)]
#[command(version = clap::crate_version!(), about = "Server implementation of the Wisp protocol in Rust, made for epoxy.")] #[command(version = clap::crate_version!())]
struct Cli { struct Cli {
/// URL prefix the server should serve on
#[arg(long, default_value = "")] #[arg(long, default_value = "")]
prefix: String, prefix: String,
/// Port the server should bind to
#[arg(long, short, default_value = "4000")] #[arg(long, short, default_value = "4000")]
port: String, port: String,
/// Host the server should bind to
#[arg(long = "host", short, value_name = "HOST", default_value = "0.0.0.0")] #[arg(long = "host", short, value_name = "HOST", default_value = "0.0.0.0")]
bind_host: String, bind_host: String,
/// Whether the server should listen on a Unix socket located at the value of the bind_host
/// argument
#[arg(long, short)] #[arg(long, short)]
unix_socket: bool, unix_socket: bool,
/// Whether the server should block IP addresses that are not globally reachable
///
/// See https://doc.rust-lang.org/std/net/struct.Ipv4Addr.html#method.is_global for which IP
/// addresses are blocked
#[arg(long, short)]
block_local: bool,
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -107,11 +119,12 @@ async fn main() -> Result<(), Error> {
println!("listening on `{}`", addr); println!("listening on `{}`", addr);
while let Ok((stream, addr)) = socket.accept().await { while let Ok((stream, addr)) = socket.accept().await {
let prefix_cloned = opt.prefix.clone(); let prefix = opt.prefix.clone();
tokio::spawn(async move { tokio::spawn(async move {
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let service = let service = service_fn(move |res| {
service_fn(move |res| accept_http(res, addr.clone(), prefix_cloned.clone())); accept_http(res, addr.clone(), prefix.clone(), opt.block_local)
});
let conn = http1::Builder::new() let conn = http1::Builder::new()
.serve_connection(io, service) .serve_connection(io, service)
.with_upgrades(); .with_upgrades();
@ -128,17 +141,18 @@ async fn accept_http(
mut req: Request<Incoming>, mut req: Request<Incoming>,
addr: String, addr: String,
prefix: String, prefix: String,
block_local: bool,
) -> Result<Response<HttpBody>, WebSocketError> { ) -> Result<Response<HttpBody>, WebSocketError> {
let uri = req.uri().clone().path().to_string(); let uri = req.uri().path().to_string();
if upgrade::is_upgrade_request(&req) if upgrade::is_upgrade_request(&req)
&& let Some(uri) = uri.strip_prefix(&prefix) && let Some(uri) = uri.strip_prefix(&prefix)
{ {
let (res, fut) = upgrade::upgrade(&mut req)?; let (res, fut) = upgrade::upgrade(&mut req)?;
if uri.is_empty() || uri == "/" { if uri == "/" {
tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); tokio::spawn(async move { accept_ws(fut, addr.clone(), block_local).await });
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await }); tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone(), block_local).await });
} }
Ok(Response::from_parts( Ok(Response::from_parts(
@ -202,6 +216,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
async fn accept_ws( async fn accept_ws(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
addr: String, addr: String,
block_local: bool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (rx, tx) = fut.await?.split(tokio::io::split); let (rx, tx) = fut.await?.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
@ -218,6 +233,29 @@ async fn accept_ws(
while let Some((packet, stream)) = mux.server_new_stream().await { while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move { tokio::spawn(async move {
if block_local {
match lookup_host(format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
))
.await
.ok()
.and_then(|mut x| x.next())
.map(|x| !x.ip().is_global())
{
Some(true) => {
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
return;
}
Some(false) => {}
None => {
let _ = stream
.close(CloseReason::ServerStreamConnectionRefused)
.await;
return;
}
}
}
let close_err = stream.get_close_handle(); let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle(); let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream) let _ = handle_mux(packet, stream)
@ -244,18 +282,35 @@ async fn accept_wsproxy(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
incoming_uri: String, incoming_uri: String,
addr: String, addr: String,
block_local: bool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let mut ws_stream = FragmentCollector::new(fut.await?); let mut ws_stream = FragmentCollector::new(fut.await?);
println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri);
match hyper::Uri::try_from(incoming_uri.clone()) { if block_local {
Ok(_) => (), match lookup_host(&incoming_uri)
Err(err) => { .await
ws_stream .ok()
.write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")) .and_then(|mut x| x.next())
.await?; .map(|x| !x.ip().is_global())
return Err(Box::new(err)); {
Some(true) => {
ws_stream
.write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri"))
.await?;
return Ok(());
}
Some(false) => {}
None => {
ws_stream
.write_frame(Frame::close(
CloseCode::Error.into(),
b"failed to resolve uri",
))
.await?;
return Ok(());
}
} }
} }
@ -263,7 +318,7 @@ async fn accept_wsproxy(
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => { Err(err) => {
ws_stream ws_stream
.write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect")) .write_frame(Frame::close(CloseCode::Error.into(), b"failed to connect"))
.await?; .await?;
return Err(Box::new(err)); return Err(Box::new(err));
} }