add cli opts, add default_config, add json config support

This commit is contained in:
Toshit Chawda 2024-07-22 13:46:22 -07:00
parent d78e6cef0c
commit 29f05a2ddd
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 82 additions and 11 deletions

2
Cargo.lock generated
View file

@ -554,6 +554,7 @@ version = "2.0.0"
dependencies = [
"anyhow",
"bytes",
"clap",
"dashmap 6.0.1",
"env_logger",
"fastwebsockets",
@ -565,6 +566,7 @@ dependencies = [
"log",
"regex",
"serde",
"serde_json",
"tokio",
"tokio-util",
"toml",

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
anyhow = "1.0.86"
bytes = "1.6.1"
clap = { version = "4.5.9", features = ["cargo", "derive"] }
dashmap = "6.0.1"
env_logger = "0.11.3"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
@ -17,6 +18,7 @@ lazy_static = "1.5.0"
log = { version = "0.4.22", features = ["serde", "std"] }
regex = "1.10.5"
serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.38.1", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
toml = "0.8.15"

View file

@ -1,5 +1,6 @@
use std::{collections::HashMap, ops::RangeInclusive};
use std::{collections::HashMap, ops::RangeInclusive, path::PathBuf};
use clap::{Parser, ValueEnum};
use lazy_static::lazy_static;
use log::LevelFilter;
use regex::RegexSet;
@ -9,7 +10,7 @@ use wisp_mux::extensions::{
ProtocolExtensionBuilder,
};
use crate::CONFIG;
use crate::{CLI, CONFIG};
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
@ -63,6 +64,7 @@ pub struct ServerConfig {
pub bind: String,
pub socket: SocketType,
pub resolve_ipv6: bool,
pub tcp_nodelay: bool,
pub verbose_stats: bool,
pub enable_stats_endpoint: bool,
@ -84,6 +86,7 @@ impl Default for ServerConfig {
bind: "127.0.0.1:4000".to_string(),
socket: SocketType::default(),
resolve_ipv6: false,
tcp_nodelay: false,
verbose_stats: true,
stats_endpoint: "/stats".to_string(),
@ -165,6 +168,8 @@ impl WispConfig {
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
pub tcp_nodelay: bool,
pub allow_udp: bool,
pub allow_wsproxy_udp: bool,
@ -185,6 +190,8 @@ pub struct StreamConfig {
impl Default for StreamConfig {
fn default() -> Self {
Self {
tcp_nodelay: false,
allow_udp: true,
allow_wsproxy_udp: false,
@ -229,3 +236,42 @@ pub struct Config {
pub wisp: WispConfig,
pub stream: StreamConfig,
}
impl Config {
pub fn ser(&self) -> anyhow::Result<String> {
Ok(match CLI.format {
ConfigFormat::Toml => toml::to_string_pretty(self)?,
ConfigFormat::Json => serde_json::to_string_pretty(self)?,
})
}
pub fn de(string: String) -> anyhow::Result<Self> {
Ok(match CLI.format {
ConfigFormat::Toml => toml::from_str(&string)?,
ConfigFormat::Json => serde_json::from_str(&string)?,
})
}
}
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)]
pub enum ConfigFormat {
#[default]
Toml,
Json,
}
/// Server implementation of the Wisp protocol in Rust, made for epoxy.
#[derive(Parser)]
#[command(version = clap::crate_version!())]
pub struct Cli {
/// Config file to use.
pub config: Option<PathBuf>,
/// Config file format to use.
#[arg(short, long, value_enum, default_value_t = ConfigFormat::default())]
pub format: ConfigFormat,
/// Show default config and exit.
#[arg(long)]
pub default_config: bool,
}

View file

@ -1,9 +1,10 @@
#![feature(ip)]
use std::{env::args, fmt::Write, fs::read_to_string};
use std::{fmt::Write, fs::read_to_string};
use bytes::Bytes;
use config::{validate_config_cache, Config};
use clap::Parser;
use config::{validate_config_cache, Cli, Config};
use dashmap::DashMap;
use handle::{handle_wisp, handle_wsproxy};
use http_body_util::Full;
@ -26,9 +27,10 @@ mod stream;
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
lazy_static! {
pub static ref CLI: Cli = Cli::parse();
pub static ref CONFIG: Config = {
if let Some(path) = args().nth(1) {
toml::from_str(&read_to_string(path).unwrap()).unwrap()
if let Some(path) = &CLI.config {
Config::de(read_to_string(path).unwrap()).unwrap()
} else {
Config::default()
}
@ -159,10 +161,16 @@ fn generate_stats() -> Result<String, std::fmt::Error> {
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
if CLI.default_config {
println!("{}", Config::default().ser()?);
return Ok(());
}
env_logger::builder()
.filter_level(CONFIG.server.log_level)
.parse_default_env()
.init();
validate_config_cache();
info!(

View file

@ -69,11 +69,18 @@ impl ServerListener {
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => x
.accept()
.await
.map(|(x, y)| (Either::Left(x), y.to_string()))
.context("failed to accept tcp connection"),
Self::Tcp(x) => {
let (stream, addr) = x
.accept()
.await
.context("failed to accept tcp connection")?;
if CONFIG.server.tcp_nodelay {
stream
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok((Either::Left(stream), addr.to_string()))
}
Self::Unix(x) => x
.accept()
.await
@ -184,6 +191,12 @@ impl ClientStream {
format!("failed to connect to host {}", packet.destination_hostname)
})?;
if CONFIG.stream.tcp_nodelay {
stream
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok(ClientStream::Tcp(stream))
}
StreamType::Udp => {