diff --git a/server/src/config.rs b/server/src/config.rs index ae4b00b..d9b4d3e 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -59,6 +59,18 @@ pub enum SocketTransport { LengthDelimitedLe, } +#[derive(Serialize, Deserialize, Default, Debug)] +#[serde(rename_all = "lowercase")] +pub enum RuntimeFlavor { + /// Single-threaded tokio runtime. + SingleThread, + /// Multi-threaded tokio runtime. + #[default] + MultiThread, + /// Multi-threaded tokio runtime with an alternate work in progress scheduler. + MultiThreadAlt, +} + pub type BindAddr = (SocketType, String); #[derive(Serialize, Deserialize, Debug)] @@ -106,6 +118,8 @@ pub struct ServerConfig { /// Server log level. pub log_level: LevelFilter, + /// Runtime type. + pub runtime: RuntimeFlavor, } #[derive(Serialize, Deserialize, PartialEq, Eq)] @@ -324,6 +338,7 @@ impl Default for ServerConfig { max_message_size: 64 * 1024, log_level: LevelFilter::Info, + runtime: RuntimeFlavor::default(), } } } diff --git a/server/src/main.rs b/server/src/main.rs index 6e89762..5fee908 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, fs::read_to_string, net::IpAddr}; use anyhow::Context; use clap::Parser; -use config::{validate_config_cache, Cli, Config}; +use config::{validate_config_cache, Cli, Config, RuntimeFlavor}; use dashmap::DashMap; use handle::{handle_wisp, handle_wsproxy}; use hickory_resolver::{ @@ -17,7 +17,10 @@ use listener::ServerListener; use log::{error, info}; use route::{route_stats, ServerRouteResult}; use serde::Serialize; -use tokio::signal::unix::{signal, SignalKind}; +use tokio::{ + runtime, + signal::unix::{signal, SignalKind}, +}; use uuid::Uuid; use wisp_mux::{ConnectPacket, StreamType}; @@ -199,8 +202,7 @@ fn handle_stream(stream: ServerRouteResult, id: String) { #[global_allocator] static JEMALLOCATOR: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -#[tokio::main(flavor = "multi_thread")] -async fn main() -> anyhow::Result<()> { +fn main() -> anyhow::Result<()> { if CLI.default_config { println!("{}", Config::default().ser()?); return Ok(()); @@ -211,63 +213,75 @@ async fn main() -> anyhow::Result<()> { .parse_default_env() .init(); - validate_config_cache().await; + let mut builder: runtime::Builder = match CONFIG.server.runtime { + RuntimeFlavor::SingleThread => runtime::Builder::new_current_thread(), + RuntimeFlavor::MultiThread => runtime::Builder::new_multi_thread(), + RuntimeFlavor::MultiThreadAlt => runtime::Builder::new_multi_thread_alt(), + }; - info!( - "listening on {:?} with socket transport {:?}", - CONFIG.server.bind, CONFIG.server.transport - ); + builder.enable_all(); + let rt = builder.build()?; - tokio::spawn(async { - let mut sig = signal(SignalKind::user_defined1()).unwrap(); - while sig.recv().await.is_some() { - info!("Stats:\n{}", generate_stats().unwrap()); - } - }); + rt.block_on(async { + validate_config_cache().await; - let mut listener = ServerListener::new(&CONFIG.server.bind) - .await - .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; + info!( + "listening on {:?} with runtime flavor {:?} and socket transport {:?}", + CONFIG.server.bind, CONFIG.server.runtime, CONFIG.server.transport + ); - if CONFIG.server.enable_stats_endpoint { - if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() { - info!("stats server listening on {:?}", bind_addr); - let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| { - format!("failed to bind to address {} for stats server", bind_addr.1) - })?; + tokio::spawn(async { + let mut sig = signal(SignalKind::user_defined1()).unwrap(); + while sig.recv().await.is_some() { + info!("Stats:\n{}", generate_stats().unwrap()); + } + }); - tokio::spawn(async move { - loop { - match stats_listener.accept().await { - Ok((stream, _)) => { - if let Err(e) = route_stats(stream).await { - error!("error while routing stats client: {:?}", e); - } - } - Err(e) => error!("error while accepting stats client: {:?}", e), - } - } - }); - } - } + let mut listener = ServerListener::new(&CONFIG.server.bind) + .await + .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; + + if CONFIG.server.enable_stats_endpoint { + if let Some(bind_addr) = CONFIG.server.stats_endpoint.get_bindaddr() { + info!("stats server listening on {:?}", bind_addr); + let mut stats_listener = + ServerListener::new(&bind_addr).await.with_context(|| { + format!("failed to bind to address {} for stats server", bind_addr.1) + })?; - let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint(); - loop { - let stats_endpoint = stats_endpoint.clone(); - match listener.accept().await { - Ok((stream, id)) => { tokio::spawn(async move { - let res = route::route(stream, stats_endpoint, move |stream| { - handle_stream(stream, id) - }) - .await; - - if let Err(e) = res { - error!("error while routing client: {:?}", e); + loop { + match stats_listener.accept().await { + Ok((stream, _)) => { + if let Err(e) = route_stats(stream).await { + error!("error while routing stats client: {:?}", e); + } + } + Err(e) => error!("error while accepting stats client: {:?}", e), + } } }); } - Err(e) => error!("error while accepting client: {:?}", e), } - } + + let stats_endpoint = CONFIG.server.stats_endpoint.get_endpoint(); + loop { + let stats_endpoint = stats_endpoint.clone(); + match listener.accept().await { + Ok((stream, id)) => { + tokio::spawn(async move { + let res = route::route(stream, stats_endpoint, move |stream| { + handle_stream(stream, id) + }) + .await; + + if let Err(e) = res { + error!("error while routing client: {:?}", e); + } + }); + } + Err(e) => error!("error while accepting client: {:?}", e), + } + } + }) }