From 06cc16c692731a9833c00e32f9fa429a1ac21ce8 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 14 Sep 2024 22:29:30 -0700 Subject: [PATCH] add tlstcp and tlsunix --- Cargo.lock | 175 ++++++++++++++++++++++++++++++++++ server/Cargo.toml | 1 + server/src/config.rs | 11 ++- server/src/listener.rs | 212 +++++++++++++++++++++++++++++++---------- 4 files changed, 345 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 424984c..edfda28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -433,6 +433,22 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.13" @@ -679,6 +695,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", + "tokio-native-tls", "tokio-util", "toml", "uuid", @@ -712,6 +729,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + [[package]] name = "fastwebsockets" version = "0.8.0" @@ -765,6 +788,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1383,6 +1421,23 @@ dependencies = [ "getrandom", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.29.0" @@ -1435,6 +1490,50 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl" +version = "0.10.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking" version = "2.2.0" @@ -1527,6 +1626,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1796,12 +1901,44 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "schannel" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.23" @@ -2053,6 +2190,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -2159,6 +2309,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.15" @@ -2410,6 +2570,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -2600,6 +2766,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" diff --git a/server/Cargo.toml b/server/Cargo.toml index 3c34fde..5f6651d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -33,6 +33,7 @@ shell-words = { version = "1.1.0", optional = true } tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] } tikv-jemallocator = "0.6.0" tokio = { version = "1.39.3", features = ["full"] } +tokio-native-tls = "0.3.1" tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } toml = { version = "0.8.19", optional = true } uuid = { version = "1.10.0", features = ["v4"] } diff --git a/server/src/config.rs b/server/src/config.rs index c8ff5e1..8d341ca 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -22,8 +22,12 @@ pub enum SocketType { /// TCP socket listener. #[default] Tcp, + /// TCP socket listener with TLS. + TlsTcp, /// Unix socket listener. Unix, + /// Unix socket listener with TLS. + TlsUnix, /// File "socket" "listener". /// "Accepts" a "connection" immediately. File, @@ -56,6 +60,8 @@ pub struct ServerConfig { pub tcp_nodelay: bool, /// Whether or not to set "raw mode" for the file. pub file_raw_mode: bool, + /// Keypair (public, private) in PEM format for TLS. + pub tls_keypair: Option<[PathBuf; 2]>, /// Whether or not to show what upstreams each client is connected to in stats. This can /// heavily increase the size of the stats. @@ -113,7 +119,7 @@ pub struct WispConfig { /// Wisp draft version 2 password authentication extension username/passwords. pub password_extension_users: HashMap, - /// Wisp draft version 2 certificate authentication extension public ed25519 keys. + /// Wisp draft version 2 certificate authentication extension public ed25519 pem keys. pub certificate_extension_keys: Vec, /// Wisp draft version 2 MOTD extension message. @@ -123,7 +129,7 @@ pub struct WispConfig { #[derive(Serialize, Deserialize)] #[serde(default)] pub struct StreamConfig { - /// Whether or not to enable TCP nodelay on proxied streams. + /// Whether or not to enable TCP nodelay. pub tcp_nodelay: bool, /// Whether or not to allow Wisp clients to create UDP streams. @@ -240,6 +246,7 @@ impl Default for ServerConfig { resolve_ipv6: false, tcp_nodelay: false, file_raw_mode: false, + tls_keypair: None, verbose_stats: true, stats_endpoint: "/stats".to_string(), diff --git a/server/src/listener.rs b/server/src/listener.rs index 6529455..fe7ecd3 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -3,20 +3,33 @@ use std::{os::fd::AsFd, path::PathBuf, pin::Pin}; use anyhow::Context; use tokio::{ fs::{remove_file, try_exists, File}, - io::{AsyncBufRead, AsyncRead, AsyncWrite}, + io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadHalf, WriteHalf}, net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream}, }; +use tokio_native_tls::{ + native_tls::{self, Identity}, + TlsAcceptor, TlsStream, +}; use uuid::Uuid; use crate::{config::SocketType, CONFIG}; -pub enum Trio { +pub enum Quintet { One(A), Two(B), Three(C), + Four(D), + Five(E), } -impl AsyncRead for Trio { +impl< + A: AsyncRead + Unpin, + B: AsyncRead + Unpin, + C: AsyncRead + Unpin, + D: AsyncRead + Unpin, + E: AsyncRead + Unpin, + > AsyncRead for Quintet +{ fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -26,12 +39,19 @@ impl AsyncRead Self::One(x) => Pin::new(x).poll_read(cx, buf), Self::Two(x) => Pin::new(x).poll_read(cx, buf), Self::Three(x) => Pin::new(x).poll_read(cx, buf), + Self::Four(x) => Pin::new(x).poll_read(cx, buf), + Self::Five(x) => Pin::new(x).poll_read(cx, buf), } } } -impl AsyncBufRead - for Trio +impl< + A: AsyncBufRead + Unpin, + B: AsyncBufRead + Unpin, + C: AsyncBufRead + Unpin, + D: AsyncBufRead + Unpin, + E: AsyncBufRead + Unpin, + > AsyncBufRead for Quintet { fn poll_fill_buf( self: Pin<&mut Self>, @@ -41,6 +61,8 @@ impl Self::One(x) => Pin::new(x).poll_fill_buf(cx), Self::Two(x) => Pin::new(x).poll_fill_buf(cx), Self::Three(x) => Pin::new(x).poll_fill_buf(cx), + Self::Four(x) => Pin::new(x).poll_fill_buf(cx), + Self::Five(x) => Pin::new(x).poll_fill_buf(cx), } } @@ -49,12 +71,19 @@ impl Self::One(x) => Pin::new(x).consume(amt), Self::Two(x) => Pin::new(x).consume(amt), Self::Three(x) => Pin::new(x).consume(amt), + Self::Four(x) => Pin::new(x).consume(amt), + Self::Five(x) => Pin::new(x).consume(amt), } } } -impl AsyncWrite - for Trio +impl< + A: AsyncWrite + Unpin, + B: AsyncWrite + Unpin, + C: AsyncWrite + Unpin, + D: AsyncWrite + Unpin, + E: AsyncWrite + Unpin, + > AsyncWrite for Quintet { fn poll_write( self: Pin<&mut Self>, @@ -65,6 +94,8 @@ impl AsyncW Self::One(x) => Pin::new(x).poll_write(cx, buf), Self::Two(x) => Pin::new(x).poll_write(cx, buf), Self::Three(x) => Pin::new(x).poll_write(cx, buf), + Self::Four(x) => Pin::new(x).poll_write(cx, buf), + Self::Five(x) => Pin::new(x).poll_write(cx, buf), } } @@ -73,6 +104,8 @@ impl AsyncW Self::One(x) => x.is_write_vectored(), Self::Two(x) => x.is_write_vectored(), Self::Three(x) => x.is_write_vectored(), + Self::Four(x) => x.is_write_vectored(), + Self::Five(x) => x.is_write_vectored(), } } @@ -85,6 +118,8 @@ impl AsyncW Self::One(x) => Pin::new(x).poll_write_vectored(cx, bufs), Self::Two(x) => Pin::new(x).poll_write_vectored(cx, bufs), Self::Three(x) => Pin::new(x).poll_write_vectored(cx, bufs), + Self::Four(x) => Pin::new(x).poll_write_vectored(cx, bufs), + Self::Five(x) => Pin::new(x).poll_write_vectored(cx, bufs), } } @@ -96,6 +131,8 @@ impl AsyncW Self::One(x) => Pin::new(x).poll_flush(cx), Self::Two(x) => Pin::new(x).poll_flush(cx), Self::Three(x) => Pin::new(x).poll_flush(cx), + Self::Four(x) => Pin::new(x).poll_flush(cx), + Self::Five(x) => Pin::new(x).poll_flush(cx), } } @@ -107,6 +144,8 @@ impl AsyncW Self::One(x) => Pin::new(x).poll_shutdown(cx), Self::Two(x) => Pin::new(x).poll_shutdown(cx), Self::Three(x) => Pin::new(x).poll_shutdown(cx), + Self::Four(x) => Pin::new(x).poll_shutdown(cx), + Self::Five(x) => Pin::new(x).poll_shutdown(cx), } } } @@ -182,9 +221,22 @@ impl AsyncWrite for Duplex { } } -pub type ServerStream = Trio>; -pub type ServerStreamRead = Trio; -pub type ServerStreamWrite = Trio; +pub type ServerStream = + Quintet, UnixStream, TlsStream, Duplex>; +pub type ServerStreamRead = Quintet< + tcp::OwnedReadHalf, + ReadHalf>, + unix::OwnedReadHalf, + ReadHalf>, + File, +>; +pub type ServerStreamWrite = Quintet< + tcp::OwnedWriteHalf, + WriteHalf>, + unix::OwnedWriteHalf, + WriteHalf>, + File, +>; pub trait ServerStreamExt { fn split(self) -> (ServerStreamRead, ServerStreamWrite); @@ -195,15 +247,23 @@ impl ServerStreamExt for ServerStream { match self { Self::One(x) => { let (r, w) = x.into_split(); - (Trio::One(r), Trio::One(w)) + (Quintet::One(r), Quintet::One(w)) } Self::Two(x) => { - let (r, w) = x.into_split(); - (Trio::Two(r), Trio::Two(w)) + let (r, w) = tokio::io::split(x); + (Quintet::Two(r), Quintet::Two(w)) } Self::Three(x) => { let (r, w) = x.into_split(); - (Trio::Three(r), Trio::Three(w)) + (Quintet::Three(r), Quintet::Three(w)) + } + Self::Four(x) => { + let (r, w) = tokio::io::split(x); + (Quintet::Four(r), Quintet::Four(w)) + } + Self::Five(x) => { + let (r, w) = x.into_split(); + (Quintet::Five(r), Quintet::Five(w)) } } } @@ -211,27 +271,54 @@ impl ServerStreamExt for ServerStream { pub enum ServerListener { Tcp(TcpListener), + TlsTcp(TcpListener, TlsAcceptor), Unix(UnixListener), + TlsUnix(UnixListener, TlsAcceptor), File(Option), } impl ServerListener { + async fn bind_tcp() -> anyhow::Result { + TcpListener::bind(&CONFIG.server.bind) + .await + .with_context(|| format!("failed to bind to tcp address `{}`", CONFIG.server.bind)) + } + + async fn bind_unix() -> anyhow::Result { + if try_exists(&CONFIG.server.bind).await? { + remove_file(&CONFIG.server.bind).await?; + } + UnixListener::bind(&CONFIG.server.bind) + .with_context(|| format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)) + } + + async fn create_tls() -> anyhow::Result { + let tls_keypair = CONFIG + .server + .tls_keypair + .as_ref() + .context("no tls keypair provided")?; + + let public = tokio::fs::read(&tls_keypair[0]) + .await + .context("failed to read public key")?; + let private = tokio::fs::read(&tls_keypair[1]) + .await + .context("failed to read private key")?; + + let identity = + Identity::from_pkcs8(&public, &private).context("failed to create tls identity")?; + + Ok(TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?)) + } + pub async fn new() -> anyhow::Result { Ok(match CONFIG.server.socket { - SocketType::Tcp => Self::Tcp( - TcpListener::bind(&CONFIG.server.bind) - .await - .with_context(|| { - format!("failed to bind to tcp address `{}`", CONFIG.server.bind) - })?, - ), - SocketType::Unix => { - if try_exists(&CONFIG.server.bind).await? { - remove_file(&CONFIG.server.bind).await?; - } - Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| { - format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) - })?) + SocketType::Tcp => Self::Tcp(Self::bind_tcp().await?), + SocketType::TlsTcp => Self::TlsTcp(Self::bind_tcp().await?, Self::create_tls().await?), + SocketType::Unix => Self::Unix(Self::bind_unix().await?), + SocketType::TlsUnix => { + Self::TlsUnix(Self::bind_unix().await?, Self::create_tls().await?) } SocketType::File => { Self::File(Some(PathBuf::try_from(&CONFIG.server.bind).with_context( @@ -241,33 +328,54 @@ impl ServerListener { }) } + async fn accept_tcp(listener: &mut TcpListener) -> anyhow::Result<(TcpStream, String)> { + let (stream, addr) = listener + .accept() + .await + .context("failed to accept tcp connection")?; + if CONFIG.server.tcp_nodelay { + stream + .set_nodelay(true) + .context("failed to set tcp nodelay")?; + } + Ok((stream, addr.to_string())) + } + + async fn accept_unix(listener: &mut UnixListener) -> anyhow::Result<(UnixStream, String)> { + let (stream, addr) = listener + .accept() + .await + .context("failed to accept unix socket connection")?; + + Ok(( + stream, + addr.as_pathname() + .and_then(|x| x.to_str()) + .map(ToString::to_string) + .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), + )) + } + pub async fn accept(&mut self) -> anyhow::Result<(ServerStream, String)> { match self { Self::Tcp(x) => { - let (stream, addr) = x - .accept() - .await - .context("failed to accept tcp connection")?; - if CONFIG.server.tcp_nodelay { - stream - .set_nodelay(true) - .context("failed to set tcp nodelay")?; - } - Ok((Trio::One(stream), addr.to_string())) + let (x, y) = Self::accept_tcp(x).await?; + Ok((Quintet::One(x), y)) + } + Self::TlsTcp(tcp, tls) => { + let (x, y) = Self::accept_tcp(tcp).await?; + let x = tls.accept(x).await?; + Ok((Quintet::Two(x), y)) + } + Self::Unix(x) => { + let (x, y) = Self::accept_unix(x).await?; + Ok((Quintet::Three(x), y)) + } + Self::TlsUnix(unix, tls) => { + let (x, y) = Self::accept_unix(unix).await?; + let x = tls.accept(x).await?; + Ok((Quintet::Four(x), y)) } - Self::Unix(x) => x - .accept() - .await - .map(|(x, y)| { - ( - Trio::Two(x), - y.as_pathname() - .and_then(|x| x.to_str()) - .map(ToString::to_string) - .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), - ) - }) - .context("failed to accept unix socket connection"), Self::File(path) => { if let Some(path) = path.take() { let rx = File::options() @@ -311,7 +419,7 @@ impl ServerListener { } Ok(( - Trio::Three(Duplex::new(rx, tx)), + Quintet::Five(Duplex::new(rx, tx)), path.to_string_lossy().to_string(), )) } else {