diff --git a/server/src/config.rs b/server/src/config.rs index 653cdad..95aae18 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -70,6 +70,17 @@ pub enum RuntimeFlavor { /// Alternate multi-threaded tokio runtime. #[cfg(tokio_unstable)] MultiThreadAlt, + /// Thread-per-core tokio runtimes. + ThreadPerCore, +} + +impl RuntimeFlavor { + pub fn is_thread_per_core(&self) -> bool { + match self { + Self::ThreadPerCore => true, + _ => false, + } + } } pub type BindAddr = (SocketType, String); diff --git a/server/src/listener.rs b/server/src/listener.rs index e7ae5af..2040baf 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -1,9 +1,5 @@ use std::{ - io::{BufReader, Cursor}, - os::fd::AsFd, - path::PathBuf, - pin::Pin, - sync::Arc, + io::{BufReader, Cursor}, net::SocketAddr, os::fd::AsFd, path::PathBuf, pin::Pin, str::FromStr, sync::Arc }; use anyhow::Context; @@ -11,7 +7,7 @@ use rustls_pemfile::{certs, private_key}; use tokio::{ fs::{remove_file, try_exists, File}, io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadHalf, WriteHalf}, - net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream}, + net::{tcp, unix, TcpListener, TcpSocket, TcpStream, UnixListener, UnixStream}, }; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; use uuid::Uuid; @@ -286,9 +282,20 @@ pub enum ServerListener { impl ServerListener { async fn bind_tcp(bind: &BindAddr) -> anyhow::Result { - TcpListener::bind(&bind.1) - .await - .with_context(|| format!("failed to bind to tcp address `{}`", bind.1)) + if CONFIG.server.runtime.is_thread_per_core() { + let listener = TcpSocket::new_v4()?; + listener + .set_reuseport(true) + .context("failed to set SO_REUSEPORT")?; + listener + .bind(SocketAddr::from_str(&bind.1)?) + .with_context(|| format!("failed to bind to tcp address `{}`", bind.1))?; + Ok(listener.listen(64)?) + } else { + TcpListener::bind(&bind.1) + .await + .with_context(|| format!("failed to bind to tcp address `{}`", bind.1)) + } } async fn bind_unix(bind: &BindAddr) -> anyhow::Result { diff --git a/server/src/main.rs b/server/src/main.rs index 5f0a098..d70ce98 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,11 +2,12 @@ #![deny(clippy::todo)] #![allow(unexpected_cfgs)] -use std::{collections::HashMap, fs::read_to_string, net::IpAddr}; +use std::{collections::HashMap, fs::read_to_string, future::Future, net::IpAddr, pin::Pin}; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use clap::Parser; -use config::{validate_config_cache, Cli, Config, RuntimeFlavor, StatsEndpoint}; +use config::{validate_config_cache, BindAddr, Cli, Config, RuntimeFlavor, StatsEndpoint}; +use futures_util::{future::select_all, FutureExt, TryFutureExt}; use handle::{handle_wisp, handle_wsproxy, wisp::wispnet::handle_wispnet}; use hickory_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, @@ -15,13 +16,13 @@ use hickory_resolver::{ }; use lazy_static::lazy_static; use listener::ServerListener; -use log::{error, info, trace, warn}; +use log::{debug, error, info, trace, warn}; use route::{route_stats, ServerRouteResult}; use stats::generate_stats; use tokio::{ runtime, signal::unix::{signal, SignalKind}, - sync::Mutex, + sync::{oneshot, Mutex}, }; use uuid::Uuid; use wisp_mux::packet::ConnectPacket; @@ -134,6 +135,9 @@ fn main() -> Result<()> { RuntimeFlavor::MultiThread => runtime::Builder::new_multi_thread(), #[cfg(tokio_unstable)] RuntimeFlavor::MultiThreadAlt => runtime::Builder::new_multi_thread_alt(), + + // threadpercore has completely different runtime setup + RuntimeFlavor::ThreadPerCore => return threadpercore_main(), }; builder.enable_all(); @@ -146,7 +150,7 @@ fn main() -> Result<()> { } #[doc(hidden)] -async fn async_main() -> Result<()> { +async fn async_init() { #[cfg(feature = "tokio-console")] console_subscriber::init(); @@ -160,20 +164,68 @@ async fn async_main() -> Result<()> { trace!("CLI: {:#?}", &*CLI); trace!("CONFIG: {:#?}", &*CONFIG); trace!("RESOLVER: {:?}", &*RESOLVER); +} - tokio::spawn(async { - let mut sig = signal(SignalKind::user_defined1()).unwrap(); - while sig.recv().await.is_some() { - match generate_stats().await { - Ok(stats) => info!("Stats:\n{}", stats), - Err(err) => error!("error while creating stats {:?}", err), - } +#[doc(hidden)] +fn threadpercore_main() -> Result<()> { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + rt.block_on(async_init()); + + let cores = std::thread::available_parallelism()?.get(); + + let mut threads = Vec::with_capacity(cores); + + for _ in 1..cores { + threads.push(Box::pin(threadpercore_init_thread(listen_wisp()).map_err(|x| anyhow!(x)).map(|x| x?)) as Pin>>>); + } + + rt.block_on(async move { + tokio::spawn(listen_stats_cli()); + + if let Some(bind_addr) = CONFIG + .server + .stats_endpoint + .as_ref() + .and_then(StatsEndpoint::get_bindaddr) + { + tokio::spawn(listen_stats(bind_addr)); } - }); - let mut listener = ServerListener::new(&CONFIG.server.bind) - .await - .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; + let wisp = Box::pin(tokio::spawn(listen_wisp()).map_err(|x| anyhow!(x)).map(|x| x?)) as Pin>>>; + + select_all(threads.into_iter().chain(std::iter::once(wisp))).await.0 + }) +} + +#[doc(hidden)] +fn threadpercore_init_thread( + func: impl Future> + Sync + Send + 'static, +) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + std::thread::spawn(move || { + let ret = (|| { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + debug!("created threadpercore thread"); + + rt.block_on(func) + })(); + + let _ = tx.send(ret.context("thread per core thread failed")); + }); + rx +} + +#[doc(hidden)] +async fn async_main() -> Result<()> { + async_init().await; + + tokio::spawn(listen_stats_cli()); if let Some(bind_addr) = CONFIG .server @@ -181,27 +233,50 @@ async fn async_main() -> Result<()> { .as_ref() .and_then(StatsEndpoint::get_bindaddr) { - info!("stats server listening on {:?}", bind_addr); - let mut stats_listener = ServerListener::new(&bind_addr).await.with_context(|| { - format!("failed to bind to address {} for stats server", bind_addr.1) - })?; - - tokio::spawn(async move { - loop { - match stats_listener.accept().await { - Ok((stream, _)) => { - tokio::spawn(async move { - if let Err(e) = Box::pin(route_stats(stream)).await { - error!("error while routing stats client: {:?}", e); - } - }); - } - Err(e) => error!("error while accepting stats client: {:?}", e), - } - } - }); + tokio::spawn(listen_stats(bind_addr)); } + listen_wisp().await +} + +#[doc(hidden)] +async fn listen_stats_cli() { + let mut sig = signal(SignalKind::user_defined1()).unwrap(); + while sig.recv().await.is_some() { + match generate_stats().await { + Ok(stats) => info!("Stats:\n{}", stats), + Err(err) => error!("error while creating stats {:?}", err), + } + } +} + +#[doc(hidden)] +async fn listen_stats(bind_addr: BindAddr) -> Result<()> { + info!("stats server listening on {:?}", bind_addr); + let mut stats_listener = ServerListener::new(&bind_addr) + .await + .with_context(|| format!("failed to bind to address {} for stats server", bind_addr.1))?; + + loop { + match stats_listener.accept().await { + Ok((stream, _)) => { + tokio::spawn(async move { + if let Err(e) = Box::pin(route_stats(stream)).await { + error!("error while routing stats client: {:?}", e); + } + }); + } + Err(e) => error!("error while accepting stats client: {:?}", e), + } + } +} + +#[doc(hidden)] +async fn listen_wisp() -> Result<()> { + let mut listener = ServerListener::new(&CONFIG.server.bind) + .await + .with_context(|| format!("failed to bind to address {}", CONFIG.server.bind.1))?; + let stats_endpoint = CONFIG .server .stats_endpoint