add tlstcp and tlsunix

This commit is contained in:
Toshit Chawda 2024-09-14 22:29:30 -07:00
parent 24ccd8d393
commit 06cc16c692
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 345 additions and 54 deletions

175
Cargo.lock generated
View file

@ -433,6 +433,22 @@ version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "cpufeatures"
version = "0.2.13"
@ -679,6 +695,7 @@ dependencies = [
"tikv-jemalloc-ctl",
"tikv-jemallocator",
"tokio",
"tokio-native-tls",
"tokio-util",
"toml",
"uuid",
@ -712,6 +729,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]]
name = "fastwebsockets"
version = "0.8.0"
@ -765,6 +788,21 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
@ -1383,6 +1421,23 @@ dependencies = [
"getrandom",
]
[[package]]
name = "native-tls"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nix"
version = "0.29.0"
@ -1435,6 +1490,50 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "openssl"
version = "0.10.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "parking"
version = "2.2.0"
@ -1527,6 +1626,12 @@ dependencies = [
"spki",
]
[[package]]
name = "pkg-config"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
@ -1796,12 +1901,44 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "schannel"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.23"
@ -2053,6 +2190,19 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
[[package]]
name = "tempfile"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [
"cfg-if",
"fastrand",
"once_cell",
"rustix",
"windows-sys 0.59.0",
]
[[package]]
name = "thiserror"
version = "1.0.63"
@ -2159,6 +2309,16 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"
@ -2410,6 +2570,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.5"
@ -2600,6 +2766,15 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.48.5"

View file

@ -33,6 +33,7 @@ shell-words = { version = "1.1.0", optional = true }
tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] }
tikv-jemallocator = "0.6.0"
tokio = { version = "1.39.3", features = ["full"] }
tokio-native-tls = "0.3.1"
tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] }
toml = { version = "0.8.19", optional = true }
uuid = { version = "1.10.0", features = ["v4"] }

View file

