use system resolver if no dns servers specified, make invalid frame type more verbose

This commit is contained in:
Toshit Chawda 2024-09-22 09:07:44 -07:00
parent f798b5544e
commit fdd641c67f
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 53 additions and 16 deletions

View file

@ -74,6 +74,7 @@ pub struct ServerConfig {
pub tcp_nodelay: bool, pub tcp_nodelay: bool,
/// Whether or not to set "raw mode" for the file. /// Whether or not to set "raw mode" for the file.
pub file_raw_mode: bool, pub file_raw_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
/// Keypair (public, private) in PEM format for TLS. /// Keypair (public, private) in PEM format for TLS.
pub tls_keypair: Option<[PathBuf; 2]>, pub tls_keypair: Option<[PathBuf; 2]>,
@ -82,12 +83,15 @@ pub struct ServerConfig {
pub verbose_stats: bool, pub verbose_stats: bool,
/// Whether or not to respond to stats requests over HTTP. /// Whether or not to respond to stats requests over HTTP.
pub enable_stats_endpoint: bool, pub enable_stats_endpoint: bool,
#[serde(skip_serializing_if = "String::is_empty")]
/// Path of stats HTTP endpoint. /// Path of stats HTTP endpoint.
pub stats_endpoint: String, pub stats_endpoint: String,
#[serde(skip_serializing_if = "String::is_empty")]
/// String sent to a request that is not a websocket upgrade request. /// String sent to a request that is not a websocket upgrade request.
pub non_ws_response: String, pub non_ws_response: String,
#[serde(skip_serializing_if = "String::is_empty")]
/// Prefix of Wisp server. Do NOT add a trailing slash here. /// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String, pub prefix: String,
@ -126,13 +130,17 @@ pub struct WispConfig {
/// Whether or not to use Wisp version 2. /// Whether or not to use Wisp version 2.
pub wisp_v2: bool, pub wisp_v2: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Wisp version 2 extensions advertised. /// Wisp version 2 extensions advertised.
pub extensions: Vec<ProtocolExtension>, pub extensions: Vec<ProtocolExtension>,
#[serde(skip_serializing_if = "Option::is_none")]
/// Wisp version 2 authentication extension advertised. /// Wisp version 2 authentication extension advertised.
pub auth_extension: Option<ProtocolExtensionAuth>, pub auth_extension: Option<ProtocolExtensionAuth>,
#[serde(skip_serializing_if = "HashMap::is_empty")]
/// Wisp version 2 password authentication extension username/passwords. /// Wisp version 2 password authentication extension username/passwords.
pub password_extension_users: HashMap<String, String>, pub password_extension_users: HashMap<String, String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Wisp version 2 certificate authentication extension public ed25519 pem keys. /// Wisp version 2 certificate authentication extension public ed25519 pem keys.
pub certificate_extension_keys: Vec<PathBuf>, pub certificate_extension_keys: Vec<PathBuf>,
@ -154,6 +162,7 @@ pub struct StreamConfig {
#[cfg(feature = "twisp")] #[cfg(feature = "twisp")]
pub allow_twisp: bool, pub allow_twisp: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// DNS servers to resolve with. Will default to system configuration. /// DNS servers to resolve with. Will default to system configuration.
pub dns_servers: Vec<IpAddr>, pub dns_servers: Vec<IpAddr>,
@ -169,23 +178,31 @@ pub struct StreamConfig {
/// Whether or not to allow connections to non-globally-routable IP addresses. /// Whether or not to allow connections to non-globally-routable IP addresses.
pub allow_non_global: bool, pub allow_non_global: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for TCP connections. /// Regex whitelist of hosts for TCP connections.
pub allow_tcp_hosts: Vec<String>, pub allow_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for TCP connections. /// Regex blacklist of hosts for TCP connections.
pub block_tcp_hosts: Vec<String>, pub block_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for UDP connections. /// Regex whitelist of hosts for UDP connections.
pub allow_udp_hosts: Vec<String>, pub allow_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for UDP connections. /// Regex blacklist of hosts for UDP connections.
pub block_udp_hosts: Vec<String>, pub block_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts. /// Regex whitelist of hosts.
pub allow_hosts: Vec<String>, pub allow_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts. /// Regex blacklist of hosts.
pub block_hosts: Vec<String>, pub block_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range whitelist of ports. Format is `[lower_bound, upper_bound]`. /// Range whitelist of ports. Format is `[lower_bound, upper_bound]`.
pub allow_ports: Vec<Vec<u16>>, pub allow_ports: Vec<Vec<u16>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range blacklist of ports. Format is `[lower_bound, upper_bound]`. /// Range blacklist of ports. Format is `[lower_bound, upper_bound]`.
pub block_ports: Vec<Vec<u16>>, pub block_ports: Vec<Vec<u16>>,
} }

View file

@ -1,7 +1,7 @@
#![feature(ip)] #![feature(ip)]
#![deny(clippy::todo)] #![deny(clippy::todo)]
use std::{fmt::Write, fs::read_to_string}; use std::{fmt::Write, fs::read_to_string, net::IpAddr};
use clap::Parser; use clap::Parser;
use config::{validate_config_cache, Cli, Config}; use config::{validate_config_cache, Cli, Config};
@ -27,6 +27,29 @@ mod stream;
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool); type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
pub enum Resolver {
Hickory(TokioAsyncResolver),
System,
}
impl Resolver {
pub async fn resolve(&self, host: String) -> anyhow::Result<Box<dyn Iterator<Item = IpAddr>>> {
match self {
Self::Hickory(resolver) => Ok(Box::new(resolver.lookup_ip(host).await?.into_iter())),
Self::System => Ok(Box::new(
tokio::net::lookup_host(host + ":0").await?.map(|x| x.ip()),
)),
}
}
pub fn clear_cache(&self) {
match self {
Self::Hickory(resolver) => resolver.clear_cache(),
Self::System => {}
}
}
}
lazy_static! { lazy_static! {
pub static ref CLI: Cli = Cli::parse(); pub static ref CLI: Cli = Cli::parse();
pub static ref CONFIG: Config = { pub static ref CONFIG: Config = {
@ -37,21 +60,19 @@ lazy_static! {
} }
}; };
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new(); pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
pub static ref RESOLVER: TokioAsyncResolver = { pub static ref RESOLVER: Resolver = {
let (config, opts) = if CONFIG.stream.dns_servers.is_empty() { if CONFIG.stream.dns_servers.is_empty() {
hickory_resolver::system_conf::read_system_conf().unwrap() Resolver::System
} else { } else {
( Resolver::Hickory(TokioAsyncResolver::tokio(
ResolverConfig::from_parts( ResolverConfig::from_parts(
None, None,
Vec::new(), Vec::new(),
NameServerConfigGroup::from_ips_clear(&CONFIG.stream.dns_servers, 53, true), NameServerConfigGroup::from_ips_clear(&CONFIG.stream.dns_servers, 53, true),
), ),
ResolverOpts::default(), ResolverOpts::default(),
) ))
}; }
TokioAsyncResolver::tokio(config, opts)
}; };
} }

View file

@ -126,10 +126,9 @@ impl ClientStream {
} }
let packet = RESOLVER let packet = RESOLVER
.lookup_ip(packet.destination_hostname) .resolve(packet.destination_hostname)
.await .await
.context("failed to resolve hostname")? .context("failed to resolve hostname")?
.iter()
.filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4()) .filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4())
.map(|x| ConnectPacket { .map(|x| ConnectPacket {
stream_type: packet.stream_type, stream_type: packet.stream_type,

View file

@ -70,7 +70,7 @@ pub enum WispError {
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame received had an invalid type. /// The websocket frame received had an invalid type.
WsFrameInvalidType, WsFrameInvalidType(ws::OpCode),
/// The websocket frame received was not finished. /// The websocket frame received was not finished.
WsFrameNotFinished, WsFrameNotFinished,
/// Error specific to the websocket implementation. /// Error specific to the websocket implementation.
@ -133,7 +133,7 @@ impl std::fmt::Display for WispError {
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"), Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"), Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), Self::WsFrameInvalidType(ty) => write!(f, "Invalid websocket frame type: {:?}", ty),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err), Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
Self::WsImplSocketClosed => { Self::WsImplSocketClosed => {

View file

@ -487,7 +487,7 @@ impl<'a> Packet<'a> {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
} }
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType(frame.opcode));
} }
let mut bytes = frame.payload; let mut bytes = frame.payload;
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
@ -511,7 +511,7 @@ impl<'a> Packet<'a> {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
} }
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType(frame.opcode));
} }
let mut bytes = frame.payload; let mut bytes = frame.payload;
if bytes.remaining() < 5 { if bytes.remaining() < 5 {
@ -587,7 +587,7 @@ impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
return Err(Self::Error::WsFrameNotFinished); return Err(Self::Error::WsFrameNotFinished);
} }
if frame.opcode != ws::OpCode::Binary { if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType); return Err(Self::Error::WsFrameInvalidType(frame.opcode));
} }
Packet::try_from(frame.payload) Packet::try_from(frame.payload)
} }