From 9905f45a9e6ee019919dcc58b58c721d4d2639e6 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 27 Jul 2024 23:33:00 -0700 Subject: [PATCH] add stream type specific whitelist/blacklists --- server/src/config.rs | 46 ++++++++++++++++++++++++++++++++++++++++++++ server/src/stream.rs | 43 +++++++++++++++++++++++++++++++++-------- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/server/src/config.rs b/server/src/config.rs index c5a809e..338d7d0 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -103,6 +103,16 @@ pub struct StreamConfig { /// Whether or not to allow connections to non-globally-routable IP addresses. pub allow_non_global: bool, + /// Regex whitelist of hosts for TCP connections. + pub allow_tcp_hosts: Vec, + /// Regex blacklist of hosts for TCP connections. + pub block_tcp_hosts: Vec, + + /// Regex whitelist of hosts for UDP connections. + pub allow_udp_hosts: Vec, + /// Regex blacklist of hosts for UDP connections. + pub block_udp_hosts: Vec, + /// Regex whitelist of hosts. pub allow_hosts: Vec, /// Regex blacklist of hosts. @@ -131,6 +141,12 @@ struct ConfigCache { pub allowed_hosts: RegexSet, pub blocked_hosts: RegexSet, + pub allowed_tcp_hosts: RegexSet, + pub blocked_tcp_hosts: RegexSet, + + pub allowed_udp_hosts: RegexSet, + pub blocked_udp_hosts: RegexSet, + pub wisp_config: (Option>, u32), } @@ -149,8 +165,16 @@ lazy_static! { .iter() .map(|x| x[0]..=x[1]) .collect(), + allowed_hosts: RegexSet::new(&CONFIG.stream.allow_hosts).unwrap(), blocked_hosts: RegexSet::new(&CONFIG.stream.block_hosts).unwrap(), + + allowed_tcp_hosts: RegexSet::new(&CONFIG.stream.allow_tcp_hosts).unwrap(), + blocked_tcp_hosts: RegexSet::new(&CONFIG.stream.block_tcp_hosts).unwrap(), + + allowed_udp_hosts: RegexSet::new(&CONFIG.stream.allow_udp_hosts).unwrap(), + blocked_udp_hosts: RegexSet::new(&CONFIG.stream.block_udp_hosts).unwrap(), + wisp_config: CONFIG.wisp.to_opts_inner().unwrap(), } }; @@ -242,6 +266,12 @@ impl Default for StreamConfig { allow_global: true, allow_non_global: true, + allow_tcp_hosts: Vec::new(), + block_tcp_hosts: Vec::new(), + + allow_udp_hosts: Vec::new(), + block_udp_hosts: Vec::new(), + allow_hosts: Vec::new(), block_hosts: Vec::new(), @@ -267,6 +297,22 @@ impl StreamConfig { pub fn blocked_hosts(&self) -> &RegexSet { &CONFIG_CACHE.blocked_hosts } + + pub fn allowed_tcp_hosts(&self) -> &RegexSet { + &CONFIG_CACHE.allowed_tcp_hosts + } + + pub fn blocked_tcp_hosts(&self) -> &RegexSet { + &CONFIG_CACHE.blocked_tcp_hosts + } + + pub fn allowed_udp_hosts(&self) -> &RegexSet { + &CONFIG_CACHE.allowed_udp_hosts + } + + pub fn blocked_udp_hosts(&self) -> &RegexSet { + &CONFIG_CACHE.blocked_udp_hosts + } } impl Config { diff --git a/server/src/stream.rs b/server/src/stream.rs index 71d13ba..b7c5c07 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -8,6 +8,7 @@ use bytes::BytesMut; use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; +use regex::RegexSet; use tokio::{ fs::{remove_file, try_exists}, net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream}, @@ -98,6 +99,26 @@ impl ServerListener { } } +fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool { + blocked.is_match(str) && !allowed.is_match(str) +} + +fn allowed_set(stream_type: StreamType) -> &'static RegexSet { + match stream_type { + StreamType::Tcp => CONFIG.stream.allowed_tcp_hosts(), + StreamType::Udp => CONFIG.stream.allowed_udp_hosts(), + StreamType::Unknown(_) => unreachable!(), + } +} + +fn blocked_set(stream_type: StreamType) -> &'static RegexSet { + match stream_type { + StreamType::Tcp => CONFIG.stream.blocked_tcp_hosts(), + StreamType::Udp => CONFIG.stream.blocked_udp_hosts(), + StreamType::Unknown(_) => unreachable!(), + } +} + pub enum ClientStream { Tcp(TcpStream), Udp(UdpSocket), @@ -151,14 +172,20 @@ impl ClientStream { } } - if CONFIG - .stream - .blocked_hosts() - .is_match(&packet.destination_hostname) - && !CONFIG - .stream - .allowed_hosts() - .is_match(&packet.destination_hostname) + if match_addr( + &packet.destination_hostname, + allowed_set(packet.stream_type), + blocked_set(packet.stream_type), + ) { + return Ok(ResolvedPacket::Blocked); + } + + // allow stream type whitelists through + if match_addr( + &packet.destination_hostname, + CONFIG.stream.allowed_hosts(), + CONFIG.stream.blocked_hosts(), + ) && !allowed_set(packet.stream_type).is_match(&packet.destination_hostname) { return Ok(ResolvedPacket::Blocked); }