fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets

This commit is contained in:
Toshit Chawda 2024-04-13 22:34:26 -07:00
parent 4d433b60c4
commit d10b7691e4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 220 additions and 50 deletions

View file

@ -1,5 +1,5 @@
#![feature(let_chains, ip)]
use std::io::Error;
use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
use bytes::Bytes;
use clap::Parser;
@ -20,8 +20,11 @@ use tokio_util::codec::{BytesCodec, Framed};
use tokio_util::either::Either;
use wisp_mux::{
extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux,
StreamType, WispError,
extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder,
},
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
};
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -56,6 +59,20 @@ struct Cli {
/// Whether the server should block ports other than 80 or 443
#[arg(long)]
block_non_http: bool,
/// Path to a file containing `user:password` separated by newlines. This is plaintext!!!
///
/// `user` cannot contain `:`. Whitespace will be trimmed.
#[arg(long)]
auth: Option<PathBuf>,
}
#[derive(Clone)]
struct MuxOptions {
pub block_local: bool,
pub block_udp: bool,
pub block_non_http: bool,
pub enforce_auth: bool,
pub auth: Arc<PasswordProtocolExtensionBuilder>,
}
#[cfg(not(unix))]
@ -138,19 +155,44 @@ async fn main() -> Result<(), Error> {
"/".to_string()
};
let mut auth = HashMap::new();
let enforce_auth = opt.auth.is_some();
if let Some(file) = opt.auth {
let file = std::fs::read_to_string(file)?;
for entry in file.split('\n').filter_map(|x| {
if x.contains(':') {
Some(x.trim())
} else {
None
}
}) {
let split: Vec<_> = entry.split(':').collect();
let username = split[0];
let password = split[1..].join(":");
println!(
"adding username {:?} password {:?} to allowed auth",
username, password
);
auth.insert(username.to_string(), password.to_string());
}
}
let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth));
let mux_options = MuxOptions {
block_local: opt.block_local,
block_non_http: opt.block_non_http,
block_udp: opt.block_udp,
auth: pw_ext,
enforce_auth,
};
println!("listening on `{}` with prefix `{}`", addr, prefix);
while let Ok((stream, addr)) = socket.accept().await {
let prefix = prefix.clone();
let mux_options = mux_options.clone();
tokio::spawn(async move {
let service = service_fn(move |res| {
accept_http(
res,
addr.clone(),
prefix.clone(),
opt.block_local,
opt.block_udp,
opt.block_non_http,
)
accept_http(res, addr.clone(), prefix.clone(), mux_options.clone())
});
let conn = http1::Builder::new()
.serve_connection(TokioIo::new(stream), service)
@ -168,9 +210,7 @@ async fn accept_http(
mut req: Request<Incoming>,
addr: String,
prefix: String,
block_local: bool,
block_udp: bool,
block_non_http: bool,
mux_options: MuxOptions,
) -> Result<Response<HttpBody>, WebSocketError> {
let uri = req.uri().path().to_string();
if upgrade::is_upgrade_request(&req)
@ -179,12 +219,17 @@ async fn accept_http(
let (res, fut) = upgrade::upgrade(&mut req)?;
if uri.is_empty() {
tokio::spawn(async move {
accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await
});
tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await });
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
tokio::spawn(async move {
accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await
accept_wsproxy(
fut,
uri,
addr.clone(),
mux_options.block_local,
mux_options.block_non_http,
)
.await
});
}
@ -258,19 +303,41 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
async fn accept_ws(
ws: UpgradeFut,
addr: String,
block_local: bool,
block_non_http: bool,
block_udp: bool,
mux_options: MuxOptions,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (rx, tx) = ws.await?.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr);
let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new(
rx,
tx,
u32::MAX,
Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]),
)
.await?;
if !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"{:?}: client did not support auth or password was invalid",
addr
);
mux.close_extension_incompat().await?;
return Ok(());
}
(mux, fut)
} else {
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?
};
let (mut mux, fut) =
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?;
println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids);
println!(
"{:?}: downgraded: {} extensions supported: {:?}",
addr, mux.downgraded, mux.supported_extension_ids
);
tokio::spawn(async move {
if let Err(e) = fut.await {
@ -280,14 +347,14 @@ async fn accept_ws(
while let Some((packet, mut stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
if (block_non_http
if (mux_options.block_non_http
&& !(packet.destination_port == 80 || packet.destination_port == 443))
|| (block_udp && packet.stream_type == StreamType::Udp)
|| (mux_options.block_udp && packet.stream_type == StreamType::Udp)
{
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
return;
}
if block_local {
if mux_options.block_local {
match lookup_host(format!(
"{}:{}",
packet.destination_hostname, packet.destination_port