diff --git a/server/src/main.rs b/server/src/main.rs index 61d1897..12cc030 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,5 +1,5 @@ #![feature(let_chains, ip)] -use std::io::Error; +use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc}; use bytes::Bytes; use clap::Parser; @@ -20,8 +20,11 @@ use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::either::Either; use wisp_mux::{ - extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux, - StreamType, WispError, + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::UdpProtocolExtensionBuilder, + }, + CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -56,6 +59,20 @@ struct Cli { /// Whether the server should block ports other than 80 or 443 #[arg(long)] block_non_http: bool, + /// Path to a file containing `user:password` separated by newlines. This is plaintext!!! + /// + /// `user` cannot contain `:`. Whitespace will be trimmed. + #[arg(long)] + auth: Option, +} + +#[derive(Clone)] +struct MuxOptions { + pub block_local: bool, + pub block_udp: bool, + pub block_non_http: bool, + pub enforce_auth: bool, + pub auth: Arc, } #[cfg(not(unix))] @@ -138,19 +155,44 @@ async fn main() -> Result<(), Error> { "/".to_string() }; + let mut auth = HashMap::new(); + let enforce_auth = opt.auth.is_some(); + if let Some(file) = opt.auth { + let file = std::fs::read_to_string(file)?; + for entry in file.split('\n').filter_map(|x| { + if x.contains(':') { + Some(x.trim()) + } else { + None + } + }) { + let split: Vec<_> = entry.split(':').collect(); + let username = split[0]; + let password = split[1..].join(":"); + println!( + "adding username {:?} password {:?} to allowed auth", + username, password + ); + auth.insert(username.to_string(), password.to_string()); + } + } + let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth)); + + let mux_options = MuxOptions { + block_local: opt.block_local, + block_non_http: opt.block_non_http, + block_udp: opt.block_udp, + auth: pw_ext, + enforce_auth, + }; + println!("listening on `{}` with prefix `{}`", addr, prefix); while let Ok((stream, addr)) = socket.accept().await { let prefix = prefix.clone(); + let mux_options = mux_options.clone(); tokio::spawn(async move { let service = service_fn(move |res| { - accept_http( - res, - addr.clone(), - prefix.clone(), - opt.block_local, - opt.block_udp, - opt.block_non_http, - ) + accept_http(res, addr.clone(), prefix.clone(), mux_options.clone()) }); let conn = http1::Builder::new() .serve_connection(TokioIo::new(stream), service) @@ -168,9 +210,7 @@ async fn accept_http( mut req: Request, addr: String, prefix: String, - block_local: bool, - block_udp: bool, - block_non_http: bool, + mux_options: MuxOptions, ) -> Result, WebSocketError> { let uri = req.uri().path().to_string(); if upgrade::is_upgrade_request(&req) @@ -179,12 +219,17 @@ async fn accept_http( let (res, fut) = upgrade::upgrade(&mut req)?; if uri.is_empty() { - tokio::spawn(async move { - accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await - }); + tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await }); } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { tokio::spawn(async move { - accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await + accept_wsproxy( + fut, + uri, + addr.clone(), + mux_options.block_local, + mux_options.block_non_http, + ) + .await }); } @@ -258,19 +303,41 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result Result<(), Box> { let (rx, tx) = ws.await?.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); + let (mut mux, fut) = if mux_options.enforce_auth { + let (mut mux, fut) = ServerMux::new( + rx, + tx, + u32::MAX, + Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]), + ) + .await?; + if !mux + .supported_extension_ids + .iter() + .any(|x| *x == PasswordProtocolExtension::ID) + { + println!( + "{:?}: client did not support auth or password was invalid", + addr + ); + mux.close_extension_incompat().await?; + return Ok(()); + } + (mux, fut) + } else { + ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await? + }; - let (mut mux, fut) = - ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; - - println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids); + println!( + "{:?}: downgraded: {} extensions supported: {:?}", + addr, mux.downgraded, mux.supported_extension_ids + ); tokio::spawn(async move { if let Err(e) = fut.await { @@ -280,14 +347,14 @@ async fn accept_ws( while let Some((packet, mut stream)) = mux.server_new_stream().await { tokio::spawn(async move { - if (block_non_http + if (mux_options.block_non_http && !(packet.destination_port == 80 || packet.destination_port == 443)) - || (block_udp && packet.stream_type == StreamType::Udp) + || (mux_options.block_udp && packet.stream_type == StreamType::Udp) { let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; return; } - if block_local { + if mux_options.block_local { match lookup_host(format!( "{}:{}", packet.destination_hostname, packet.destination_port diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 9c38cad..1f5802f 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -27,7 +27,14 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError}; +use wisp_mux::{ + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + ProtocolExtensionBuilder, + }, + ClientMux, StreamType, WispError, +}; #[derive(Debug)] enum WispClientError { @@ -80,6 +87,11 @@ struct Cli { /// Ask for UDP #[arg(short, long)] udp: bool, + /// Enable auth: format is `username:password` + /// + /// Usernames and passwords are sent in plaintext!! + #[arg(long)] + auth: Option, } #[tokio::main(flavor = "multi_thread")] @@ -103,6 +115,13 @@ async fn main() -> Result<(), Box> { let addr_dest = opts.tcp.ip().to_string(); let addr_dest_port = opts.tcp.port(); + let auth = opts.auth.map(|auth| { + let split: Vec<_> = auth.split(':').collect(); + let username = split[0].to_string(); + let password = split[1..].join(":"); + PasswordProtocolExtensionBuilder::new_client(username, password) + }); + println!( "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", opts.wisp, opts.packet_size, opts.tcp, opts.streams, @@ -133,18 +152,49 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let (mut mux, fut) = if opts.udp { - let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?; - if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) { - println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); - exit(1); - } - (mux, fut) - } else { - ClientMux::new(rx, tx, Some(&[])).await? - }; + let mut extensions: Vec> = Vec::new(); + if opts.udp { + extensions.push(Box::new(UdpProtocolExtensionBuilder())); + } + let enforce_auth = auth.is_some(); + if let Some(auth) = auth { + extensions.push(Box::new(auth)); + } + let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> = + extensions.iter().map(|x| x.as_ref()).collect(); - println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); + let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?; + if opts.udp + && !mux + .supported_extension_ids + .iter() + .any(|x| *x == UdpProtocolExtension::ID) + { + println!( + "server did not support udp, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); + mux.close_extension_incompat().await?; + exit(1); + } + if enforce_auth + && !mux + .supported_extension_ids + .iter() + .any(|x| *x == PasswordProtocolExtension::ID) + { + println!( + "server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); + mux.close_extension_incompat().await?; + exit(1); + } + + println!( + "connected and created ClientMux, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); let mut threads = Vec::with_capacity(opts.streams * 2 + 3); diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs index dfefae6..661439c 100644 --- a/wisp/src/extensions.rs +++ b/wisp/src/extensions.rs @@ -202,6 +202,8 @@ pub mod udp { pub mod password { //! Password protocol extension. //! + //! Passwords are sent in plain text!! + //! //! # Example //! Server: //! ``` @@ -246,6 +248,7 @@ pub mod password { #[derive(Debug, Clone)] /// Password protocol extension. /// + /// **Passwords are sent in plain text!!** /// **This extension will panic when encoding if the username's length does not fit within a u8 /// or the password's length does not fit within a u16.** pub struct PasswordProtocolExtension { @@ -306,7 +309,7 @@ pub mod password { let password = Bytes::from(self.password.clone().into_bytes()); let username_len = u8::try_from(username.len()).expect("username was too long"); let password_len = - u16::try_from(username.len()).expect("password was too long"); + u16::try_from(password.len()).expect("password was too long"); let mut bytes = BytesMut::with_capacity(3 + username_len as usize + password_len as usize); @@ -380,6 +383,8 @@ pub mod password { } /// Password protocol extension builder. + /// + /// **Passwords are sent in plain text!!** pub struct PasswordProtocolExtensionBuilder { /// Map of users and their passwords to allow. Only used on server. pub users: HashMap, @@ -448,6 +453,7 @@ pub mod password { if *user.1 != password { return Err(EError::InvalidPassword.into()); } + Ok(PasswordProtocolExtension { username, password, diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 3c77584..d8f7dd0 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -64,6 +64,8 @@ pub enum WispError { UriHasNoPort, /// The max stream count was reached. MaxStreamCountReached, + /// The Wisp protocol version was incompatible. + IncompatibleProtocolVersion, /// The stream had already been closed. StreamAlreadyClosed, /// The websocket frame received had an invalid type. @@ -117,6 +119,7 @@ impl std::fmt::Display for WispError { Self::UriHasNoHost => write!(f, "URI has no host"), Self::UriHasNoPort => write!(f, "URI has no port"), Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), + Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"), Self::StreamAlreadyClosed => write!(f, "Stream already closed"), Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), @@ -286,7 +289,15 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } - WsEvent::EndFut => break, + WsEvent::EndFut(x) => { + if let Some(reason) = x { + let _ = self + .tx + .write_frame(Packet::new_close(0, reason).into()) + .await; + } + break; + } } } } @@ -364,6 +375,9 @@ impl MuxInner { } Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); stream.stream.disconnect(); @@ -410,6 +424,9 @@ impl MuxInner { } } Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); stream.stream.disconnect(); @@ -532,15 +549,28 @@ impl ServerMux { self.muxstream_recv.next().await } + async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + /// Close all streams. /// /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// this function is called. pub async fn close(&mut self) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut) + self.close_internal(None).await + } + + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. Waiting for a new stream will never succed after + /// this function is called. + pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await - .map_err(|_| WispError::MuxMessageFailedToSend) } } /// Client side multiplexor. @@ -600,7 +630,7 @@ impl ClientMux { x = read.wisp_read_frame(&write).fuse() => Some(x?), _ = Delay::new(Duration::from_secs(5)).fuse() => None } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; if let PacketType::Info(info) = packet.packet_type { supported_extensions = info .extensions @@ -671,14 +701,27 @@ impl ClientMux { rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } + async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + /// Close all streams. /// /// Also terminates the multiplexor future. Creating a stream is UB after calling this /// function. pub async fn close(&mut self) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut) + self.close_internal(None).await + } + + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. Creating a stream is UB after calling this + /// function. + pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await - .map_err(|_| WispError::MuxMessageFailedToSend) } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 9ff6a3c..41554a3 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -446,6 +446,10 @@ impl Packet { minor: bytes.get_u8(), }; + if version.major != WISP_VERSION.major { + return Err(WispError::IncompatibleProtocolVersion); + } + let mut extensions = Vec::new(); while bytes.remaining() > 4 { diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index f579140..69c711b 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -27,7 +27,7 @@ pub(crate) enum WsEvent { u16, oneshot::Sender>, ), - EndFut, + EndFut(Option), } /// Read side of a multiplexor stream.