@ -22,8 +22,12 @@ pub enum SocketType {
/// TCP socket listener.
#[default]
Tcp,
/// TCP socket listener with TLS.
TlsTcp,
/// Unix socket listener.
Unix,
/// Unix socket listener with TLS.
TlsUnix,
/// File "socket" "listener".
/// "Accepts" a "connection" immediately.
File,
@ -56,6 +60,8 @@ pub struct ServerConfig {
pub tcp_nodelay: bool,
/// Whether or not to set "raw mode" for the file.
pub file_raw_mode: bool,
/// Keypair (public, private) in PEM format for TLS.
pub tls_keypair: Option<[PathBuf; 2]>,
/// Whether or not to show what upstreams each client is connected to in stats. This can
/// heavily increase the size of the stats.
@ -113,7 +119,7 @@ pub struct WispConfig {
/// Wisp draft version 2 password authentication extension username/passwords.
pub password_extension_users: HashMap<String, String>,
/// Wisp draft version 2 certificate authentication extension public ed25519 keys.
/// Wisp draft version 2 certificate authentication extension public ed25519 pem keys.
pub certificate_extension_keys: Vec<PathBuf>,
/// Wisp draft version 2 MOTD extension message.
@ -123,7 +129,7 @@ pub struct WispConfig {
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
/// Whether or not to enable TCP nodelay on proxied streams.
/// Whether or not to enable TCP nodelay.
pub tcp_nodelay: bool,
/// Whether or not to allow Wisp clients to create UDP streams.
@ -240,6 +246,7 @@ impl Default for ServerConfig {
resolve_ipv6: false,
tcp_nodelay: false,
file_raw_mode: false,
tls_keypair: None,
verbose_stats: true,
stats_endpoint: "/stats".to_string(),

View file

@ -3,20 +3,33 @@ use std::{os::fd::AsFd, path::PathBuf, pin::Pin};
use anyhow::Context;
use tokio::{
fs::{remove_file, try_exists, File},
io::{AsyncBufRead, AsyncRead, AsyncWrite},
io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream},
};
use tokio_native_tls::{
native_tls::{self, Identity},
TlsAcceptor, TlsStream,
};
use uuid::Uuid;
use crate::{config::SocketType, CONFIG};
pub enum Trio<A, B, C> {
pub enum Quintet<A, B, C, D, E> {
One(A),
Two(B),
Three(C),
Four(D),
Five(E),
}
impl<A: AsyncRead + Unpin, B: AsyncRead + Unpin, C: AsyncRead + Unpin> AsyncRead for Trio<A, B, C> {
impl<
A: AsyncRead + Unpin,
B: AsyncRead + Unpin,
C: AsyncRead + Unpin,
D: AsyncRead + Unpin,
E: AsyncRead + Unpin,
> AsyncRead for Quintet<A, B, C, D, E>
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@ -26,12 +39,19 @@ impl<A: AsyncRead + Unpin, B: AsyncRead + Unpin, C: AsyncRead + Unpin> AsyncRead
Self::One(x) => Pin::new(x).poll_read(cx, buf),
Self::Two(x) => Pin::new(x).poll_read(cx, buf),
Self::Three(x) => Pin::new(x).poll_read(cx, buf),
Self::Four(x) => Pin::new(x).poll_read(cx, buf),
Self::Five(x) => Pin::new(x).poll_read(cx, buf),
}
}
}
impl<A: AsyncBufRead + Unpin, B: AsyncBufRead + Unpin, C: AsyncBufRead + Unpin> AsyncBufRead
for Trio<A, B, C>
impl<
A: AsyncBufRead + Unpin,
B: AsyncBufRead + Unpin,
C: AsyncBufRead + Unpin,
D: AsyncBufRead + Unpin,
E: AsyncBufRead + Unpin,
> AsyncBufRead for Quintet<A, B, C, D, E>
{
fn poll_fill_buf(
self: Pin<&mut Self>,
@ -41,6 +61,8 @@ impl<A: AsyncBufRead + Unpin, B: AsyncBufRead + Unpin, C: AsyncBufRead + Unpin>
Self::One(x) => Pin::new(x).poll_fill_buf(cx),
Self::Two(x) => Pin::new(x).poll_fill_buf(cx),
Self::Three(x) => Pin::new(x).poll_fill_buf(cx),
Self::Four(x) => Pin::new(x).poll_fill_buf(cx),
Self::Five(x) => Pin::new(x).poll_fill_buf(cx),
}
}
@ -49,12 +71,19 @@ impl<A: AsyncBufRead + Unpin, B: AsyncBufRead + Unpin, C: AsyncBufRead + Unpin>
Self::One(x) => Pin::new(x).consume(amt),
Self::Two(x) => Pin::new(x).consume(amt),
Self::Three(x) => Pin::new(x).consume(amt),
Self::Four(x) => Pin::new(x).consume(amt),
Self::Five(x) => Pin::new(x).consume(amt),
}
}
}
impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncWrite
for Trio<A, B, C>
impl<
A: AsyncWrite + Unpin,
B: AsyncWrite + Unpin,
C: AsyncWrite + Unpin,
D: AsyncWrite + Unpin,
E: AsyncWrite + Unpin,
> AsyncWrite for Quintet<A, B, C, D, E>
{
fn poll_write(
self: Pin<&mut Self>,
@ -65,6 +94,8 @@ impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncW
Self::One(x) => Pin::new(x).poll_write(cx, buf),
Self::Two(x) => Pin::new(x).poll_write(cx, buf),
Self::Three(x) => Pin::new(x).poll_write(cx, buf),
Self::Four(x) => Pin::new(x).poll_write(cx, buf),
Self::Five(x) => Pin::new(x).poll_write(cx, buf),
}
}
@ -73,6 +104,8 @@ impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncW
Self::One(x) => x.is_write_vectored(),
Self::Two(x) => x.is_write_vectored(),
Self::Three(x) => x.is_write_vectored(),
Self::Four(x) => x.is_write_vectored(),
Self::Five(x) => x.is_write_vectored(),
}
}
@ -85,6 +118,8 @@ impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncW
Self::One(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Two(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Three(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Four(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Five(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}
@ -96,6 +131,8 @@ impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncW
Self::One(x) => Pin::new(x).poll_flush(cx),
Self::Two(x) => Pin::new(x).poll_flush(cx),
Self::Three(x) => Pin::new(x).poll_flush(cx),
Self::Four(x) => Pin::new(x).poll_flush(cx),
Self::Five(x) => Pin::new(x).poll_flush(cx),
}
}
@ -107,6 +144,8 @@ impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncW
Self::One(x) => Pin::new(x).poll_shutdown(cx),
Self::Two(x) => Pin::new(x).poll_shutdown(cx),
Self::Three(x) => Pin::new(x).poll_shutdown(cx),
Self::Four(x) => Pin::new(x).poll_shutdown(cx),
Self::Five(x) => Pin::new(x).poll_shutdown(cx),
}
}
}
@ -182,9 +221,22 @@ impl<A: Unpin, B: AsyncWrite + Unpin> AsyncWrite for Duplex<A, B> {
}
}
pub type ServerStream = Trio<TcpStream, UnixStream, Duplex<File, File>>;
pub type ServerStreamRead = Trio<tcp::OwnedReadHalf, unix::OwnedReadHalf, File>;
pub type ServerStreamWrite = Trio<tcp::OwnedWriteHalf, unix::OwnedWriteHalf, File>;
pub type ServerStream =
Quintet<TcpStream, TlsStream<TcpStream>, UnixStream, TlsStream<UnixStream>, Duplex<File, File>>;
pub type ServerStreamRead = Quintet<
tcp::OwnedReadHalf,
ReadHalf<TlsStream<TcpStream>>,
unix::OwnedReadHalf,
ReadHalf<TlsStream<UnixStream>>,
File,
>;
pub type ServerStreamWrite = Quintet<
tcp::OwnedWriteHalf,
WriteHalf<TlsStream<TcpStream>>,
unix::OwnedWriteHalf,
WriteHalf<TlsStream<UnixStream>>,
File,
>;
pub trait ServerStreamExt {
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
@ -195,15 +247,23 @@ impl ServerStreamExt for ServerStream {
match self {
Self::One(x) => {
let (r, w) = x.into_split();
(Trio::One(r), Trio::One(w))
(Quintet::One(r), Quintet::One(w))
}
Self::Two(x) => {
let (r, w) = x.into_split();
(Trio::Two(r), Trio::Two(w))
let (r, w) = tokio::io::split(x);
(Quintet::Two(r), Quintet::Two(w))
}
Self::Three(x) => {
let (r, w) = x.into_split();
(Trio::Three(r), Trio::Three(w))
(Quintet::Three(r), Quintet::Three(w))
}
Self::Four(x) => {
let (r, w) = tokio::io::split(x);
(Quintet::Four(r), Quintet::Four(w))
}
Self::Five(x) => {
let (r, w) = x.into_split();
(Quintet::Five(r), Quintet::Five(w))
}
}
}
@ -211,27 +271,54 @@ impl ServerStreamExt for ServerStream {
pub enum ServerListener {
Tcp(TcpListener),
TlsTcp(TcpListener, TlsAcceptor),
Unix(UnixListener),
TlsUnix(UnixListener, TlsAcceptor),
File(Option<PathBuf>),
}
impl ServerListener {
pub async fn new() -> anyhow::Result<Self> {
Ok(match CONFIG.server.socket {
SocketType::Tcp => Self::Tcp(
async fn bind_tcp() -> anyhow::Result<TcpListener> {
TcpListener::bind(&CONFIG.server.bind)
.await
.with_context(|| {
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
})?,
),
SocketType::Unix => {
.with_context(|| format!("failed to bind to tcp address `{}`", CONFIG.server.bind))
}
async fn bind_unix() -> anyhow::Result<UnixListener> {
if try_exists(&CONFIG.server.bind).await? {
remove_file(&CONFIG.server.bind).await?;
}
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
UnixListener::bind(&CONFIG.server.bind)
.with_context(|| format!("failed to bind to unix socket at `{}`", CONFIG.server.bind))
}
async fn create_tls() -> anyhow::Result<TlsAcceptor> {
let tls_keypair = CONFIG
.server
.tls_keypair
.as_ref()
.context("no tls keypair provided")?;
let public = tokio::fs::read(&tls_keypair[0])
.await
.context("failed to read public key")?;
let private = tokio::fs::read(&tls_keypair[1])
.await
.context("failed to read private key")?;
let identity =
Identity::from_pkcs8(&public, &private).context("failed to create tls identity")?;
Ok(TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?))
}
pub async fn new() -> anyhow::Result<Self> {
Ok(match CONFIG.server.socket {
SocketType::Tcp => Self::Tcp(Self::bind_tcp().await?),
SocketType::TlsTcp => Self::TlsTcp(Self::bind_tcp().await?, Self::create_tls().await?),
SocketType::Unix => Self::Unix(Self::bind_unix().await?),
SocketType::TlsUnix => {
Self::TlsUnix(Self::bind_unix().await?, Self::create_tls().await?)
}
SocketType::File => {
Self::File(Some(PathBuf::try_from(&CONFIG.server.bind).with_context(
@ -241,10 +328,8 @@ impl ServerListener {
})
}
pub async fn accept(&mut self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => {
let (stream, addr) = x
async fn accept_tcp(listener: &mut TcpListener) -> anyhow::Result<(TcpStream, String)> {
let (stream, addr) = listener
.accept()
.await
.context("failed to accept tcp connection")?;
@ -253,21 +338,44 @@ impl ServerListener {
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok((Trio::One(stream), addr.to_string()))
Ok((stream, addr.to_string()))
}
Self::Unix(x) => x
async fn accept_unix(listener: &mut UnixListener) -> anyhow::Result<(UnixStream, String)> {
let (stream, addr) = listener
.accept()
.await
.map(|(x, y)| {
(
Trio::Two(x),
y.as_pathname()
.context("failed to accept unix socket connection")?;
Ok((
stream,
addr.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string)
.unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"),
)
})
.context("failed to accept unix socket connection"),
))
}
pub async fn accept(&mut self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => {
let (x, y) = Self::accept_tcp(x).await?;
Ok((Quintet::One(x), y))
}
Self::TlsTcp(tcp, tls) => {
let (x, y) = Self::accept_tcp(tcp).await?;
let x = tls.accept(x).await?;
Ok((Quintet::Two(x), y))
}
Self::Unix(x) => {
let (x, y) = Self::accept_unix(x).await?;
Ok((Quintet::Three(x), y))
}
Self::TlsUnix(unix, tls) => {
let (x, y) = Self::accept_unix(unix).await?;
let x = tls.accept(x).await?;
Ok((Quintet::Four(x), y))
}
Self::File(path) => {
if let Some(path) = path.take() {
let rx = File::options()
@ -311,7 +419,7 @@ impl ServerListener {
}
Ok((
Trio::Three(Duplex::new(rx, tx)),
Quintet::Five(Duplex::new(rx, tx)),
path.to_string_lossy().to_string(),
))
} else {