move prefix to wisp config, add x-real-ip support

This commit is contained in:
Toshit Chawda 2024-10-02 17:37:25 -07:00
parent 88a35039c9
commit bca8be0bd2
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 111 additions and 84 deletions

View file

@ -93,24 +93,17 @@ pub struct ServerConfig {
pub tcp_nodelay: bool,
/// Whether or not to set "raw mode" for the file.
pub file_raw_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
/// Keypair (public, private) in PEM format for TLS.
pub tls_keypair: Option<[PathBuf; 2]>,
/// Whether or not to show what upstreams each client is connected to in stats. This can
/// heavily increase the size of the stats.
pub verbose_stats: bool,
/// Whether or not to respond to stats requests over HTTP.
pub enable_stats_endpoint: bool,
/// Where to listen for stats requests over HTTP.
pub stats_endpoint: StatsEndpoint,
pub stats_endpoint: Option<StatsEndpoint>,
/// Whether or not to search for the x-real-ip or x-forwarded-for headers.
pub use_real_ip_headers: bool,
/// String sent to a request that is not a websocket upgrade request.
pub non_ws_response: String,
/// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String,
/// Max WebSocket message size that can be recieved.
pub max_message_size: usize,
@ -153,13 +146,13 @@ pub struct WispConfig {
pub allow_wsproxy: bool,
/// Buffer size advertised to the client.
pub buffer_size: u32,
/// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String,
/// Whether or not to use Wisp version 2.
pub wisp_v2: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Wisp version 2 extensions advertised.
pub extensions: Vec<ProtocolExtension>,
#[serde(skip_serializing_if = "Option::is_none")]
/// Wisp version 2 authentication extension advertised.
pub auth_extension: Option<ProtocolExtensionAuth>,
@ -189,7 +182,6 @@ pub struct StreamConfig {
#[cfg(feature = "twisp")]
pub allow_twisp: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// DNS servers to resolve with. Will default to system configuration.
pub dns_servers: Vec<IpAddr>,
@ -205,31 +197,23 @@ pub struct StreamConfig {
/// Whether or not to allow connections to non-globally-routable IP addresses.
pub allow_non_global: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for TCP connections.
pub allow_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for TCP connections.
pub block_tcp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts for UDP connections.
pub allow_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts for UDP connections.
pub block_udp_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex whitelist of hosts.
pub allow_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Regex blacklist of hosts.
pub block_hosts: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range whitelist of ports. Format is `[lower_bound, upper_bound]`.
pub allow_ports: Vec<Vec<u16>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// Range blacklist of ports. Format is `[lower_bound, upper_bound]`.
pub block_ports: Vec<Vec<u16>>,
}
@ -287,18 +271,12 @@ lazy_static! {
pub async fn validate_config_cache() {
// constructs regexes
let _ = CONFIG_CACHE.allowed_ports;
// constructs wisp config
// validates wisp config
CONFIG.wisp.to_opts().await.unwrap();
// constructs resolver
RESOLVER.clear_cache();
}
impl Default for StatsEndpoint {
fn default() -> Self {
Self::SameServer("/stats".to_string())
}
}
impl StatsEndpoint {
pub fn get_endpoint(&self) -> Option<String> {
match self {
@ -325,14 +303,11 @@ impl Default for ServerConfig {
file_raw_mode: false,
tls_keypair: None,
verbose_stats: true,
enable_stats_endpoint: false,
stats_endpoint: StatsEndpoint::default(),
stats_endpoint: None,
use_real_ip_headers: false,
non_ws_response: ":3".to_string(),
prefix: String::new(),
max_message_size: 64 * 1024,
log_level: LevelFilter::Info,
@ -346,6 +321,7 @@ impl Default for WispConfig {
Self {
buffer_size: 128,
allow_wsproxy: true,
prefix: String::new(),
wisp_v2: true,
extensions: vec![ProtocolExtension::Udp, ProtocolExtension::Motd],

View file

@ -10,11 +10,12 @@ use dashmap::DashMap;
use handle::{handle_wisp, handle_wsproxy};
use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
system_conf::read_system_conf,
TokioAsyncResolver,
};
use lazy_static::lazy_static;
use listener::ServerListener;
use log::{error, info};
use log::{error, info, warn};
use route::{route_stats, ServerRouteResult};
use serde::Serialize;
use tokio::{
@ -73,7 +74,12 @@ lazy_static! {
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
pub static ref RESOLVER: Resolver = {
if CONFIG.stream.dns_servers.is_empty() {
if let Ok((config, opts)) = read_system_conf() {
Resolver::Hickory(TokioAsyncResolver::tokio(config, opts))
} else {
warn!("unable to read system dns configuration. using system dns resolver with no caching");
Resolver::System
}
} else {
Resolver::Hickory(TokioAsyncResolver::tokio(
ResolverConfig::from_parts(
@ -240,11 +246,14 @@ fn main() -> anyhow::Result<()> {
.await
.with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?;
if CONFIG.server.enable_stats_endpoint {
if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() {
if let Some(bind_addr) = CONFIG
.server
.stats_endpoint
.as_ref()
.and_then(|x| x.get_bindaddr())
{
info!("stats server listening on {:?}", bind_addr);
let mut stats_listener =
ServerListener::new(&bind_addr).await.with_context(|| {
let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| {
format!("failed to bind to address {} for stats server", bind_addr.1)
})?;
@ -261,16 +270,24 @@ fn main() -> anyhow::Result<()> {
}
});
}
}
let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint();
let stats_endpoint = CONFIG
.server
.stats_endpoint
.as_ref()
.and_then(|x| x.get_endpoint());
loop {
let stats_endpoint = stats_endpoint.clone();
match listener.accept().await {
Ok((stream, id)) => {
Ok((stream, client_id)) => {
tokio::spawn(async move {
let res = route::route(stream, stats_endpoint, move |stream| {
handle_stream(stream, id)
let res = route::route(stream, stats_endpoint, move |stream, maybe_ip| {
let client_id = if let Some(ip) = maybe_ip {
format!("{} ({})", client_id, ip)
} else {
client_id
};
handle_stream(stream, client_id)
})
.await;

View file

@ -5,8 +5,8 @@ use bytes::Bytes;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
use http_body_util::Full;
use hyper::{
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response,
StatusCode,
body::Incoming, server::conn::http1::Builder, service::service_fn, HeaderMap, Request,
Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use log::{debug, error};
@ -52,20 +52,29 @@ fn send_stats() -> anyhow::Result<Response<Body>> {
}
}
async fn ws_upgrade<T, R>(
fn get_header(headers: &HeaderMap, header: &str) -> Option<String> {
headers.get(header).and_then(|x| x.to_str().ok()).map(|x| x.to_string())
}
enum HttpUpgradeResult {
Wisp,
WsProxy(String, bool),
}
async fn ws_upgrade<F, R>(
mut req: Request<Incoming>,
stats_endpoint: Option<String>,
callback: T,
callback: F,
) -> anyhow::Result<Response<Body>>
where
T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static,
F: FnOnce(UpgradeFut, HttpUpgradeResult, Option<String>) -> R + Send + 'static,
R: Future<Output = anyhow::Result<()>> + Send,
{
let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req);
if !is_upgrade {
if let Some(stats_endpoint) = stats_endpoint {
if CONFIG.server.enable_stats_endpoint && req.uri().path() == stats_endpoint {
if req.uri().path() == stats_endpoint {
return send_stats();
} else {
debug!("sent non_ws_response to http client");
@ -81,20 +90,33 @@ where
// replace body of Empty<Bytes> with Full<Bytes>
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
let headers = req.headers();
let ip_header = if CONFIG.server.use_real_ip_headers {
get_header(headers, "x-real-ip").or_else(|| get_header(headers, "x-forwarded-for"))
} else {
None
};
if req
.uri()
.path()
.starts_with(&(CONFIG.server.prefix.clone() + "/"))
.starts_with(&(CONFIG.wisp.prefix.clone() + "/"))
{
tokio::spawn(async move {
if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await {
if let Err(err) = (callback)(fut, HttpUpgradeResult::Wisp, ip_header).await {
error!("error while serving client: {:?}", err);
}
});
} else if CONFIG.wisp.allow_wsproxy {
let udp = req.uri().query().unwrap_or_default() == "?udp";
tokio::spawn(async move {
if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await {
if let Err(err) = (callback)(
fut,
HttpUpgradeResult::WsProxy(req.uri().path().to_string(), udp),
ip_header,
)
.await
{
error!("error while serving client: {:?}", err);
}
});
@ -117,7 +139,7 @@ pub async fn route_stats(stream: ServerStream) -> anyhow::Result<()> {
pub async fn route(
stream: ServerStream,
stats_endpoint: Option<String>,
callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static,
callback: impl FnOnce(ServerRouteResult, Option<String>) + Clone + Send + 'static,
) -> anyhow::Result<()> {
match CONFIG.server.transport {
SocketTransport::WebSocket => {
@ -132,15 +154,13 @@ pub async fn route(
ws_upgrade(
req,
stats_endpoint.clone(),
|fut, wsproxy, udp, path| async move {
|fut, res, maybe_ip| async move {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
ws.set_auto_pong(false);
if wsproxy {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(ServerRouteResult::WsProxy(ws, path, udp));
} else {
match res {
HttpUpgradeResult::Wisp => {
let (read, write) = ws.split(|x| {
let parts = x
.into_inner()
@ -150,10 +170,21 @@ pub async fn route(
(Cursor::new(parts.read_buf).chain(r), w)
});
(callback)(ServerRouteResult::Wisp((
(callback)(
ServerRouteResult::Wisp((
Box::new(read),
Box::new(write),
)))
)),
maybe_ip,
)
}
HttpUpgradeResult::WsProxy(path, udp) => {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(
ServerRouteResult::WsProxy(ws, path, udp),
maybe_ip,
);
}
}
Ok(())
@ -174,7 +205,10 @@ pub async fn route(
let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone()));
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write))));
(callback)(
ServerRouteResult::Wisp((Box::new(read), Box::new(write))),
None,
);
}
}
Ok(())