add cli opts, add default_config, add json config support

This commit is contained in:
Toshit Chawda 2024-07-22 13:46:22 -07:00
parent d78e6cef0c
commit 29f05a2ddd
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 82 additions and 11 deletions

2
Cargo.lock generated
View file

@ -554,6 +554,7 @@ version = "2.0.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
"clap",
"dashmap 6.0.1", "dashmap 6.0.1",
"env_logger", "env_logger",
"fastwebsockets", "fastwebsockets",
@ -565,6 +566,7 @@ dependencies = [
"log", "log",
"regex", "regex",
"serde", "serde",
"serde_json",
"tokio", "tokio",
"tokio-util", "tokio-util",
"toml", "toml",

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.86" anyhow = "1.0.86"
bytes = "1.6.1" bytes = "1.6.1"
clap = { version = "4.5.9", features = ["cargo", "derive"] }
dashmap = "6.0.1" dashmap = "6.0.1"
env_logger = "0.11.3" env_logger = "0.11.3"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
@ -17,6 +18,7 @@ lazy_static = "1.5.0"
log = { version = "0.4.22", features = ["serde", "std"] } log = { version = "0.4.22", features = ["serde", "std"] }
regex = "1.10.5" regex = "1.10.5"
serde = { version = "1.0.204", features = ["derive"] } serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.38.1", features = ["full"] } tokio = { version = "1.38.1", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] } tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
toml = "0.8.15" toml = "0.8.15"

View file

@ -1,5 +1,6 @@
use std::{collections::HashMap, ops::RangeInclusive}; use std::{collections::HashMap, ops::RangeInclusive, path::PathBuf};
use clap::{Parser, ValueEnum};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use log::LevelFilter; use log::LevelFilter;
use regex::RegexSet; use regex::RegexSet;
@ -9,7 +10,7 @@ use wisp_mux::extensions::{
ProtocolExtensionBuilder, ProtocolExtensionBuilder,
}; };
use crate::CONFIG; use crate::{CLI, CONFIG};
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>; type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
@ -63,6 +64,7 @@ pub struct ServerConfig {
pub bind: String, pub bind: String,
pub socket: SocketType, pub socket: SocketType,
pub resolve_ipv6: bool, pub resolve_ipv6: bool,
pub tcp_nodelay: bool,
pub verbose_stats: bool, pub verbose_stats: bool,
pub enable_stats_endpoint: bool, pub enable_stats_endpoint: bool,
@ -84,6 +86,7 @@ impl Default for ServerConfig {
bind: "127.0.0.1:4000".to_string(), bind: "127.0.0.1:4000".to_string(),
socket: SocketType::default(), socket: SocketType::default(),
resolve_ipv6: false, resolve_ipv6: false,
tcp_nodelay: false,
verbose_stats: true, verbose_stats: true,
stats_endpoint: "/stats".to_string(), stats_endpoint: "/stats".to_string(),
@ -165,6 +168,8 @@ impl WispConfig {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(default)] #[serde(default)]
pub struct StreamConfig { pub struct StreamConfig {
pub tcp_nodelay: bool,
pub allow_udp: bool, pub allow_udp: bool,
pub allow_wsproxy_udp: bool, pub allow_wsproxy_udp: bool,
@ -185,6 +190,8 @@ pub struct StreamConfig {
impl Default for StreamConfig { impl Default for StreamConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
tcp_nodelay: false,
allow_udp: true, allow_udp: true,
allow_wsproxy_udp: false, allow_wsproxy_udp: false,
@ -229,3 +236,42 @@ pub struct Config {
pub wisp: WispConfig, pub wisp: WispConfig,
pub stream: StreamConfig, pub stream: StreamConfig,
} }
impl Config {
pub fn ser(&self) -> anyhow::Result<String> {
Ok(match CLI.format {
ConfigFormat::Toml => toml::to_string_pretty(self)?,
ConfigFormat::Json => serde_json::to_string_pretty(self)?,
})
}
pub fn de(string: String) -> anyhow::Result<Self> {
Ok(match CLI.format {
ConfigFormat::Toml => toml::from_str(&string)?,
ConfigFormat::Json => serde_json::from_str(&string)?,
})
}
}
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)]
pub enum ConfigFormat {
#[default]
Toml,
Json,
}
/// Server implementation of the Wisp protocol in Rust, made for epoxy.
#[derive(Parser)]
#[command(version = clap::crate_version!())]
pub struct Cli {
/// Config file to use.
pub config: Option<PathBuf>,
/// Config file format to use.
#[arg(short, long, value_enum, default_value_t = ConfigFormat::default())]
pub format: ConfigFormat,
/// Show default config and exit.
#[arg(long)]
pub default_config: bool,
}

View file

@ -1,9 +1,10 @@
#![feature(ip)] #![feature(ip)]
use std::{env::args, fmt::Write, fs::read_to_string}; use std::{fmt::Write, fs::read_to_string};
use bytes::Bytes; use bytes::Bytes;
use config::{validate_config_cache, Config}; use clap::Parser;
use config::{validate_config_cache, Cli, Config};
use dashmap::DashMap; use dashmap::DashMap;
use handle::{handle_wisp, handle_wsproxy}; use handle::{handle_wisp, handle_wsproxy};
use http_body_util::Full; use http_body_util::Full;
@ -26,9 +27,10 @@ mod stream;
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool); type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
lazy_static! { lazy_static! {
pub static ref CLI: Cli = Cli::parse();
pub static ref CONFIG: Config = { pub static ref CONFIG: Config = {
if let Some(path) = args().nth(1) { if let Some(path) = &CLI.config {
toml::from_str(&read_to_string(path).unwrap()).unwrap() Config::de(read_to_string(path).unwrap()).unwrap()
} else { } else {
Config::default() Config::default()
} }
@ -159,10 +161,16 @@ fn generate_stats() -> Result<String, std::fmt::Error> {
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
if CLI.default_config {
println!("{}", Config::default().ser()?);
return Ok(());
}
env_logger::builder() env_logger::builder()
.filter_level(CONFIG.server.log_level) .filter_level(CONFIG.server.log_level)
.parse_default_env() .parse_default_env()
.init(); .init();
validate_config_cache(); validate_config_cache();
info!( info!(

View file

@ -69,11 +69,18 @@ impl ServerListener {
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
match self { match self {
Self::Tcp(x) => x Self::Tcp(x) => {
.accept() let (stream, addr) = x
.await .accept()
.map(|(x, y)| (Either::Left(x), y.to_string())) .await
.context("failed to accept tcp connection"), .context("failed to accept tcp connection")?;
if CONFIG.server.tcp_nodelay {
stream
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok((Either::Left(stream), addr.to_string()))
}
Self::Unix(x) => x Self::Unix(x) => x
.accept() .accept()
.await .await
@ -184,6 +191,12 @@ impl ClientStream {
format!("failed to connect to host {}", packet.destination_hostname) format!("failed to connect to host {}", packet.destination_hostname)
})?; })?;
if CONFIG.stream.tcp_nodelay {
stream
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok(ClientStream::Tcp(stream)) Ok(ClientStream::Tcp(stream))
} }
StreamType::Udp => { StreamType::Udp => {