diff --git a/Cargo.lock b/Cargo.lock index 2bad28e..4f62d28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -554,6 +554,7 @@ version = "2.0.0" dependencies = [ "anyhow", "bytes", + "clap", "dashmap 6.0.1", "env_logger", "fastwebsockets", @@ -565,6 +566,7 @@ dependencies = [ "log", "regex", "serde", + "serde_json", "tokio", "tokio-util", "toml", diff --git a/server/Cargo.toml b/server/Cargo.toml index 03be1a4..30fa3f2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] anyhow = "1.0.86" bytes = "1.6.1" +clap = { version = "4.5.9", features = ["cargo", "derive"] } dashmap = "6.0.1" env_logger = "0.11.3" fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } @@ -17,6 +18,7 @@ lazy_static = "1.5.0" log = { version = "0.4.22", features = ["serde", "std"] } regex = "1.10.5" serde = { version = "1.0.204", features = ["derive"] } +serde_json = "1.0.120" tokio = { version = "1.38.1", features = ["full"] } tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] } toml = "0.8.15" diff --git a/server/src/config.rs b/server/src/config.rs index 22a2dff..60886f8 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,5 +1,6 @@ -use std::{collections::HashMap, ops::RangeInclusive}; +use std::{collections::HashMap, ops::RangeInclusive, path::PathBuf}; +use clap::{Parser, ValueEnum}; use lazy_static::lazy_static; use log::LevelFilter; use regex::RegexSet; @@ -9,7 +10,7 @@ use wisp_mux::extensions::{ ProtocolExtensionBuilder, }; -use crate::CONFIG; +use crate::{CLI, CONFIG}; type AnyProtocolExtensionBuilder = Box; @@ -63,6 +64,7 @@ pub struct ServerConfig { pub bind: String, pub socket: SocketType, pub resolve_ipv6: bool, + pub tcp_nodelay: bool, pub verbose_stats: bool, pub enable_stats_endpoint: bool, @@ -84,6 +86,7 @@ impl Default for ServerConfig { bind: "127.0.0.1:4000".to_string(), socket: SocketType::default(), resolve_ipv6: false, + tcp_nodelay: false, verbose_stats: true, stats_endpoint: "/stats".to_string(), @@ -165,6 +168,8 @@ impl WispConfig { #[derive(Serialize, Deserialize)] #[serde(default)] pub struct StreamConfig { + pub tcp_nodelay: bool, + pub allow_udp: bool, pub allow_wsproxy_udp: bool, @@ -185,6 +190,8 @@ pub struct StreamConfig { impl Default for StreamConfig { fn default() -> Self { Self { + tcp_nodelay: false, + allow_udp: true, allow_wsproxy_udp: false, @@ -229,3 +236,42 @@ pub struct Config { pub wisp: WispConfig, pub stream: StreamConfig, } + +impl Config { + pub fn ser(&self) -> anyhow::Result { + Ok(match CLI.format { + ConfigFormat::Toml => toml::to_string_pretty(self)?, + ConfigFormat::Json => serde_json::to_string_pretty(self)?, + }) + } + + pub fn de(string: String) -> anyhow::Result { + Ok(match CLI.format { + ConfigFormat::Toml => toml::from_str(&string)?, + ConfigFormat::Json => serde_json::from_str(&string)?, + }) + } +} + +#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)] +pub enum ConfigFormat { + #[default] + Toml, + Json, +} + +/// Server implementation of the Wisp protocol in Rust, made for epoxy. +#[derive(Parser)] +#[command(version = clap::crate_version!())] +pub struct Cli { + /// Config file to use. + pub config: Option, + + /// Config file format to use. + #[arg(short, long, value_enum, default_value_t = ConfigFormat::default())] + pub format: ConfigFormat, + + /// Show default config and exit. + #[arg(long)] + pub default_config: bool, +} diff --git a/server/src/main.rs b/server/src/main.rs index d0e87a2..b100dcb 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,9 +1,10 @@ #![feature(ip)] -use std::{env::args, fmt::Write, fs::read_to_string}; +use std::{fmt::Write, fs::read_to_string}; use bytes::Bytes; -use config::{validate_config_cache, Config}; +use clap::Parser; +use config::{validate_config_cache, Cli, Config}; use dashmap::DashMap; use handle::{handle_wisp, handle_wsproxy}; use http_body_util::Full; @@ -26,9 +27,10 @@ mod stream; type Client = (DashMap, bool); lazy_static! { + pub static ref CLI: Cli = Cli::parse(); pub static ref CONFIG: Config = { - if let Some(path) = args().nth(1) { - toml::from_str(&read_to_string(path).unwrap()).unwrap() + if let Some(path) = &CLI.config { + Config::de(read_to_string(path).unwrap()).unwrap() } else { Config::default() } @@ -159,10 +161,16 @@ fn generate_stats() -> Result { #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { + if CLI.default_config { + println!("{}", Config::default().ser()?); + return Ok(()); + } + env_logger::builder() .filter_level(CONFIG.server.log_level) .parse_default_env() .init(); + validate_config_cache(); info!( diff --git a/server/src/stream.rs b/server/src/stream.rs index ee8c169..71d13ba 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -69,11 +69,18 @@ impl ServerListener { pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { match self { - Self::Tcp(x) => x - .accept() - .await - .map(|(x, y)| (Either::Left(x), y.to_string())) - .context("failed to accept tcp connection"), + Self::Tcp(x) => { + let (stream, addr) = x + .accept() + .await + .context("failed to accept tcp connection")?; + if CONFIG.server.tcp_nodelay { + stream + .set_nodelay(true) + .context("failed to set tcp nodelay")?; + } + Ok((Either::Left(stream), addr.to_string())) + } Self::Unix(x) => x .accept() .await @@ -184,6 +191,12 @@ impl ClientStream { format!("failed to connect to host {}", packet.destination_hostname) })?; + if CONFIG.stream.tcp_nodelay { + stream + .set_nodelay(true) + .context("failed to set tcp nodelay")?; + } + Ok(ClientStream::Tcp(stream)) } StreamType::Udp => {