move prefix to wisp config, add x-real-ip support

This commit is contained in:
Toshit Chawda 2024-10-02 17:37:25 -07:00
parent 88a35039c9
commit bca8be0bd2
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 111 additions and 84 deletions

View file

@ -93,24 +93,17 @@ pub struct ServerConfig {
pub tcp_nodelay: bool, pub tcp_nodelay: bool,
/// Whether or not to set "raw mode" for the file. /// Whether or not to set "raw mode" for the file.
pub file_raw_mode: bool, pub file_raw_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
/// Keypair (public, private) in PEM format for TLS. /// Keypair (public, private) in PEM format for TLS.
pub tls_keypair: Option<[PathBuf; 2]>, pub tls_keypair: Option<[PathBuf; 2]>,
/// Whether or not to show what upstreams each client is connected to in stats. This can
/// heavily increase the size of the stats.
pub verbose_stats: bool,
/// Whether or not to respond to stats requests over HTTP.
pub enable_stats_endpoint: bool,
/// Where to listen for stats requests over HTTP. /// Where to listen for stats requests over HTTP.
pub stats_endpoint: StatsEndpoint, pub stats_endpoint: Option<StatsEndpoint>,
/// Whether or not to search for the x-real-ip or x-forwarded-for headers.
pub use_real_ip_headers: bool,
/// String sent to a request that is not a websocket upgrade request. /// String sent to a request that is not a websocket upgrade request.
pub non_ws_response: String, pub non_ws_response: String,
/// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String,
/// Max WebSocket message size that can be recieved. /// Max WebSocket message size that can be recieved.
pub max_message_size: usize, pub max_message_size: usize,
@ -153,13 +146,13 @@ pub struct WispConfig {
pub allow_wsproxy: bool, pub allow_wsproxy: bool,
/// Buffer size advertised to the client. /// Buffer size advertised to the client.
pub buffer_size: u32, pub buffer_size: u32,
/// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String,
/// Whether or not to use Wisp version 2. /// Whether or not to use Wisp version 2.
pub wisp_v2: bool, pub wisp_v2: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Wisp version 2 extensions advertised. /// Wisp version 2 extensions advertised.
pub extensions: Vec<ProtocolExtension>, pub extensions: Vec<ProtocolExtension>,
#[serde(skip_serializing_if = "Option::is_none")]
/// Wisp version 2 authentication extension advertised. /// Wisp version 2 authentication extension advertised.
pub auth_extension: Option<ProtocolExtensionAuth>, pub auth_extension: Option<ProtocolExtensionAuth>,
@ -189,7 +182,6 @@ pub struct StreamConfig {
#[cfg(feature = "twisp")] #[cfg(feature = "twisp")]
pub allow_twisp: bool, pub allow_twisp: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// DNS servers to resolve with. Will default to system configuration. /// DNS servers to resolve with. Will default to system configuration.
pub dns_servers: Vec<IpAddr>, pub dns_servers: Vec<IpAddr>,
@ -205,31 +197,23 @@ pub struct StreamConfig {
/// Whether or not to allow connections to non-globally-routable IP addresses. /// Whether or not to allow connections to non-globally-routable IP addresses.
pub allow_non_global: bool, pub allow_non_global: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for TCP connections. /// Regex whitelist of hosts for TCP connections.
pub allow_tcp_hosts: Vec<String>, pub allow_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for TCP connections. /// Regex blacklist of hosts for TCP connections.
pub block_tcp_hosts: Vec<String>, pub block_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for UDP connections. /// Regex whitelist of hosts for UDP connections.
pub allow_udp_hosts: Vec<String>, pub allow_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for UDP connections. /// Regex blacklist of hosts for UDP connections.
pub block_udp_hosts: Vec<String>, pub block_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts. /// Regex whitelist of hosts.
pub allow_hosts: Vec<String>, pub allow_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts. /// Regex blacklist of hosts.
pub block_hosts: Vec<String>, pub block_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range whitelist of ports. Format is `[lower_bound, upper_bound]`. /// Range whitelist of ports. Format is `[lower_bound, upper_bound]`.
pub allow_ports: Vec<Vec<u16>>, pub allow_ports: Vec<Vec<u16>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range blacklist of ports. Format is `[lower_bound, upper_bound]`. /// Range blacklist of ports. Format is `[lower_bound, upper_bound]`.
pub block_ports: Vec<Vec<u16>>, pub block_ports: Vec<Vec<u16>>,
} }
@ -287,18 +271,12 @@ lazy_static! {
pub async fn validate_config_cache() { pub async fn validate_config_cache() {
// constructs regexes // constructs regexes
let _ = CONFIG_CACHE.allowed_ports; let _ = CONFIG_CACHE.allowed_ports;
// constructs wisp config // validates wisp config
CONFIG.wisp.to_opts().await.unwrap(); CONFIG.wisp.to_opts().await.unwrap();
// constructs resolver // constructs resolver
RESOLVER.clear_cache(); RESOLVER.clear_cache();
} }
impl Default for StatsEndpoint {
fn default() -> Self {
Self::SameServer("/stats".to_string())
}
}
impl StatsEndpoint { impl StatsEndpoint {
pub fn get_endpoint(&self) -> Option<String> { pub fn get_endpoint(&self) -> Option<String> {
match self { match self {
@ -325,14 +303,11 @@ impl Default for ServerConfig {
file_raw_mode: false, file_raw_mode: false,
tls_keypair: None, tls_keypair: None,
verbose_stats: true, stats_endpoint: None,
enable_stats_endpoint: false,
stats_endpoint: StatsEndpoint::default(),
use_real_ip_headers: false,
non_ws_response: ":3".to_string(), non_ws_response: ":3".to_string(),
prefix: String::new(),
max_message_size: 64 * 1024, max_message_size: 64 * 1024,
log_level: LevelFilter::Info, log_level: LevelFilter::Info,
@ -346,6 +321,7 @@ impl Default for WispConfig {
Self { Self {
buffer_size: 128, buffer_size: 128,
allow_wsproxy: true, allow_wsproxy: true,
prefix: String::new(),
wisp_v2: true, wisp_v2: true,
extensions: vec![ProtocolExtension::Udp, ProtocolExtension::Motd], extensions: vec![ProtocolExtension::Udp, ProtocolExtension::Motd],

View file

@ -10,11 +10,12 @@ use dashmap::DashMap;
use handle::{handle_wisp, handle_wsproxy}; use handle::{handle_wisp, handle_wsproxy};
use hickory_resolver::{ use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
system_conf::read_system_conf,
TokioAsyncResolver, TokioAsyncResolver,
}; };
use lazy_static::lazy_static; use lazy_static::lazy_static;
use listener::ServerListener; use listener::ServerListener;
use log::{error, info}; use log::{error, info, warn};
use route::{route_stats, ServerRouteResult}; use route::{route_stats, ServerRouteResult};
use serde::Serialize; use serde::Serialize;
use tokio::{ use tokio::{
@ -73,7 +74,12 @@ lazy_static! {
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new(); pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
pub static ref RESOLVER: Resolver = { pub static ref RESOLVER: Resolver = {
if CONFIG.stream.dns_servers.is_empty() { if CONFIG.stream.dns_servers.is_empty() {
Resolver::System if let Ok((config, opts)) = read_system_conf() {
Resolver::Hickory(TokioAsyncResolver::tokio(config, opts))
} else {
warn!("unable to read system dns configuration. using system dns resolver with no caching");
Resolver::System
}
} else { } else {
Resolver::Hickory(TokioAsyncResolver::tokio( Resolver::Hickory(TokioAsyncResolver::tokio(
ResolverConfig::from_parts( ResolverConfig::from_parts(
@ -240,37 +246,48 @@ fn main() -> anyhow::Result<()> {
.await .await
.with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?;
if CONFIG.server.enable_stats_endpoint { if let Some(bind_addr) = CONFIG
if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() { .server
info!("stats server listening on {:?}", bind_addr); .stats_endpoint
let mut stats_listener = .as_ref()
ServerListener::new(&bind_addr).await.with_context(|| { .and_then(|x| x.get_bindaddr())
format!("failed to bind to address {} for stats server", bind_addr.1) {
})?; info!("stats server listening on {:?}", bind_addr);
let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| {
format!("failed to bind to address {} for stats server", bind_addr.1)
})?;
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
match stats_listener.accept().await { match stats_listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
if let Err(e) = route_stats(stream).await { if let Err(e) = route_stats(stream).await {
error!("error while routing stats client: {:?}", e); error!("error while routing stats client: {:?}", e);
}
} }
Err(e) => error!("error while accepting stats client: {:?}", e),
} }
Err(e) => error!("error while accepting stats client: {:?}", e),
} }
}); }
} });
} }
let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint(); let stats_endpoint = CONFIG
.server
.stats_endpoint
.as_ref()
.and_then(|x| x.get_endpoint());
loop { loop {
let stats_endpoint = stats_endpoint.clone(); let stats_endpoint = stats_endpoint.clone();
match listener.accept().await { match listener.accept().await {
Ok((stream, id)) => { Ok((stream, client_id)) => {
tokio::spawn(async move { tokio::spawn(async move {
let res = route::route(stream, stats_endpoint, move |stream| { let res = route::route(stream, stats_endpoint, move |stream, maybe_ip| {
handle_stream(stream, id) let client_id = if let Some(ip) = maybe_ip {
format!("{} ({})", client_id, ip)
} else {
client_id
};
handle_stream(stream, client_id)
}) })
.await; .await;

View file

@ -5,8 +5,8 @@ use bytes::Bytes;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
use http_body_util::Full; use http_body_util::Full;
use hyper::{ use hyper::{
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, body::Incoming, server::conn::http1::Builder, service::service_fn, HeaderMap, Request,
StatusCode, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use log::{debug, error}; use log::{debug, error};
@ -52,20 +52,29 @@ fn send_stats() -> anyhow::Result<Response<Body>> {
} }
} }
async fn ws_upgrade<T, R>( fn get_header(headers: &HeaderMap, header: &str) -> Option<String> {
headers.get(header).and_then(|x| x.to_str().ok()).map(|x| x.to_string())
}
enum HttpUpgradeResult {
Wisp,
WsProxy(String, bool),
}
async fn ws_upgrade<F, R>(
mut req: Request<Incoming>, mut req: Request<Incoming>,
stats_endpoint: Option<String>, stats_endpoint: Option<String>,
callback: T, callback: F,
) -> anyhow::Result<Response<Body>> ) -> anyhow::Result<Response<Body>>
where where
T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, F: FnOnce(UpgradeFut, HttpUpgradeResult, Option<String>) -> R + Send + 'static,
R: Future<Output = anyhow::Result<()>> + Send, R: Future<Output = anyhow::Result<()>> + Send,
{ {
let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req);
if !is_upgrade { if !is_upgrade {
if let Some(stats_endpoint) = stats_endpoint { if let Some(stats_endpoint) = stats_endpoint {
if CONFIG.server.enable_stats_endpoint && req.uri().path() == stats_endpoint { if req.uri().path() == stats_endpoint {
return send_stats(); return send_stats();
} else { } else {
debug!("sent non_ws_response to http client"); debug!("sent non_ws_response to http client");
@ -81,20 +90,33 @@ where
// replace body of Empty<Bytes> with Full<Bytes> // replace body of Empty<Bytes> with Full<Bytes>
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
let headers = req.headers();
let ip_header = if CONFIG.server.use_real_ip_headers {
get_header(headers, "x-real-ip").or_else(|| get_header(headers, "x-forwarded-for"))
} else {
None
};
if req if req
.uri() .uri()
.path() .path()
.starts_with(&(CONFIG.server.prefix.clone() + "/")) .starts_with(&(CONFIG.wisp.prefix.clone() + "/"))
{ {
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await { if let Err(err) = (callback)(fut, HttpUpgradeResult::Wisp, ip_header).await {
error!("error while serving client: {:?}", err); error!("error while serving client: {:?}", err);
} }
}); });
} else if CONFIG.wisp.allow_wsproxy { } else if CONFIG.wisp.allow_wsproxy {
let udp = req.uri().query().unwrap_or_default() == "?udp"; let udp = req.uri().query().unwrap_or_default() == "?udp";
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await { if let Err(err) = (callback)(
fut,
HttpUpgradeResult::WsProxy(req.uri().path().to_string(), udp),
ip_header,
)
.await
{
error!("error while serving client: {:?}", err); error!("error while serving client: {:?}", err);
} }
}); });
@ -117,7 +139,7 @@ pub async fn route_stats(stream: ServerStream) -> anyhow::Result<()> {
pub async fn route( pub async fn route(
stream: ServerStream, stream: ServerStream,
stats_endpoint: Option<String>, stats_endpoint: Option<String>,
callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, callback: impl FnOnce(ServerRouteResult, Option<String>) + Clone + Send + 'static,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
match CONFIG.server.transport { match CONFIG.server.transport {
SocketTransport::WebSocket => { SocketTransport::WebSocket => {
@ -132,28 +154,37 @@ pub async fn route(
ws_upgrade( ws_upgrade(
req, req,
stats_endpoint.clone(), stats_endpoint.clone(),
|fut, wsproxy, udp, path| async move { |fut, res, maybe_ip| async move {
let mut ws = fut.await.context("failed to await upgrade future")?; let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size); ws.set_max_message_size(CONFIG.server.max_message_size);
ws.set_auto_pong(false); ws.set_auto_pong(false);
if wsproxy { match res {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); HttpUpgradeResult::Wisp => {
(callback)(ServerRouteResult::WsProxy(ws, path, udp)); let (read, write) = ws.split(|x| {
} else { let parts = x
let (read, write) = ws.split(|x| { .into_inner()
let parts = x .downcast::<TokioIo<ServerStream>>()
.into_inner() .unwrap();
.downcast::<TokioIo<ServerStream>>() let (r, w) = parts.io.into_inner().split();
.unwrap(); (Cursor::new(parts.read_buf).chain(r), w)
let (r, w) = parts.io.into_inner().split(); });
(Cursor::new(parts.read_buf).chain(r), w)
});
(callback)(ServerRouteResult::Wisp(( (callback)(
Box::new(read), ServerRouteResult::Wisp((
Box::new(write), Box::new(read),
))) Box::new(write),
)),
maybe_ip,
)
}
HttpUpgradeResult::WsProxy(path, udp) => {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(
ServerRouteResult::WsProxy(ws, path, udp),
maybe_ip,
);
}
} }
Ok(()) Ok(())
@ -174,7 +205,10 @@ pub async fn route(
let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone()));
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write)))); (callback)(
ServerRouteResult::Wisp((Box::new(read), Box::new(write))),
None,
);
} }
} }
Ok(()) Ok(())