simplify bindaddr, add separate stats server

This commit is contained in:
Toshit Chawda 2024-09-26 17:30:02 -07:00
parent cbbe5308f5
commit 14b5bd796b
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 197 additions and 94 deletions

View file

@ -30,7 +30,7 @@ const VERSION_STRING: &str = concat!(
env!("VERGEN_RUSTC_HOST_TRIPLE")
);
#[derive(Serialize, Deserialize, Default, Debug)]
#[derive(Serialize, Deserialize, Default, Debug, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum SocketType {
/// TCP socket listener.
@ -59,13 +59,22 @@ pub enum SocketTransport {
LengthDelimitedLe,
}
pub type BindAddr = (SocketType, String);
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum StatsEndpoint {
/// Stats on the same listener as the Wisp server.
SameServer(String),
/// Stats on this address and socket type.
SeparateServer((SocketType, String)),
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
/// Address to listen on.
pub bind: String,
/// Socket type to listen on.
pub socket: SocketType,
/// Address and socket type to listen on.
pub bind: BindAddr,
/// Transport to listen on.
pub transport: SocketTransport,
/// Whether or not to resolve and connect to IPV6 upstream addresses.
@ -83,15 +92,12 @@ pub struct ServerConfig {
pub verbose_stats: bool,
/// Whether or not to respond to stats requests over HTTP.
pub enable_stats_endpoint: bool,
#[serde(skip_serializing_if = "String::is_empty")]
/// Path of stats HTTP endpoint.
pub stats_endpoint: String,
/// Where to listen for stats requests over HTTP.
pub stats_endpoint: StatsEndpoint,
#[serde(skip_serializing_if = "String::is_empty")]
/// String sent to a request that is not a websocket upgrade request.
pub non_ws_response: String,
#[serde(skip_serializing_if = "String::is_empty")]
/// Prefix of Wisp server. Do NOT add a trailing slash here.
pub prefix: String,
@ -120,6 +126,14 @@ pub enum ProtocolExtensionAuth {
Certificate,
}
fn default_motd() -> String {
format!("epoxy_server ({})", VERSION_STRING)
}
fn is_default_motd(str: &String) -> bool {
*str == default_motd()
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct WispConfig {
@ -144,6 +158,7 @@ pub struct WispConfig {
/// Wisp version 2 certificate authentication extension public ed25519 pem keys.
pub certificate_extension_keys: Vec<PathBuf>,
#[serde(skip_serializing_if = "is_default_motd")]
/// Wisp version 2 MOTD extension message.
pub motd_extension: String,
}
@ -266,11 +281,32 @@ pub async fn validate_config_cache() {
RESOLVER.clear_cache();
}
impl Default for StatsEndpoint {
fn default() -> Self {
Self::SameServer("/stats".to_string())
}
}
impl StatsEndpoint {
pub fn get_endpoint(&self) -> Option<String> {
match self {
Self::SameServer(x) => Some(x.clone()),
Self::SeparateServer(_) => None,
}
}
pub fn get_bindaddr(&self) -> Option<BindAddr> {
match self {
Self::SameServer(_) => None,
Self::SeparateServer(x) => Some(x.clone()),
}
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind: "127.0.0.1:4000".to_string(),
socket: SocketType::default(),
bind: (SocketType::default(), "127.0.0.1:4000".to_string()),
transport: SocketTransport::default(),
resolve_ipv6: false,
tcp_nodelay: false,
@ -278,8 +314,8 @@ impl Default for ServerConfig {
tls_keypair: None,
verbose_stats: true,
stats_endpoint: "/stats".to_string(),
enable_stats_endpoint: false,
stats_endpoint: StatsEndpoint::default(),
non_ws_response: ":3".to_string(),
@ -305,7 +341,7 @@ impl Default for WispConfig {
password_extension_users: HashMap::new(),
certificate_extension_keys: Vec::new(),
motd_extension: format!("epoxy_server ({})", VERSION_STRING),
motd_extension: default_motd(),
}
}
}

View file

@ -16,7 +16,10 @@ use tokio::{
use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
use uuid::Uuid;
use crate::{config::SocketType, CONFIG};
use crate::{
config::{BindAddr, SocketType},
CONFIG,
};
pub enum Quintet<A, B, C, D, E> {
One(A),
@ -282,18 +285,18 @@ pub enum ServerListener {
}
impl ServerListener {
async fn bind_tcp() -> anyhow::Result<TcpListener> {
TcpListener::bind(&CONFIG.server.bind)
async fn bind_tcp(bind: &BindAddr) -> anyhow::Result<TcpListener> {
TcpListener::bind(&bind.1)
.await
.with_context(|| format!("failed to bind to tcp address `{}`", CONFIG.server.bind))
.with_context(|| format!("failed to bind to tcp address `{}`", bind.1))
}
async fn bind_unix() -> anyhow::Result<UnixListener> {
if try_exists(&CONFIG.server.bind).await? {
remove_file(&CONFIG.server.bind).await?;
async fn bind_unix(bind: &BindAddr) -> anyhow::Result<UnixListener> {
if try_exists(&bind.1).await? {
remove_file(&bind.1).await?;
}
UnixListener::bind(&CONFIG.server.bind)
.with_context(|| format!("failed to bind to unix socket at `{}`", CONFIG.server.bind))
UnixListener::bind(&bind.1)
.with_context(|| format!("failed to bind to unix socket at `{}`", bind.1))
}
async fn create_tls() -> anyhow::Result<TlsAcceptor> {
@ -330,15 +333,17 @@ impl ServerListener {
Ok(TlsAcceptor::from(cfg))
}
pub async fn new() -> anyhow::Result<Self> {
Ok(match CONFIG.server.socket {
SocketType::Tcp => Self::Tcp(Self::bind_tcp().await?),
SocketType::TlsTcp => Self::TlsTcp(Self::bind_tcp().await?, Self::create_tls().await?),
SocketType::Unix => Self::Unix(Self::bind_unix().await?),
SocketType::TlsUnix => {
Self::TlsUnix(Self::bind_unix().await?, Self::create_tls().await?)
pub async fn new(bind: &BindAddr) -> anyhow::Result<Self> {
Ok(match bind.0 {
SocketType::Tcp => Self::Tcp(Self::bind_tcp(bind).await?),
SocketType::TlsTcp => {
Self::TlsTcp(Self::bind_tcp(bind).await?, Self::create_tls().await?)
}
SocketType::File => Self::File(Some(CONFIG.server.bind.clone().into())),
SocketType::Unix => Self::Unix(Self::bind_unix(bind).await?),
SocketType::TlsUnix => {
Self::TlsUnix(Self::bind_unix(bind).await?, Self::create_tls().await?)
}
SocketType::File => Self::File(Some(bind.1.clone().into())),
})
}

View file

@ -3,6 +3,7 @@
use std::{fmt::Write, fs::read_to_string, net::IpAddr};
use anyhow::Context;
use clap::Parser;
use config::{validate_config_cache, Cli, Config};
use dashmap::DashMap;
@ -14,7 +15,7 @@ use hickory_resolver::{
use lazy_static::lazy_static;
use listener::ServerListener;
use log::{error, info};
use route::ServerRouteResult;
use route::{route_stats, ServerRouteResult};
use tokio::signal::unix::{signal, SignalKind};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
@ -54,7 +55,13 @@ lazy_static! {
pub static ref CLI: Cli = Cli::parse();
pub static ref CONFIG: Config = {
if let Some(path) = &CLI.config {
Config::de(read_to_string(path).unwrap()).unwrap()
Config::de(
read_to_string(path)
.context("failed to read config")
.unwrap(),
)
.context("failed to parse config")
.unwrap()
} else {
Config::default()
}
@ -191,7 +198,7 @@ fn handle_stream(stream: ServerRouteResult, id: String) {
}
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
static JEMALLOCATOR: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
@ -208,8 +215,8 @@ async fn main() -> anyhow::Result<()> {
validate_config_cache().await;
info!(
"listening on {:?} with socket type {:?} and socket transport {:?}",
CONFIG.server.bind, CONFIG.server.socket, CONFIG.server.transport
"listening on {:?} with socket transport {:?}",
CONFIG.server.bind, CONFIG.server.transport
);
tokio::spawn(async {
@ -219,13 +226,40 @@ async fn main() -> anyhow::Result<()> {
}
});
let mut listener = ServerListener::new().await?;
let mut listener = ServerListener::new(&CONFIG.server.bind)
.await
.with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?;
if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() {
info!("stats server listening on {:?}", bind_addr);
let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| {
format!("failed to bind to address {} for stats server", bind_addr.1)
})?;
tokio::spawn(async move {
loop {
match stats_listener.accept().await {
Ok((stream, _)) => {
if let Err(e) = route_stats(stream).await {
error!("error while routing stats client: {:?}", e);
}
}
Err(e) => error!("error while accepting stats client: {:?}", e),
}
}
});
}
let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint();
loop {
let ret = listener.accept().await;
match ret {
let stats_endpoint = stats_endpoint.clone();
match listener.accept().await {
Ok((stream, id)) => {
tokio::spawn(async move {
let res = route::route(stream, move |stream| handle_stream(stream, id)).await;
let res = route::route(stream, stats_endpoint, move |stream| {
handle_stream(stream, id)
})
.await;
if let Err(e) = res {
error!("error while routing client: {:?}", e);

View file

@ -33,36 +33,48 @@ fn non_ws_resp() -> Response<Body> {
.unwrap()
}
async fn ws_upgrade<T, R>(mut req: Request<Incoming>, callback: T) -> anyhow::Result<Response<Body>>
fn send_stats() -> anyhow::Result<Response<Body>> {
match generate_stats() {
Ok(x) => {
debug!("sent server stats to http client");
Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::new(x.into()))
.unwrap())
}
Err(x) => {
error!("failed to send stats to http client: {:?}", x);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::new(x.to_string().into()))
.unwrap())
}
}
}
async fn ws_upgrade<T, R>(
mut req: Request<Incoming>,
stats_endpoint: Option<String>,
callback: T,
) -> anyhow::Result<Response<Body>>
where
T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static,
R: Future<Output = anyhow::Result<()>> + Send,
{
let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req);
if !is_upgrade
&& CONFIG.server.enable_stats_endpoint
&& req.uri().path() == CONFIG.server.stats_endpoint
{
match generate_stats() {
Ok(x) => {
debug!("sent server stats to http client");
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::new(x.into()))
.unwrap());
}
Err(x) => {
error!("failed to send stats to http client: {:?}", x);
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::new(x.to_string().into()))
.unwrap());
if !is_upgrade {
if let Some(stats_endpoint) = stats_endpoint {
if CONFIG.server.enable_stats_endpoint && req.uri().path() == stats_endpoint {
return send_stats();
} else {
debug!("sent non_ws_response to http client");
return Ok(non_ws_resp());
}
} else {
debug!("sent non_ws_response to http client");
return Ok(non_ws_resp());
}
} else if !is_upgrade {
debug!("sent non_ws_response to http client");
return Ok(non_ws_resp());
}
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
@ -94,51 +106,63 @@ where
Ok(resp)
}
pub async fn route_stats(stream: ServerStream) -> anyhow::Result<()> {
let stream = TokioIo::new(stream);
Builder::new()
.serve_connection(stream, service_fn(move |_| async { send_stats() }))
.await?;
Ok(())
}
pub async fn route(
stream: ServerStream,
stats_endpoint: Option<String>,
callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static,
) -> anyhow::Result<()> {
match CONFIG.server.transport {
SocketTransport::WebSocket => {
let stream = TokioIo::new(stream);
let fut = Builder::new()
Builder::new()
.serve_connection(
stream,
service_fn(move |req| {
let callback = callback.clone();
ws_upgrade(req, |fut, wsproxy, udp, path| async move {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
ws.set_auto_pong(false);
ws_upgrade(
req,
stats_endpoint.clone(),
|fut, wsproxy, udp, path| async move {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
ws.set_auto_pong(false);
if wsproxy {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(ServerRouteResult::WsProxy(ws, path, udp));
} else {
let (read, write) = ws.split(|x| {
let parts =
x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
let (r, w) = parts.io.into_inner().split();
(Cursor::new(parts.read_buf).chain(r), w)
});
if wsproxy {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(ServerRouteResult::WsProxy(ws, path, udp));
} else {
let (read, write) = ws.split(|x| {
let parts = x
.into_inner()
.downcast::<TokioIo<ServerStream>>()
.unwrap();
let (r, w) = parts.io.into_inner().split();
(Cursor::new(parts.read_buf).chain(r), w)
});
(callback)(ServerRouteResult::Wisp((
Box::new(read),
Box::new(write),
)))
}
(callback)(ServerRouteResult::Wisp((
Box::new(read),
Box::new(write),
)))
}
Ok(())
})
Ok(())
},
)
}),
)
.with_upgrades();
if let Err(e) = fut.await {
error!("error while serving client: {:?}", e);
}
.with_upgrades()
.await?;
}
SocketTransport::LengthDelimitedLe => {
let codec = LengthDelimitedCodec::builder()

View file

@ -6,7 +6,7 @@ use futures::{Sink, SinkExt, Stream, StreamExt};
use std::error::Error;
use crate::{
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
WispError,
};
@ -72,10 +72,14 @@ impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static>
for GenericWebSocketWrite<T, E>
{
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
self.0
.send(BytesMut::from(frame.payload).freeze())
.await
.map_err(|x| WispError::WsImplError(Box::new(x)))
if frame.opcode == OpCode::Binary {
self.0
.send(BytesMut::from(frame.payload).freeze())
.await
.map_err(|x| WispError::WsImplError(Box::new(x)))
} else {
Ok(())
}
}
async fn wisp_close(&mut self) -> Result<(), WispError> {