finish server side cert auth and motd

This commit is contained in:
Toshit Chawda 2024-09-14 17:47:16 -07:00
parent 01ff6ee956
commit 577ce71b89
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 199 additions and 38 deletions

View file

@ -10,6 +10,7 @@ bytes = "1.7.1"
cfg-if = "1.0.0"
clap = { version = "4.5.16", features = ["cargo", "derive"] }
dashmap = "6.0.1"
ed25519-dalek = { version = "2.1.1", features = ["pem"] }
env_logger = "0.11.5"
event-listener = "5.3.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
@ -27,6 +28,7 @@ regex = "1.10.6"
serde = { version = "1.0.208", features = ["derive"] }
serde_json = { version = "1.0.125", optional = true }
serde_yaml = { version = "0.9.34", optional = true }
sha2 = "0.10.8"
shell-words = { version = "1.1.0", optional = true }
tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] }
tikv-jemallocator = "0.6.0"
@ -34,7 +36,7 @@ tokio = { version = "1.39.3", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] }
toml = { version = "0.8.19", optional = true }
uuid = { version = "1.10.0", features = ["v4"] }
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream"] }
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] }
[features]
default = ["toml"]

View file

@ -7,11 +7,14 @@ use log::LevelFilter;
use regex::RegexSet;
use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{
password::PasswordProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
cert::CertAuthProtocolExtensionBuilder,
motd::MotdProtocolExtensionBuilder,
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
};
use crate::{CLI, CONFIG, RESOLVER};
use crate::{handle::wisp::utils::get_certificates_from_paths, CLI, CONFIG, RESOLVER};
#[derive(Serialize, Deserialize, Default, Debug)]
#[serde(rename_all = "lowercase")]
@ -75,13 +78,22 @@ pub struct ServerConfig {
pub log_level: LevelFilter,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ProtocolExtension {
/// Wisp draft version 2 UDP protocol extension.
Udp,
/// Wisp draft version 2 password protocol extension.
/// Wisp draft version 2 MOTD protocol extension.
Motd,
}
#[derive(Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ProtocolExtensionAuth {
/// Wisp draft version 2 password authentication protocol extension.
Password,
/// Wisp draft version 2 certificate authentication protocol extension.
Certificate,
}
#[derive(Serialize, Deserialize)]
@ -96,8 +108,16 @@ pub struct WispConfig {
pub wisp_v2: bool,
/// Wisp draft version 2 extensions advertised.
pub extensions: Vec<ProtocolExtension>,
/// Wisp draft version 2 password extension username/passwords.
/// Wisp draft version 2 authentication extension advertised.
pub auth_extension: Option<ProtocolExtensionAuth>,
/// Wisp draft version 2 password authentication extension username/passwords.
pub password_extension_users: HashMap<String, String>,
/// Wisp draft version 2 certificate authentication extension public ed25519 keys.
pub certificate_extension_keys: Vec<PathBuf>,
/// Wisp draft version 2 MOTD extension message.
pub motd_extension: String,
}
#[derive(Serialize, Deserialize)]
@ -202,11 +222,11 @@ lazy_static! {
};
}
pub fn validate_config_cache() {
pub async fn validate_config_cache() {
// constructs regexes
let _ = CONFIG_CACHE.allowed_ports;
// constructs wisp config
CONFIG.wisp.to_opts().unwrap();
CONFIG.wisp.to_opts().await.unwrap();
// constructs resolver
RESOLVER.clear_cache();
}
@ -244,29 +264,53 @@ impl Default for WispConfig {
wisp_v2: false,
extensions: vec![ProtocolExtension::Udp],
auth_extension: None,
password_extension_users: HashMap::new(),
certificate_extension_keys: Vec::new(),
motd_extension: String::new(),
}
}
}
impl WispConfig {
pub fn to_opts(&self) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
pub async fn to_opts(
&self,
) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, Vec<u8>, u32)> {
if self.wisp_v2 {
let mut extensions: Vec<AnyProtocolExtensionBuilder> = Vec::new();
let mut required_extensions: Vec<u8> = Vec::new();
if self.extensions.contains(&ProtocolExtension::Udp) {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
}
if self.extensions.contains(&ProtocolExtension::Password) {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(),
if self.extensions.contains(&ProtocolExtension::Motd) {
extensions.push(Box::new(MotdProtocolExtensionBuilder::Server(
self.motd_extension.clone(),
)));
}
Ok((Some(extensions), self.buffer_size))
match self.auth_extension {
Some(ProtocolExtensionAuth::Password) => {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(),
)));
required_extensions.push(PasswordProtocolExtension::ID);
}
Some(ProtocolExtensionAuth::Certificate) => {
extensions.push(Box::new(CertAuthProtocolExtensionBuilder::new_server(
get_certificates_from_paths(self.certificate_extension_keys.clone())
.await?,
)));
}
None => {}
}
Ok((Some(extensions), required_extensions, self.buffer_size))
} else {
Ok((None, self.buffer_size))
Ok((None, Vec::new(), self.buffer_size))
}
}
}
@ -370,7 +414,7 @@ impl Config {
}
}
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, ValueEnum)]
#[derive(Clone, Copy, Eq, PartialEq, ValueEnum)]
pub enum ConfigFormat {
#[cfg(feature = "toml")]
Toml,

View file

@ -1,7 +1,5 @@
#[cfg(feature = "twisp")]
pub mod twisp;
mod wisp;
mod wsproxy;
pub mod wisp;
pub mod wsproxy;
pub use wisp::handle_wisp;
pub use wsproxy::handle_wsproxy;

View file

@ -1,3 +1,7 @@
#[cfg(feature = "twisp")]
pub mod twisp;
pub mod utils;
use std::sync::Arc;
use anyhow::Context;
@ -64,7 +68,7 @@ async fn handle_stream(
muxstream: MuxStream,
id: String,
event: Arc<Event>,
#[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap,
#[cfg(feature = "twisp")] twisp_map: twisp::TwispMap,
) {
let requested_stream = connect.clone();
@ -175,7 +179,7 @@ async fn handle_stream(
let id = muxstream.stream_id;
let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split();
match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd)
match twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd)
.await
{
Ok(()) => {
@ -213,12 +217,12 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
let (read, write) = stream;
cfg_if! {
if #[cfg(feature = "twisp")] {
let twisp_map = super::twisp::new_map();
let (extensions, buffer_size) = CONFIG.wisp.to_opts()?;
let twisp_map = twisp::new_map();
let (extensions, required_extensions, buffer_size) = CONFIG.wisp.to_opts().await?;
let extensions = match extensions {
Some(mut exts) => {
exts.push(super::twisp::new_ext(twisp_map.clone()));
exts.push(twisp::new_ext(twisp_map.clone()));
Some(exts)
},
None => {
@ -226,18 +230,23 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
}
};
} else {
let (extensions, buffer_size) = CONFIG.wisp.to_opts()?;
let (extensions, required_extensions, buffer_size) = CONFIG.wisp.to_opts().await?;
}
}
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions)
.await
.context("failed to create server multiplexor")?
.with_no_required_extensions();
.with_required_extensions(&required_extensions)
.await?;
debug!(
"new wisp client id {:?} connected with extensions {:?}",
id, mux.supported_extensions.iter().map(|x| x.get_id()).collect::<Vec<_>>()
id,
mux.supported_extensions
.iter()
.map(|x| x.get_id())
.collect::<Vec<_>>()
);
let mut set: JoinSet<()> = JoinSet::new();

View file

@ -91,15 +91,15 @@ impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder {
}
fn build_from_bytes(
&self,
&mut self,
_: Bytes,
_: wisp_mux::Role,
) -> std::result::Result<AnyProtocolExtension, WispError> {
Ok(TWispServerProtocolExtension(self.0.clone()).into())
}
fn build_to_extension(&self, _: wisp_mux::Role) -> AnyProtocolExtension {
TWispServerProtocolExtension(self.0.clone()).into()
fn build_to_extension(&mut self, _: wisp_mux::Role) -> Result<AnyProtocolExtension, WispError> {
Ok(TWispServerProtocolExtension(self.0.clone()).into())
}
}

View file

@ -0,0 +1,20 @@
use std::{path::PathBuf, sync::Arc};
use ed25519_dalek::{pkcs8::DecodePublicKey, VerifyingKey};
use sha2::{Digest, Sha512};
use wisp_mux::extensions::cert::VerifyKey;
pub async fn get_certificates_from_paths(paths: Vec<PathBuf>) -> anyhow::Result<Vec<VerifyKey>> {
let mut out = Vec::new();
for path in paths {
let data = tokio::fs::read_to_string(path).await?;
let verifier = VerifyingKey::from_public_key_pem(&data)?;
let binary_key = verifier.to_bytes();
let mut hasher = Sha512::new();
hasher.update(binary_key);
let hash: [u8; 64] = hasher.finalize().into();
out.push(VerifyKey::new_ed25519(Arc::new(verifier), hash));
}
Ok(out)
}

View file

@ -1,4 +1,5 @@
#![feature(ip)]
#![deny(clippy::todo)]
use std::{fmt::Write, fs::read_to_string};
@ -59,7 +60,7 @@ fn format_stream_type(stream_type: StreamType) -> &'static str {
StreamType::Tcp => "tcp",
StreamType::Udp => "udp",
#[cfg(feature = "twisp")]
StreamType::Unknown(crate::handle::twisp::STREAM_TYPE) => "twisp",
StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => "twisp",
StreamType::Unknown(_) => unreachable!(),
}
}
@ -183,7 +184,7 @@ async fn main() -> anyhow::Result<()> {
.parse_default_env()
.init();
validate_config_cache();
validate_config_cache().await;
info!(
"listening on {:?} with socket type {:?} and socket transport {:?}",

View file

@ -56,7 +56,7 @@ impl ClientStream {
cfg_if! {
if #[cfg(feature = "twisp")] {
if let StreamType::Unknown(ty) = packet.stream_type {
if ty == crate::handle::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 {
if ty == crate::handle::wisp::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 {
return Ok(ResolvedPacket::Valid(packet));
} else {
return Ok(ResolvedPacket::Invalid);
@ -185,7 +185,7 @@ impl ClientStream {
Ok(ClientStream::Udp(stream))
}
#[cfg(feature = "twisp")]
StreamType::Unknown(crate::handle::twisp::STREAM_TYPE) => {
StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => {
if !CONFIG.stream.allow_twisp {
return Ok(ClientStream::Blocked);
}