diff --git a/Cargo.lock b/Cargo.lock index 595234a..d301bbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -476,6 +476,33 @@ dependencies = [ "typenum", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dashmap" version = "6.0.1" @@ -528,6 +555,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ed25519-dalek" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a3daa8e81a3963a60642bcc1f90a670680bd4a77535faa384e9d1c79d620871" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + [[package]] name = "either" version = "1.13.0" @@ -615,6 +656,7 @@ dependencies = [ "cfg-if", "clap", "dashmap", + "ed25519-dalek", "env_logger", "event-listener", "fastwebsockets", @@ -632,6 +674,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sha2", "shell-words", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -688,6 +731,12 @@ dependencies = [ "utf-8", ] +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "flate2" version = "1.0.33" @@ -1668,6 +1717,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.35" @@ -1744,6 +1802,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "send_wrapper" version = "0.4.0" @@ -1824,6 +1888,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" diff --git a/server/Cargo.toml b/server/Cargo.toml index b08d655..3c34fde 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,6 +10,7 @@ bytes = "1.7.1" cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } dashmap = "6.0.1" +ed25519-dalek = { version = "2.1.1", features = ["pem"] } env_logger = "0.11.5" event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } @@ -27,6 +28,7 @@ regex = "1.10.6" serde = { version = "1.0.208", features = ["derive"] } serde_json = { version = "1.0.125", optional = true } serde_yaml = { version = "0.9.34", optional = true } +sha2 = "0.10.8" shell-words = { version = "1.1.0", optional = true } tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] } tikv-jemallocator = "0.6.0" @@ -34,7 +36,7 @@ tokio = { version = "1.39.3", features = ["full"] } tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } toml = { version = "0.8.19", optional = true } uuid = { version = "1.10.0", features = ["v4"] } -wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream"] } +wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] } [features] default = ["toml"] diff --git a/server/src/config.rs b/server/src/config.rs index 7623f12..63374c6 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -7,11 +7,14 @@ use log::LevelFilter; use regex::RegexSet; use serde::{Deserialize, Serialize}; use wisp_mux::extensions::{ - password::PasswordProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder, + cert::CertAuthProtocolExtensionBuilder, + motd::MotdProtocolExtensionBuilder, + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, }; -use crate::{CLI, CONFIG, RESOLVER}; +use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER}; #[derive(Serialize, Deserialize, Default, Debug)] #[serde(rename_all = "lowercase")] @@ -75,13 +78,22 @@ pub struct ServerConfig { pub log_level: LevelFilter, } -#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum ProtocolExtension { /// Wisp draft version 2 UDP protocol extension. Udp, - /// Wisp draft version 2 password protocol extension. + /// Wisp draft version 2 MOTD protocol extension. + Motd, +} + +#[derive(Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ProtocolExtensionAuth { + /// Wisp draft version 2 password authentication protocol extension. Password, + /// Wisp draft version 2 certificate authentication protocol extension. + Certificate, } #[derive(Serialize, Deserialize)] @@ -96,8 +108,16 @@ pub struct WispConfig { pub wisp_v2: bool, /// Wisp draft version 2 extensions advertised. pub extensions: Vec, - /// Wisp draft version 2 password extension username/passwords. + /// Wisp draft version 2 authentication extension advertised. + pub auth_extension: Option, + + /// Wisp draft version 2 password authentication extension username/passwords. pub password_extension_users: HashMap, + /// Wisp draft version 2 certificate authentication extension public ed25519 keys. + pub certificate_extension_keys: Vec, + + /// Wisp draft version 2 MOTD extension message. + pub motd_extension: String, } #[derive(Serialize, Deserialize)] @@ -202,11 +222,11 @@ lazy_static! { }; } -pub fn validate_config_cache() { +pub async fn validate_config_cache() { // constructs regexes let _ = CONFIG_CACHE.allowed_ports; // constructs wisp config - CONFIG.wisp.to_opts().unwrap(); + CONFIG.wisp.to_opts().await.unwrap(); // constructs resolver RESOLVER.clear_cache(); } @@ -244,29 +264,53 @@ impl Default for WispConfig { wisp_v2: false, extensions: vec![ProtocolExtension::Udp], + auth_extension: None, + password_extension_users: HashMap::new(), + certificate_extension_keys: Vec::new(), + + motd_extension: String::new(), } } } impl WispConfig { - pub fn to_opts(&self) -> anyhow::Result<(Option>, u32)> { + pub async fn to_opts( + &self, + ) -> anyhow::Result<(Option>, Vec, u32)> { if self.wisp_v2 { let mut extensions: Vec = Vec::new(); + let mut required_extensions: Vec = Vec::new(); if self.extensions.contains(&ProtocolExtension::Udp) { extensions.push(Box::new(UdpProtocolExtensionBuilder)); } - if self.extensions.contains(&ProtocolExtension::Password) { - extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server( - self.password_extension_users.clone(), + if self.extensions.contains(&ProtocolExtension::Motd) { + extensions.push(Box::new(MotdProtocolExtensionBuilder::Server( + self.motd_extension.clone(), ))); } - Ok((Some(extensions), self.buffer_size)) + match self.auth_extension { + Some(ProtocolExtensionAuth::Password) => { + extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server( + self.password_extension_users.clone(), + ))); + required_extensions.push(PasswordProtocolExtension::ID); + } + Some(ProtocolExtensionAuth::Certificate) => { + extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server( + get_certificates_from_paths(self.certificate_extension_keys.clone()) + .await?, + ))); + } + None => {} + } + + Ok((Some(extensions), required_extensions, self.buffer_size)) } else { - Ok((None, self.buffer_size)) + Ok((None, Vec::new(), self.buffer_size)) } } } @@ -370,7 +414,7 @@ impl Config { } } -#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, ValueEnum)] +#[derive(Clone, Copy, Eq, PartialEq, ValueEnum)] pub enum ConfigFormat { #[cfg(feature = "toml")] Toml, diff --git a/server/src/handle/mod.rs b/server/src/handle/mod.rs index 2ad7f43..80106ea 100644 --- a/server/src/handle/mod.rs +++ b/server/src/handle/mod.rs @@ -1,7 +1,5 @@ -#[cfg(feature = "twisp")] -pub mod twisp; -mod wisp; -mod wsproxy; +pub mod wisp; +pub mod wsproxy; pub use wisp::handle_wisp; pub use wsproxy::handle_wsproxy; diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp/mod.rs similarity index 90% rename from server/src/handle/wisp.rs rename to server/src/handle/wisp/mod.rs index 01e816d..8dbe507 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp/mod.rs @@ -1,3 +1,7 @@ +#[cfg(feature = "twisp")] +pub mod twisp; +pub mod utils; + use std::sync::Arc; use anyhow::Context; @@ -64,7 +68,7 @@ async fn handle_stream( muxstream: MuxStream, id: String, event: Arc, - #[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap, + #[cfg(feature = "twisp")] twisp_map: twisp::TwispMap, ) { let requested_stream = connect.clone(); @@ -175,7 +179,7 @@ async fn handle_stream( let id = muxstream.stream_id; let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); - match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) + match twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) .await { Ok(()) => { @@ -213,12 +217,12 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { let (read, write) = stream; cfg_if! { if #[cfg(feature = "twisp")] { - let twisp_map = super::twisp::new_map(); - let (extensions, buffer_size) = CONFIG.wisp.to_opts()?; + let twisp_map = twisp::new_map(); + let (extensions, required_extensions, buffer_size) = CONFIG.wisp.to_opts().await?; let extensions = match extensions { Some(mut exts) => { - exts.push(super::twisp::new_ext(twisp_map.clone())); + exts.push(twisp::new_ext(twisp_map.clone())); Some(exts) }, None => { @@ -226,18 +230,23 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { } }; } else { - let (extensions, buffer_size) = CONFIG.wisp.to_opts()?; + let (extensions, required_extensions, buffer_size) = CONFIG.wisp.to_opts().await?; } } let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions) .await .context("failed to create server multiplexor")? - .with_no_required_extensions(); + .with_required_extensions(&required_extensions) + .await?; debug!( "new wisp client id {:?} connected with extensions {:?}", - id, mux.supported_extensions.iter().map(|x| x.get_id()).collect::>() + id, + mux.supported_extensions + .iter() + .map(|x| x.get_id()) + .collect::>() ); let mut set: JoinSet<()> = JoinSet::new(); diff --git a/server/src/handle/twisp.rs b/server/src/handle/wisp/twisp.rs similarity index 95% rename from server/src/handle/twisp.rs rename to server/src/handle/wisp/twisp.rs index aa34b57..cbf0071 100644 --- a/server/src/handle/twisp.rs +++ b/server/src/handle/wisp/twisp.rs @@ -91,15 +91,15 @@ impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder { } fn build_from_bytes( - &self, + &mut self, _: Bytes, _: wisp_mux::Role, ) -> std::result::Result { Ok(TWispServerProtocolExtension(self.0.clone()).into()) } - fn build_to_extension(&self, _: wisp_mux::Role) -> AnyProtocolExtension { - TWispServerProtocolExtension(self.0.clone()).into() + fn build_to_extension(&mut self, _: wisp_mux::Role) -> Result { + Ok(TWispServerProtocolExtension(self.0.clone()).into()) } } diff --git a/server/src/handle/wisp/utils.rs b/server/src/handle/wisp/utils.rs new file mode 100644 index 0000000..b098cab --- /dev/null +++ b/server/src/handle/wisp/utils.rs @@ -0,0 +1,20 @@ +use std::{path::PathBuf, sync::Arc}; + +use ed25519_dalek::{pkcs8::DecodePublicKey, VerifyingKey}; +use sha2::{Digest, Sha512}; +use wisp_mux::extensions::cert::VerifyKey; + +pub async fn get_certificates_from_paths(paths: Vec) -> anyhow::Result> { + let mut out = Vec::new(); + for path in paths { + let data = tokio::fs::read_to_string(path).await?; + let verifier = VerifyingKey::from_public_key_pem(&data)?; + let binary_key = verifier.to_bytes(); + + let mut hasher = Sha512::new(); + hasher.update(binary_key); + let hash: [u8; 64] = hasher.finalize().into(); + out.push(VerifyKey::new_ed25519(Arc::new(verifier), hash)); + } + Ok(out) +} diff --git a/server/src/main.rs b/server/src/main.rs index 3f20eab..25b6e46 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,5 @@ #![feature(ip)] +#![deny(clippy::todo)] use std::{fmt::Write, fs::read_to_string}; @@ -59,7 +60,7 @@ fn format_stream_type(stream_type: StreamType) -> &'static str { StreamType::Tcp => "tcp", StreamType::Udp => "udp", #[cfg(feature = "twisp")] - StreamType::Unknown(crate::handle::twisp::STREAM_TYPE) => "twisp", + StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => "twisp", StreamType::Unknown(_) => unreachable!(), } } @@ -183,7 +184,7 @@ async fn main() -> anyhow::Result<()> { .parse_default_env() .init(); - validate_config_cache(); + validate_config_cache().await; info!( "listening on {:?} with socket type {:?} and socket transport {:?}", diff --git a/server/src/stream.rs b/server/src/stream.rs index 8f79dc2..d6f4f46 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -56,7 +56,7 @@ impl ClientStream { cfg_if! { if #[cfg(feature = "twisp")] { if let StreamType::Unknown(ty) = packet.stream_type { - if ty == crate::handle::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 { + if ty == crate::handle::wisp::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 { return Ok(ResolvedPacket::Valid(packet)); } else { return Ok(ResolvedPacket::Invalid); @@ -185,7 +185,7 @@ impl ClientStream { Ok(ClientStream::Udp(stream)) } #[cfg(feature = "twisp")] - StreamType::Unknown(crate::handle::twisp::STREAM_TYPE) => { + StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => { if !CONFIG.stream.allow_twisp { return Ok(ClientStream::Blocked); } diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index f1dfcef..005b534 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -69,12 +69,12 @@ pub struct VerifyKey { /// SHA-512 hash of the public key. pub hash: [u8; 64], /// Verifier. - pub verifier: Arc>, + pub verifier: Arc + Sync + Send>, } impl VerifyKey { /// Create a new ED25519 verification key. - pub fn new_ed25519(verifier: Arc>, hash: [u8; 64]) -> Self { + pub fn new_ed25519(verifier: Arc + Sync + Send>, hash: [u8; 64]) -> Self { Self { cert_type: SupportedCertificateTypes::Ed25519, hash, @@ -91,11 +91,11 @@ pub struct SigningKey { /// SHA-512 hash of the public key. pub hash: [u8; 64], /// Signer. - pub signer: Arc>, + pub signer: Arc + Sync + Send>, } impl SigningKey { /// Create a new ED25519 signing key. - pub fn new_ed25519(signer: Arc>, hash: [u8; 64]) -> Self { + pub fn new_ed25519(signer: Arc + Sync + Send>, hash: [u8; 64]) -> Self { Self { cert_type: SupportedCertificateTypes::Ed25519, hash, @@ -234,6 +234,18 @@ pub enum CertAuthProtocolExtensionBuilder { }, } +impl CertAuthProtocolExtensionBuilder { + /// Create a new server variant of the certificate authentication protocol extension. + pub fn new_server(verifiers: Vec) -> Self { + Self::ServerBeforeChallenge { verifiers } + } + + /// Create a new client variant of the certificate authentication protocol extension. + pub fn new_client(signer: SigningKey) -> Self { + Self::ClientBeforeChallenge { signer } + } +} + #[async_trait] impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { fn get_id(&self) -> u8 {