fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets

This commit is contained in:
Toshit Chawda 2024-04-13 22:34:26 -07:00
parent 4d433b60c4
commit d10b7691e4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 220 additions and 50 deletions

View file

@ -1,5 +1,5 @@
#![feature(let_chains, ip)] #![feature(let_chains, ip)]
use std::io::Error; use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
@ -20,8 +20,11 @@ use tokio_util::codec::{BytesCodec, Framed};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{ use wisp_mux::{
extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux, extensions::{
StreamType, WispError, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder,
},
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
}; };
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -56,6 +59,20 @@ struct Cli {
/// Whether the server should block ports other than 80 or 443 /// Whether the server should block ports other than 80 or 443
#[arg(long)] #[arg(long)]
block_non_http: bool, block_non_http: bool,
/// Path to a file containing `user:password` separated by newlines. This is plaintext!!!
///
/// `user` cannot contain `:`. Whitespace will be trimmed.
#[arg(long)]
auth: Option<PathBuf>,
}
#[derive(Clone)]
struct MuxOptions {
pub block_local: bool,
pub block_udp: bool,
pub block_non_http: bool,
pub enforce_auth: bool,
pub auth: Arc<PasswordProtocolExtensionBuilder>,
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -138,19 +155,44 @@ async fn main() -> Result<(), Error> {
"/".to_string() "/".to_string()
}; };
let mut auth = HashMap::new();
let enforce_auth = opt.auth.is_some();
if let Some(file) = opt.auth {
let file = std::fs::read_to_string(file)?;
for entry in file.split('\n').filter_map(|x| {
if x.contains(':') {
Some(x.trim())
} else {
None
}
}) {
let split: Vec<_> = entry.split(':').collect();
let username = split[0];
let password = split[1..].join(":");
println!(
"adding username {:?} password {:?} to allowed auth",
username, password
);
auth.insert(username.to_string(), password.to_string());
}
}
let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth));
let mux_options = MuxOptions {
block_local: opt.block_local,
block_non_http: opt.block_non_http,
block_udp: opt.block_udp,
auth: pw_ext,
enforce_auth,
};
println!("listening on `{}` with prefix `{}`", addr, prefix); println!("listening on `{}` with prefix `{}`", addr, prefix);
while let Ok((stream, addr)) = socket.accept().await { while let Ok((stream, addr)) = socket.accept().await {
let prefix = prefix.clone(); let prefix = prefix.clone();
let mux_options = mux_options.clone();
tokio::spawn(async move { tokio::spawn(async move {
let service = service_fn(move |res| { let service = service_fn(move |res| {
accept_http( accept_http(res, addr.clone(), prefix.clone(), mux_options.clone())
res,
addr.clone(),
prefix.clone(),
opt.block_local,
opt.block_udp,
opt.block_non_http,
)
}); });
let conn = http1::Builder::new() let conn = http1::Builder::new()
.serve_connection(TokioIo::new(stream), service) .serve_connection(TokioIo::new(stream), service)
@ -168,9 +210,7 @@ async fn accept_http(
mut req: Request<Incoming>, mut req: Request<Incoming>,
addr: String, addr: String,
prefix: String, prefix: String,
block_local: bool, mux_options: MuxOptions,
block_udp: bool,
block_non_http: bool,
) -> Result<Response<HttpBody>, WebSocketError> { ) -> Result<Response<HttpBody>, WebSocketError> {
let uri = req.uri().path().to_string(); let uri = req.uri().path().to_string();
if upgrade::is_upgrade_request(&req) if upgrade::is_upgrade_request(&req)
@ -179,12 +219,17 @@ async fn accept_http(
let (res, fut) = upgrade::upgrade(&mut req)?; let (res, fut) = upgrade::upgrade(&mut req)?;
if uri.is_empty() { if uri.is_empty() {
tokio::spawn(async move { tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await });
accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await
});
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
tokio::spawn(async move { tokio::spawn(async move {
accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await accept_wsproxy(
fut,
uri,
addr.clone(),
mux_options.block_local,
mux_options.block_non_http,
)
.await
}); });
} }
@ -258,19 +303,41 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
async fn accept_ws( async fn accept_ws(
ws: UpgradeFut, ws: UpgradeFut,
addr: String, addr: String,
block_local: bool, mux_options: MuxOptions,
block_non_http: bool,
block_udp: bool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (rx, tx) = ws.await?.split(tokio::io::split); let (rx, tx) = ws.await?.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new(
rx,
tx,
u32::MAX,
Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]),
)
.await?;
if !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"{:?}: client did not support auth or password was invalid",
addr
);
mux.close_extension_incompat().await?;
return Ok(());
}
(mux, fut)
} else {
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?
};
let (mut mux, fut) = println!(
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; "{:?}: downgraded: {} extensions supported: {:?}",
addr, mux.downgraded, mux.supported_extension_ids
println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids); );
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = fut.await { if let Err(e) = fut.await {
@ -280,14 +347,14 @@ async fn accept_ws(
while let Some((packet, mut stream)) = mux.server_new_stream().await { while let Some((packet, mut stream)) = mux.server_new_stream().await {
tokio::spawn(async move { tokio::spawn(async move {
if (block_non_http if (mux_options.block_non_http
&& !(packet.destination_port == 80 || packet.destination_port == 443)) && !(packet.destination_port == 80 || packet.destination_port == 443))
|| (block_udp && packet.stream_type == StreamType::Udp) || (mux_options.block_udp && packet.stream_type == StreamType::Udp)
{ {
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
return; return;
} }
if block_local { if mux_options.block_local {
match lookup_host(format!( match lookup_host(format!(
"{}:{}", "{}:{}",
packet.destination_hostname, packet.destination_port packet.destination_hostname, packet.destination_port

View file

@ -27,7 +27,14 @@ use tokio::{
}; };
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError}; use wisp_mux::{
extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder,
},
ClientMux, StreamType, WispError,
};
#[derive(Debug)] #[derive(Debug)]
enum WispClientError { enum WispClientError {
@ -80,6 +87,11 @@ struct Cli {
/// Ask for UDP /// Ask for UDP
#[arg(short, long)] #[arg(short, long)]
udp: bool, udp: bool,
/// Enable auth: format is `username:password`
///
/// Usernames and passwords are sent in plaintext!!
#[arg(long)]
auth: Option<String>,
} }
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
@ -103,6 +115,13 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let addr_dest = opts.tcp.ip().to_string(); let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port(); let addr_dest_port = opts.tcp.port();
let auth = opts.auth.map(|auth| {
let split: Vec<_> = auth.split(':').collect();
let username = split[0].to_string();
let password = split[1..].join(":");
PasswordProtocolExtensionBuilder::new_client(username, password)
});
println!( println!(
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
opts.wisp, opts.packet_size, opts.tcp, opts.streams, opts.wisp, opts.packet_size, opts.tcp, opts.streams,
@ -133,18 +152,49 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let (mut mux, fut) = if opts.udp { let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Sync)>> = Vec::new();
let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?; if opts.udp {
if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) { extensions.push(Box::new(UdpProtocolExtensionBuilder()));
println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); }
exit(1); let enforce_auth = auth.is_some();
} if let Some(auth) = auth {
(mux, fut) extensions.push(Box::new(auth));
} else { }
ClientMux::new(rx, tx, Some(&[])).await? let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> =
}; extensions.iter().map(|x| x.as_ref()).collect();
println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?;
if opts.udp
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
println!(
"server did not support udp, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
if enforce_auth
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
let mut threads = Vec::with_capacity(opts.streams * 2 + 3); let mut threads = Vec::with_capacity(opts.streams * 2 + 3);

View file

@ -202,6 +202,8 @@ pub mod udp {
pub mod password { pub mod password {
//! Password protocol extension. //! Password protocol extension.
//! //!
//! Passwords are sent in plain text!!
//!
//! # Example //! # Example
//! Server: //! Server:
//! ``` //! ```
@ -246,6 +248,7 @@ pub mod password {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Password protocol extension. /// Password protocol extension.
/// ///
/// **Passwords are sent in plain text!!**
/// **This extension will panic when encoding if the username's length does not fit within a u8 /// **This extension will panic when encoding if the username's length does not fit within a u8
/// or the password's length does not fit within a u16.** /// or the password's length does not fit within a u16.**
pub struct PasswordProtocolExtension { pub struct PasswordProtocolExtension {
@ -306,7 +309,7 @@ pub mod password {
let password = Bytes::from(self.password.clone().into_bytes()); let password = Bytes::from(self.password.clone().into_bytes());
let username_len = u8::try_from(username.len()).expect("username was too long"); let username_len = u8::try_from(username.len()).expect("username was too long");
let password_len = let password_len =
u16::try_from(username.len()).expect("password was too long"); u16::try_from(password.len()).expect("password was too long");
let mut bytes = let mut bytes =
BytesMut::with_capacity(3 + username_len as usize + password_len as usize); BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
@ -380,6 +383,8 @@ pub mod password {
} }
/// Password protocol extension builder. /// Password protocol extension builder.
///
/// **Passwords are sent in plain text!!**
pub struct PasswordProtocolExtensionBuilder { pub struct PasswordProtocolExtensionBuilder {
/// Map of users and their passwords to allow. Only used on server. /// Map of users and their passwords to allow. Only used on server.
pub users: HashMap<String, String>, pub users: HashMap<String, String>,
@ -448,6 +453,7 @@ pub mod password {
if *user.1 != password { if *user.1 != password {
return Err(EError::InvalidPassword.into()); return Err(EError::InvalidPassword.into());
} }
Ok(PasswordProtocolExtension { Ok(PasswordProtocolExtension {
username, username,
password, password,

View file

@ -64,6 +64,8 @@ pub enum WispError {
UriHasNoPort, UriHasNoPort,
/// The max stream count was reached. /// The max stream count was reached.
MaxStreamCountReached, MaxStreamCountReached,
/// The Wisp protocol version was incompatible.
IncompatibleProtocolVersion,
/// The stream had already been closed. /// The stream had already been closed.
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame received had an invalid type. /// The websocket frame received had an invalid type.
@ -117,6 +119,7 @@ impl std::fmt::Display for WispError {
Self::UriHasNoHost => write!(f, "URI has no host"), Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"), Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"), Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
@ -286,7 +289,15 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
} }
WsEvent::EndFut => break, WsEvent::EndFut(x) => {
if let Some(reason) = x {
let _ = self
.tx
.write_frame(Packet::new_close(0, reason).into())
.await;
}
break;
}
} }
} }
} }
@ -364,6 +375,9 @@ impl MuxInner {
} }
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Close(_) => { Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect(); stream.stream.disconnect();
@ -410,6 +424,9 @@ impl MuxInner {
} }
} }
Close(_) => { Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect(); stream.stream.disconnect();
@ -532,15 +549,28 @@ impl ServerMux {
self.muxstream_recv.next().await self.muxstream_recv.next().await
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams. /// Close all streams.
/// ///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called. /// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> { pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx self.close_internal(None).await
.send(WsEvent::EndFut) }
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)
} }
} }
/// Client side multiplexor. /// Client side multiplexor.
@ -600,7 +630,7 @@ impl ClientMux {
x = read.wisp_read_frame(&write).fuse() => Some(x?), x = read.wisp_read_frame(&write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None _ = Delay::new(Duration::from_secs(5)).fuse() => None
} { } {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
if let PacketType::Info(info) = packet.packet_type { if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info supported_extensions = info
.extensions .extensions
@ -671,14 +701,27 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams. /// Close all streams.
/// ///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this /// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function. /// function.
pub async fn close(&mut self) -> Result<(), WispError> { pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx self.close_internal(None).await
.send(WsEvent::EndFut) }
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)
} }
} }

View file

@ -446,6 +446,10 @@ impl Packet {
minor: bytes.get_u8(), minor: bytes.get_u8(),
}; };
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new(); let mut extensions = Vec::new();
while bytes.remaining() > 4 { while bytes.remaining() > 4 {

View file

@ -27,7 +27,7 @@ pub(crate) enum WsEvent {
u16, u16,
oneshot::Sender<Result<MuxStream, WispError>>, oneshot::Sender<Result<MuxStream, WispError>>,
), ),
EndFut, EndFut(Option<CloseReason>),
} }
/// Read side of a multiplexor stream. /// Read side of a multiplexor stream.