diff --git a/server/src/config.rs b/server/src/config.rs index 79f9520..8187b08 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -93,24 +93,17 @@ pub struct ServerConfig { pub tcp_nodelay: bool, /// Whether or not to set "raw mode" for the file. pub file_raw_mode: bool, - #[serde(skip_serializing_if = "Option::is_none")] /// Keypair (public, private) in PEM format for TLS. pub tls_keypair: Option<[PathBuf; 2]>, - /// Whether or not to show what upstreams each client is connected to in stats. This can - /// heavily increase the size of the stats. - pub verbose_stats: bool, - /// Whether or not to respond to stats requests over HTTP. - pub enable_stats_endpoint: bool, /// Where to listen for stats requests over HTTP. - pub stats_endpoint: StatsEndpoint, + pub stats_endpoint: Option, + /// Whether or not to search for the x-real-ip or x-forwarded-for headers. + pub use_real_ip_headers: bool, /// String sent to a request that is not a websocket upgrade request. pub non_ws_response: String, - /// Prefix of Wisp server. Do NOT add a trailing slash here. - pub prefix: String, - /// Max WebSocket message size that can be recieved. pub max_message_size: usize, @@ -153,13 +146,13 @@ pub struct WispConfig { pub allow_wsproxy: bool, /// Buffer size advertised to the client. pub buffer_size: u32, + /// Prefix of Wisp server. Do NOT add a trailing slash here. + pub prefix: String, /// Whether or not to use Wisp version 2. pub wisp_v2: bool, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Wisp version 2 extensions advertised. pub extensions: Vec, - #[serde(skip_serializing_if = "Option::is_none")] /// Wisp version 2 authentication extension advertised. pub auth_extension: Option, @@ -189,7 +182,6 @@ pub struct StreamConfig { #[cfg(feature = "twisp")] pub allow_twisp: bool, - #[serde(skip_serializing_if = "Vec::is_empty")] /// DNS servers to resolve with. Will default to system configuration. pub dns_servers: Vec, @@ -205,31 +197,23 @@ pub struct StreamConfig { /// Whether or not to allow connections to non-globally-routable IP addresses. pub allow_non_global: bool, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex whitelist of hosts for TCP connections. pub allow_tcp_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex blacklist of hosts for TCP connections. pub block_tcp_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex whitelist of hosts for UDP connections. pub allow_udp_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex blacklist of hosts for UDP connections. pub block_udp_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex whitelist of hosts. pub allow_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Regex blacklist of hosts. pub block_hosts: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Range whitelist of ports. Format is `[lower_bound, upper_bound]`. pub allow_ports: Vec>, - #[serde(skip_serializing_if = "Vec::is_empty")] /// Range blacklist of ports. Format is `[lower_bound, upper_bound]`. pub block_ports: Vec>, } @@ -287,18 +271,12 @@ lazy_static! { pub async fn validate_config_cache() { // constructs regexes let _ = CONFIG_CACHE.allowed_ports; - // constructs wisp config + // validates wisp config CONFIG.wisp.to_opts().await.unwrap(); // constructs resolver RESOLVER.clear_cache(); } -impl Default for StatsEndpoint { - fn default() -> Self { - Self::SameServer("/stats".to_string()) - } -} - impl StatsEndpoint { pub fn get_endpoint(&self) -> Option { match self { @@ -325,14 +303,11 @@ impl Default for ServerConfig { file_raw_mode: false, tls_keypair: None, - verbose_stats: true, - enable_stats_endpoint: false, - stats_endpoint: StatsEndpoint::default(), + stats_endpoint: None, + use_real_ip_headers: false, non_ws_response: ":3".to_string(), - prefix: String::new(), - max_message_size: 64 * 1024, log_level: LevelFilter::Info, @@ -346,6 +321,7 @@ impl Default for WispConfig { Self { buffer_size: 128, allow_wsproxy: true, + prefix: String::new(), wisp_v2: true, extensions: vec![ProtocolExtension::Udp, ProtocolExtension::Motd], diff --git a/server/src/main.rs b/server/src/main.rs index 9549338..f1fa6c4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -10,11 +10,12 @@ use dashmap::DashMap; use handle::{handle_wisp, handle_wsproxy}; use hickory_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, + system_conf::read_system_conf, TokioAsyncResolver, }; use lazy_static::lazy_static; use listener::ServerListener; -use log::{error, info}; +use log::{error, info, warn}; use route::{route_stats, ServerRouteResult}; use serde::Serialize; use tokio::{ @@ -73,7 +74,12 @@ lazy_static! { pub static ref CLIENTS: DashMap = DashMap::new(); pub static ref RESOLVER: Resolver = { if CONFIG.stream.dns_servers.is_empty() { - Resolver::System + if let Ok((config, opts)) = read_system_conf() { + Resolver::Hickory(TokioAsyncResolver::tokio(config, opts)) + } else { + warn!("unable to read system dns configuration. using system dns resolver with no caching"); + Resolver::System + } } else { Resolver::Hickory(TokioAsyncResolver::tokio( ResolverConfig::from_parts( @@ -240,37 +246,48 @@ fn main() -> anyhow::Result<()> { .await .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; - if CONFIG.server.enable_stats_endpoint { - if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() { - info!("stats server listening on {:?}", bind_addr); - let mut stats_listener = - ServerListener::new(&bind_addr).await.with_context(|| { - format!("failed to bind to address {} for stats server", bind_addr.1) - })?; + if let Some(bind_addr) = CONFIG + .server + .stats_endpoint + .as_ref() + .and_then(|x| x.get_bindaddr()) + { + info!("stats server listening on {:?}", bind_addr); + let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| { + format!("failed to bind to address {} for stats server", bind_addr.1) + })?; - tokio::spawn(async move { - loop { - match stats_listener.accept().await { - Ok((stream, _)) => { - if let Err(e) = route_stats(stream).await { - error!("error while routing stats client: {:?}", e); - } + tokio::spawn(async move { + loop { + match stats_listener.accept().await { + Ok((stream, _)) => { + if let Err(e) = route_stats(stream).await { + error!("error while routing stats client: {:?}", e); } - Err(e) => error!("error while accepting stats client: {:?}", e), } + Err(e) => error!("error while accepting stats client: {:?}", e), } - }); - } + } + }); } - let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint(); + let stats_endpoint = CONFIG + .server + .stats_endpoint + .as_ref() + .and_then(|x| x.get_endpoint()); loop { let stats_endpoint = stats_endpoint.clone(); match listener.accept().await { - Ok((stream, id)) => { + Ok((stream, client_id)) => { tokio::spawn(async move { - let res = route::route(stream, stats_endpoint, move |stream| { - handle_stream(stream, id) + let res = route::route(stream, stats_endpoint, move |stream, maybe_ip| { + let client_id = if let Some(ip) = maybe_ip { + format!("{} ({})", client_id, ip) + } else { + client_id + }; + handle_stream(stream, client_id) }) .await; diff --git a/server/src/route.rs b/server/src/route.rs index 8220f75..1efc978 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -5,8 +5,8 @@ use bytes::Bytes; use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; use http_body_util::Full; use hyper::{ - body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, - StatusCode, + body::Incoming, server::conn::http1::Builder, service::service_fn, HeaderMap, Request, + Response, StatusCode, }; use hyper_util::rt::TokioIo; use log::{debug, error}; @@ -52,20 +52,29 @@ fn send_stats() -> anyhow::Result> { } } -async fn ws_upgrade( +fn get_header(headers: &HeaderMap, header: &str) -> Option { + headers.get(header).and_then(|x| x.to_str().ok()).map(|x| x.to_string()) +} + +enum HttpUpgradeResult { + Wisp, + WsProxy(String, bool), +} + +async fn ws_upgrade( mut req: Request, stats_endpoint: Option, - callback: T, + callback: F, ) -> anyhow::Result> where - T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, + F: FnOnce(UpgradeFut, HttpUpgradeResult, Option) -> R + Send + 'static, R: Future> + Send, { let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); if !is_upgrade { if let Some(stats_endpoint) = stats_endpoint { - if CONFIG.server.enable_stats_endpoint && req.uri().path() == stats_endpoint { + if req.uri().path() == stats_endpoint { return send_stats(); } else { debug!("sent non_ws_response to http client"); @@ -81,20 +90,33 @@ where // replace body of Empty with Full let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); + let headers = req.headers(); + let ip_header = if CONFIG.server.use_real_ip_headers { + get_header(headers, "x-real-ip").or_else(|| get_header(headers, "x-forwarded-for")) + } else { + None + }; + if req .uri() .path() - .starts_with(&(CONFIG.server.prefix.clone() + "/")) + .starts_with(&(CONFIG.wisp.prefix.clone() + "/")) { tokio::spawn(async move { - if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await { + if let Err(err) = (callback)(fut, HttpUpgradeResult::Wisp, ip_header).await { error!("error while serving client: {:?}", err); } }); } else if CONFIG.wisp.allow_wsproxy { let udp = req.uri().query().unwrap_or_default() == "?udp"; tokio::spawn(async move { - if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await { + if let Err(err) = (callback)( + fut, + HttpUpgradeResult::WsProxy(req.uri().path().to_string(), udp), + ip_header, + ) + .await + { error!("error while serving client: {:?}", err); } }); @@ -117,7 +139,7 @@ pub async fn route_stats(stream: ServerStream) -> anyhow::Result<()> { pub async fn route( stream: ServerStream, stats_endpoint: Option, - callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, + callback: impl FnOnce(ServerRouteResult, Option) + Clone + Send + 'static, ) -> anyhow::Result<()> { match CONFIG.server.transport { SocketTransport::WebSocket => { @@ -132,28 +154,37 @@ pub async fn route( ws_upgrade( req, stats_endpoint.clone(), - |fut, wsproxy, udp, path| async move { + |fut, res, maybe_ip| async move { let mut ws = fut.await.context("failed to await upgrade future")?; ws.set_max_message_size(CONFIG.server.max_message_size); ws.set_auto_pong(false); - if wsproxy { - let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); - (callback)(ServerRouteResult::WsProxy(ws, path, udp)); - } else { - let (read, write) = ws.split(|x| { - let parts = x - .into_inner() - .downcast::>() - .unwrap(); - let (r, w) = parts.io.into_inner().split(); - (Cursor::new(parts.read_buf).chain(r), w) - }); + match res { + HttpUpgradeResult::Wisp => { + let (read, write) = ws.split(|x| { + let parts = x + .into_inner() + .downcast::>() + .unwrap(); + let (r, w) = parts.io.into_inner().split(); + (Cursor::new(parts.read_buf).chain(r), w) + }); - (callback)(ServerRouteResult::Wisp(( - Box::new(read), - Box::new(write), - ))) + (callback)( + ServerRouteResult::Wisp(( + Box::new(read), + Box::new(write), + )), + maybe_ip, + ) + } + HttpUpgradeResult::WsProxy(path, udp) => { + let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); + (callback)( + ServerRouteResult::WsProxy(ws, path, udp), + maybe_ip, + ); + } } Ok(()) @@ -174,7 +205,10 @@ pub async fn route( let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); - (callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write)))); + (callback)( + ServerRouteResult::Wisp((Box::new(read), Box::new(write))), + None, + ); } } Ok(())