diff --git a/Cargo.lock b/Cargo.lock index d301bbe..424984c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1952,12 +1952,14 @@ dependencies = [ "bytes", "clap", "console-subscriber", + "ed25519-dalek", "fastwebsockets", "futures", "http-body-util", "humantime", "hyper", "hyper-util", + "sha2", "simple_moving_average", "tokio", "wisp-mux", diff --git a/server/src/config.rs b/server/src/config.rs index 63374c6..c8ff5e1 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -7,7 +7,7 @@ use log::LevelFilter; use regex::RegexSet; use serde::{Deserialize, Serialize}; use wisp_mux::extensions::{ - cert::CertAuthProtocolExtensionBuilder, + cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder}, motd::MotdProtocolExtensionBuilder, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, udp::UdpProtocolExtensionBuilder, @@ -304,6 +304,7 @@ impl WispConfig { get_certificates_from_paths(self.certificate_extension_keys.clone()) .await?, ))); + required_extensions.push(CertAuthProtocolExtension::ID); } None => {} } diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index 8dbe507..ed4e999 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -179,9 +179,7 @@ async fn handle_stream( let id = muxstream.stream_id; let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); - match twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) - .await - { + match twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd).await { Ok(()) => { let _ = closer.close(CloseReason::Voluntary).await; } diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index 73beca2..646485c 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -8,12 +8,14 @@ atomic-counter = "1.0.1" bytes = "1.7.1" clap = { version = "4.5.16", features = ["cargo", "derive"] } console-subscriber = { version = "0.4.0", optional = true } +ed25519-dalek = { version = "2.1.1", features = ["pem"] } fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } futures = "0.3.30" http-body-util = "0.1.2" humantime = "2.1.0" hyper = { version = "1.4.1", features = ["http1", "client"] } hyper-util = { version = "0.1.7", features = ["tokio"] } +sha2 = "0.10.8" simple_moving_average = "1.0.2" tokio = { version = "1.39.3", features = ["full"] } wisp-mux = { path = "../wisp", features = ["fastwebsockets"]} diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 870e828..48b0d49 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -1,6 +1,7 @@ use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; use clap::Parser; +use ed25519_dalek::pkcs8::DecodePrivateKey; use fastwebsockets::handshake; use futures::future::select_all; use http_body_util::Empty; @@ -10,12 +11,14 @@ use hyper::{ Request, Uri, }; use hyper_util::rt::TokioIo; +use sha2::{Digest, Sha512}; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ error::Error, future::Future, io::{stdout, Cursor, IsTerminal, Write}, net::SocketAddr, + path::PathBuf, process::{abort, exit}, sync::Arc, time::{Duration, Instant}, @@ -29,6 +32,8 @@ use tokio::{ }; use wisp_mux::{ extensions::{ + cert::{CertAuthProtocolExtension, CertAuthProtocolExtensionBuilder, SigningKey}, + motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ProtocolExtensionBuilder, @@ -92,11 +97,28 @@ struct Cli { /// Usernames and passwords are sent in plaintext!! #[arg(long)] auth: Option, + /// Enable certauth + #[arg(long)] + certauth: Option, + /// Enable motd parsing + #[arg(long)] + motd: bool, /// Make a Wisp V2 connection #[arg(long)] wisp_v2: bool, } +async fn get_cert(path: PathBuf) -> Result> { + let data = tokio::fs::read_to_string(path).await?; + let signer = ed25519_dalek::SigningKey::from_pkcs8_pem(&data)?; + let binary_key = signer.verifying_key().to_bytes(); + + let mut hasher = Sha512::new(); + hasher.update(binary_key); + let hash: [u8; 64] = hasher.finalize().into(); + Ok(SigningKey::new_ed25519(Arc::new(signer), hash)) +} + #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Box> { #[cfg(feature = "tokio-console")] @@ -153,10 +175,19 @@ async fn main() -> Result<(), Box> { extensions.push(Box::new(UdpProtocolExtensionBuilder)); extension_ids.push(UdpProtocolExtension::ID); } + if opts.motd { + extensions.push(Box::new(MotdProtocolExtensionBuilder::Client)); + } if let Some(auth) = auth { extensions.push(Box::new(auth)); extension_ids.push(PasswordProtocolExtension::ID); } + if let Some(certauth) = opts.certauth { + let key = get_cert(certauth).await?; + let extension = CertAuthProtocolExtensionBuilder::new_client(key); + extensions.push(Box::new(extension)); + extension_ids.push(CertAuthProtocolExtension::ID); + } let (mux, fut) = if !opts.wisp_v2 { ClientMux::create(rx, tx, None) @@ -169,9 +200,19 @@ async fn main() -> Result<(), Box> { .await? }; + let motd_extension = mux + .supported_extensions + .iter() + .find_map(|x| x.downcast_ref::()); + println!( - "connected and created ClientMux, was downgraded {}, extensions supported {:?}\n", - mux.downgraded, mux.supported_extension_ids + "connected and created ClientMux, was downgraded {}, extensions supported {:?}, motd {:?}\n\n", + mux.downgraded, + mux.supported_extensions + .iter() + .map(|x| x.get_id()) + .collect::>(), + motd_extension.map(|x| x.motd.clone()) ); let mut threads = Vec::with_capacity((opts.streams * 2) + 3); diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index 005b534..981ba16 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -74,7 +74,10 @@ pub struct VerifyKey { impl VerifyKey { /// Create a new ED25519 verification key. - pub fn new_ed25519(verifier: Arc + Sync + Send>, hash: [u8; 64]) -> Self { + pub fn new_ed25519( + verifier: Arc + Sync + Send>, + hash: [u8; 64], + ) -> Self { Self { cert_type: SupportedCertificateTypes::Ed25519, hash, @@ -314,9 +317,9 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { fn build_to_extension(&mut self, _: Role) -> Result { match self { Self::ServerBeforeChallenge { verifiers } => { - let mut challenge = BytesMut::with_capacity(64); + let mut challenge = [0u8; 64]; getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?; - let challenge = challenge.freeze(); + let challenge = Bytes::from(challenge.to_vec()); *self = Self::ServerAfterChallenge { verifiers: verifiers.to_vec(), diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 6c1347a..8f966df 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -32,6 +32,16 @@ impl AnyProtocolExtension { pub fn downcast(self) -> Result, Self> { self.0.__downcast().map_err(Self) } + + /// Downcast the protocol extension. + pub fn downcast_ref(&self) -> Option<&T> { + self.0.__downcast_ref() + } + + /// Downcast the protocol extension. + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.0.__downcast_mut() + } } impl Deref for AnyProtocolExtension { @@ -126,6 +136,22 @@ impl dyn ProtocolExtension { Err(self) } } + + fn __downcast_ref(&self) -> Option<&T> { + if self.__is::() { + unsafe { Some(&*(self as *const dyn ProtocolExtension as *const T)) } + } else { + None + } + } + + fn __downcast_mut(&mut self) -> Option<&mut T> { + if self.__is::() { + unsafe { Some(&mut *(self as *mut dyn ProtocolExtension as *mut T)) } + } else { + None + } + } } /// Trait to build a Wisp protocol extension from a payload.