From 14b5bd796b2e5977a02c3443952b422c3c81f0a7 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Thu, 26 Sep 2024 17:30:02 -0700 Subject: [PATCH] simplify bindaddr, add separate stats server --- server/src/config.rs | 64 ++++++++++++++++----- server/src/listener.rs | 39 +++++++------ server/src/main.rs | 52 +++++++++++++++--- server/src/route.rs | 122 ++++++++++++++++++++++++----------------- wisp/src/generic.rs | 14 +++-- 5 files changed, 197 insertions(+), 94 deletions(-) diff --git a/server/src/config.rs b/server/src/config.rs index 19f3ff6..ef8920a 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -30,7 +30,7 @@ const VERSION_STRING: &str = concat!( env!("VERGEN_RUSTC_HOST_TRIPLE") ); -#[derive(Serialize, Deserialize, Default, Debug)] +#[derive(Serialize, Deserialize, Default, Debug, Clone, Copy)] #[serde(rename_all = "lowercase")] pub enum SocketType { /// TCP socket listener. @@ -59,13 +59,22 @@ pub enum SocketTransport { LengthDelimitedLe, } +pub type BindAddr = (SocketType, String); + +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum StatsEndpoint { + /// Stats on the same listener as the Wisp server. + SameServer(String), + /// Stats on this address and socket type. + SeparateServer((SocketType, String)), +} + #[derive(Serialize, Deserialize)] #[serde(default)] pub struct ServerConfig { - /// Address to listen on. - pub bind: String, - /// Socket type to listen on. - pub socket: SocketType, + /// Address and socket type to listen on. + pub bind: BindAddr, /// Transport to listen on. pub transport: SocketTransport, /// Whether or not to resolve and connect to IPV6 upstream addresses. @@ -83,15 +92,12 @@ pub struct ServerConfig { pub verbose_stats: bool, /// Whether or not to respond to stats requests over HTTP. pub enable_stats_endpoint: bool, - #[serde(skip_serializing_if = "String::is_empty")] - /// Path of stats HTTP endpoint. - pub stats_endpoint: String, + /// Where to listen for stats requests over HTTP. + pub stats_endpoint: StatsEndpoint, - #[serde(skip_serializing_if = "String::is_empty")] /// String sent to a request that is not a websocket upgrade request. pub non_ws_response: String, - #[serde(skip_serializing_if = "String::is_empty")] /// Prefix of Wisp server. Do NOT add a trailing slash here. pub prefix: String, @@ -120,6 +126,14 @@ pub enum ProtocolExtensionAuth { Certificate, } +fn default_motd() -> String { + format!("epoxy_server ({})", VERSION_STRING) +} + +fn is_default_motd(str: &String) -> bool { + *str == default_motd() +} + #[derive(Serialize, Deserialize)] #[serde(default)] pub struct WispConfig { @@ -144,6 +158,7 @@ pub struct WispConfig { /// Wisp version 2 certificate authentication extension public ed25519 pem keys. pub certificate_extension_keys: Vec, + #[serde(skip_serializing_if = "is_default_motd")] /// Wisp version 2 MOTD extension message. pub motd_extension: String, } @@ -266,11 +281,32 @@ pub async fn validate_config_cache() { RESOLVER.clear_cache(); } +impl Default for StatsEndpoint { + fn default() -> Self { + Self::SameServer("/stats".to_string()) + } +} + +impl StatsEndpoint { + pub fn get_endpoint(&self) -> Option { + match self { + Self::SameServer(x) => Some(x.clone()), + Self::SeparateServer(_) => None, + } + } + + pub fn get_bindaddr(&self) -> Option { + match self { + Self::SameServer(_) => None, + Self::SeparateServer(x) => Some(x.clone()), + } + } +} + impl Default for ServerConfig { fn default() -> Self { Self { - bind: "127.0.0.1:4000".to_string(), - socket: SocketType::default(), + bind: (SocketType::default(), "127.0.0.1:4000".to_string()), transport: SocketTransport::default(), resolve_ipv6: false, tcp_nodelay: false, @@ -278,8 +314,8 @@ impl Default for ServerConfig { tls_keypair: None, verbose_stats: true, - stats_endpoint: "/stats".to_string(), enable_stats_endpoint: false, + stats_endpoint: StatsEndpoint::default(), non_ws_response: ":3".to_string(), @@ -305,7 +341,7 @@ impl Default for WispConfig { password_extension_users: HashMap::new(), certificate_extension_keys: Vec::new(), - motd_extension: format!("epoxy_server ({})", VERSION_STRING), + motd_extension: default_motd(), } } } diff --git a/server/src/listener.rs b/server/src/listener.rs index 32e0136..ff83ca8 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -16,7 +16,10 @@ use tokio::{ use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; use uuid::Uuid; -use crate::{config::SocketType, CONFIG}; +use crate::{ + config::{BindAddr, SocketType}, + CONFIG, +}; pub enum Quintet { One(A), @@ -282,18 +285,18 @@ pub enum ServerListener { } impl ServerListener { - async fn bind_tcp() -> anyhow::Result { - TcpListener::bind(&CONFIG.server.bind) + async fn bind_tcp(bind: &BindAddr) -> anyhow::Result { + TcpListener::bind(&bind.1) .await - .with_context(|| format!("failed to bind to tcp address `{}`", CONFIG.server.bind)) + .with_context(|| format!("failed to bind to tcp address `{}`", bind.1)) } - async fn bind_unix() -> anyhow::Result { - if try_exists(&CONFIG.server.bind).await? { - remove_file(&CONFIG.server.bind).await?; + async fn bind_unix(bind: &BindAddr) -> anyhow::Result { + if try_exists(&bind.1).await? { + remove_file(&bind.1).await?; } - UnixListener::bind(&CONFIG.server.bind) - .with_context(|| format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)) + UnixListener::bind(&bind.1) + .with_context(|| format!("failed to bind to unix socket at `{}`", bind.1)) } async fn create_tls() -> anyhow::Result { @@ -330,15 +333,17 @@ impl ServerListener { Ok(TlsAcceptor::from(cfg)) } - pub async fn new() -> anyhow::Result { - Ok(match CONFIG.server.socket { - SocketType::Tcp => Self::Tcp(Self::bind_tcp().await?), - SocketType::TlsTcp => Self::TlsTcp(Self::bind_tcp().await?, Self::create_tls().await?), - SocketType::Unix => Self::Unix(Self::bind_unix().await?), - SocketType::TlsUnix => { - Self::TlsUnix(Self::bind_unix().await?, Self::create_tls().await?) + pub async fn new(bind: &BindAddr) -> anyhow::Result { + Ok(match bind.0 { + SocketType::Tcp => Self::Tcp(Self::bind_tcp(bind).await?), + SocketType::TlsTcp => { + Self::TlsTcp(Self::bind_tcp(bind).await?, Self::create_tls().await?) } - SocketType::File => Self::File(Some(CONFIG.server.bind.clone().into())), + SocketType::Unix => Self::Unix(Self::bind_unix(bind).await?), + SocketType::TlsUnix => { + Self::TlsUnix(Self::bind_unix(bind).await?, Self::create_tls().await?) + } + SocketType::File => Self::File(Some(bind.1.clone().into())), }) } diff --git a/server/src/main.rs b/server/src/main.rs index bf86b69..bcf9f4a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -3,6 +3,7 @@ use std::{fmt::Write, fs::read_to_string, net::IpAddr}; +use anyhow::Context; use clap::Parser; use config::{validate_config_cache, Cli, Config}; use dashmap::DashMap; @@ -14,7 +15,7 @@ use hickory_resolver::{ use lazy_static::lazy_static; use listener::ServerListener; use log::{error, info}; -use route::ServerRouteResult; +use route::{route_stats, ServerRouteResult}; use tokio::signal::unix::{signal, SignalKind}; use uuid::Uuid; use wisp_mux::{ConnectPacket, StreamType}; @@ -54,7 +55,13 @@ lazy_static! { pub static ref CLI: Cli = Cli::parse(); pub static ref CONFIG: Config = { if let Some(path) = &CLI.config { - Config::de(read_to_string(path).unwrap()).unwrap() + Config::de( + read_to_string(path) + .context("failed to read config") + .unwrap(), + ) + .context("failed to parse config") + .unwrap() } else { Config::default() } @@ -191,7 +198,7 @@ fn handle_stream(stream: ServerRouteResult, id: String) { } #[global_allocator] -static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +static JEMALLOCATOR: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { @@ -208,8 +215,8 @@ async fn main() -> anyhow::Result<()> { validate_config_cache().await; info!( - "listening on {:?} with socket type {:?} and socket transport {:?}", - CONFIG.server.bind, CONFIG.server.socket, CONFIG.server.transport + "listening on {:?} with socket transport {:?}", + CONFIG.server.bind, CONFIG.server.transport ); tokio::spawn(async { @@ -219,13 +226,40 @@ async fn main() -> anyhow::Result<()> { } }); - let mut listener = ServerListener::new().await?; + let mut listener = ServerListener::new(&CONFIG.server.bind) + .await + .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; + + if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() { + info!("stats server listening on {:?}", bind_addr); + let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| { + format!("failed to bind to address {} for stats server", bind_addr.1) + })?; + + tokio::spawn(async move { + loop { + match stats_listener.accept().await { + Ok((stream, _)) => { + if let Err(e) = route_stats(stream).await { + error!("error while routing stats client: {:?}", e); + } + } + Err(e) => error!("error while accepting stats client: {:?}", e), + } + } + }); + } + + let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint(); loop { - let ret = listener.accept().await; - match ret { + let stats_endpoint = stats_endpoint.clone(); + match listener.accept().await { Ok((stream, id)) => { tokio::spawn(async move { - let res = route::route(stream, move |stream| handle_stream(stream, id)).await; + let res = route::route(stream, stats_endpoint, move |stream| { + handle_stream(stream, id) + }) + .await; if let Err(e) = res { error!("error while routing client: {:?}", e); diff --git a/server/src/route.rs b/server/src/route.rs index ace7d44..8220f75 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -33,36 +33,48 @@ fn non_ws_resp() -> Response { .unwrap() } -async fn ws_upgrade(mut req: Request, callback: T) -> anyhow::Result> +fn send_stats() -> anyhow::Result> { + match generate_stats() { + Ok(x) => { + debug!("sent server stats to http client"); + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::new(x.into())) + .unwrap()) + } + Err(x) => { + error!("failed to send stats to http client: {:?}", x); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::new(x.to_string().into())) + .unwrap()) + } + } +} + +async fn ws_upgrade( + mut req: Request, + stats_endpoint: Option, + callback: T, +) -> anyhow::Result> where T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, R: Future> + Send, { let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); - if !is_upgrade - && CONFIG.server.enable_stats_endpoint - && req.uri().path() == CONFIG.server.stats_endpoint - { - match generate_stats() { - Ok(x) => { - debug!("sent server stats to http client"); - return Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::new(x.into())) - .unwrap()); - } - Err(x) => { - error!("failed to send stats to http client: {:?}", x); - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::new(x.to_string().into())) - .unwrap()); + if !is_upgrade { + if let Some(stats_endpoint) = stats_endpoint { + if CONFIG.server.enable_stats_endpoint && req.uri().path() == stats_endpoint { + return send_stats(); + } else { + debug!("sent non_ws_response to http client"); + return Ok(non_ws_resp()); } + } else { + debug!("sent non_ws_response to http client"); + return Ok(non_ws_resp()); } - } else if !is_upgrade { - debug!("sent non_ws_response to http client"); - return Ok(non_ws_resp()); } let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; @@ -94,51 +106,63 @@ where Ok(resp) } +pub async fn route_stats(stream: ServerStream) -> anyhow::Result<()> { + let stream = TokioIo::new(stream); + Builder::new() + .serve_connection(stream, service_fn(move |_| async { send_stats() })) + .await?; + Ok(()) +} + pub async fn route( stream: ServerStream, + stats_endpoint: Option, callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, ) -> anyhow::Result<()> { match CONFIG.server.transport { SocketTransport::WebSocket => { let stream = TokioIo::new(stream); - let fut = Builder::new() + Builder::new() .serve_connection( stream, service_fn(move |req| { let callback = callback.clone(); - ws_upgrade(req, |fut, wsproxy, udp, path| async move { - let mut ws = fut.await.context("failed to await upgrade future")?; - ws.set_max_message_size(CONFIG.server.max_message_size); - ws.set_auto_pong(false); + ws_upgrade( + req, + stats_endpoint.clone(), + |fut, wsproxy, udp, path| async move { + let mut ws = fut.await.context("failed to await upgrade future")?; + ws.set_max_message_size(CONFIG.server.max_message_size); + ws.set_auto_pong(false); - if wsproxy { - let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); - (callback)(ServerRouteResult::WsProxy(ws, path, udp)); - } else { - let (read, write) = ws.split(|x| { - let parts = - x.into_inner().downcast::>().unwrap(); - let (r, w) = parts.io.into_inner().split(); - (Cursor::new(parts.read_buf).chain(r), w) - }); + if wsproxy { + let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); + (callback)(ServerRouteResult::WsProxy(ws, path, udp)); + } else { + let (read, write) = ws.split(|x| { + let parts = x + .into_inner() + .downcast::>() + .unwrap(); + let (r, w) = parts.io.into_inner().split(); + (Cursor::new(parts.read_buf).chain(r), w) + }); - (callback)(ServerRouteResult::Wisp(( - Box::new(read), - Box::new(write), - ))) - } + (callback)(ServerRouteResult::Wisp(( + Box::new(read), + Box::new(write), + ))) + } - Ok(()) - }) + Ok(()) + }, + ) }), ) - .with_upgrades(); - - if let Err(e) = fut.await { - error!("error while serving client: {:?}", e); - } + .with_upgrades() + .await?; } SocketTransport::LengthDelimitedLe => { let codec = LengthDelimitedCodec::builder() diff --git a/wisp/src/generic.rs b/wisp/src/generic.rs index 262749c..5589623 100644 --- a/wisp/src/generic.rs +++ b/wisp/src/generic.rs @@ -6,7 +6,7 @@ use futures::{Sink, SinkExt, Stream, StreamExt}; use std::error::Error; use crate::{ - ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, WispError, }; @@ -72,10 +72,14 @@ impl + Send + Unpin, E: Error + Sync + Send + 'static> for GenericWebSocketWrite { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - self.0 - .send(BytesMut::from(frame.payload).freeze()) - .await - .map_err(|x| WispError::WsImplError(Box::new(x))) + if frame.opcode == OpCode::Binary { + self.0 + .send(BytesMut::from(frame.payload).freeze()) + .await + .map_err(|x| WispError::WsImplError(Box::new(x))) + } else { + Ok(()) + } } async fn wisp_close(&mut self) -> Result<(), WispError> {