diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index fc4bcf0..f1dfcef 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -20,10 +20,6 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; /// Certificate authentication protocol extension error. #[derive(Debug)] pub enum CertAuthError { - /// Invalid or unsupported certificate type - InvalidCertType, - /// Invalid signature - InvalidSignature, /// ED25519 error Ed25519(ed25519::Error), /// Getrandom error @@ -33,8 +29,6 @@ pub enum CertAuthError { impl std::fmt::Display for CertAuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::InvalidCertType => write!(f, "Invalid or unsupported certificate type"), - Self::InvalidSignature => write!(f, "Invalid signature"), Self::Ed25519(x) => write!(f, "ED25519: {:?}", x), Self::Getrandom(x) => write!(f, "getrandom: {:?}", x), } @@ -78,6 +72,17 @@ pub struct VerifyKey { pub verifier: Arc>, } +impl VerifyKey { + /// Create a new ED25519 verification key. + pub fn new_ed25519(verifier: Arc>, hash: [u8; 64]) -> Self { + Self { + cert_type: SupportedCertificateTypes::Ed25519, + hash, + verifier, + } + } +} + /// Signing key. #[derive(Clone)] pub struct SigningKey { @@ -88,6 +93,16 @@ pub struct SigningKey { /// Signer. pub signer: Arc>, } +impl SigningKey { + /// Create a new ED25519 signing key. + pub fn new_ed25519(signer: Arc>, hash: [u8; 64]) -> Self { + Self { + cert_type: SupportedCertificateTypes::Ed25519, + hash, + signer, + } + } +} /// Certificate authentication protocol extension. #[derive(Debug, Clone)] @@ -241,7 +256,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { } => { // validate and parse response let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8()) - .ok_or(CertAuthError::InvalidCertType)?; + .ok_or(WispError::CertAuthExtensionSigInvalid)?; let hash = bytes.split_to(64); let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?; let is_valid = verifiers @@ -252,15 +267,15 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { if is_valid { Ok(CertAuthProtocolExtension::ServerVerified.into()) } else { - Err(CertAuthError::InvalidSignature.into()) + Err(WispError::CertAuthExtensionSigInvalid) } } Self::ClientBeforeChallenge { signer } => { // sign challenge let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8()) - .ok_or(CertAuthError::InvalidCertType)?; + .ok_or(WispError::CertAuthExtensionSigInvalid)?; if !cert_types.iter().any(|x| x == signer.cert_type) { - return Err(CertAuthError::InvalidCertType.into()); + return Err(WispError::CertAuthExtensionSigInvalid); } let signed: Bytes = signer diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index 6d6b6ca..c20e01b 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -4,7 +4,7 @@ //! //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) -use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; +use std::collections::HashMap; use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -118,38 +118,6 @@ impl ProtocolExtension for PasswordProtocolExtension { } } -#[derive(Debug)] -enum PasswordProtocolExtensionError { - Utf8Error(FromUtf8Error), - InvalidUsername, - InvalidPassword, -} - -impl Display for PasswordProtocolExtensionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use PasswordProtocolExtensionError as E; - match self { - E::Utf8Error(e) => write!(f, "{}", e), - E::InvalidUsername => write!(f, "Invalid username"), - E::InvalidPassword => write!(f, "Invalid password"), - } - } -} - -impl Error for PasswordProtocolExtensionError {} - -impl From for WispError { - fn from(value: PasswordProtocolExtensionError) -> Self { - WispError::ExtensionImplError(Box::new(value)) - } -} - -impl From for PasswordProtocolExtensionError { - fn from(value: FromUtf8Error) -> Self { - PasswordProtocolExtensionError::Utf8Error(value) - } -} - impl From for AnyProtocolExtension { fn from(value: PasswordProtocolExtension) -> Self { AnyProtocolExtension(Box::new(value)) @@ -212,20 +180,17 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { return Err(WispError::PacketTooSmall); } - use PasswordProtocolExtensionError as EError; let username = - String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; + std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string(); let password = - String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; + std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string(); let Some(user) = self.users.iter().find(|x| *x.0 == username) else { - return Err(EError::InvalidUsername.into()); + return Err(WispError::PasswordExtensionCredsInvalid); }; if *user.1 != password { - return Err(EError::InvalidPassword.into()); + return Err(WispError::PasswordExtensionCredsInvalid); } Ok(PasswordProtocolExtension { diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 4e2932c..0b81008 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -68,6 +68,7 @@ pub enum WispError { IncompatibleProtocolVersion, /// The stream had already been closed. StreamAlreadyClosed, + /// The websocket frame received had an invalid type. WsFrameInvalidType, /// The websocket frame received was not finished. @@ -78,18 +79,14 @@ pub enum WispError { WsImplSocketClosed, /// The websocket implementation did not support the action. WsImplNotSupported, - /// Error specific to the protocol extension implementation. - ExtensionImplError(Box), - /// The protocol extension implementation did not support the action. - ExtensionImplNotSupported, - /// The specified protocol extensions are not supported by the server. - ExtensionsNotSupported(Vec), + /// The string was invalid UTF-8. Utf8Error(std::str::Utf8Error), /// The integer failed to convert. TryFromIntError(std::num::TryFromIntError), /// Other error. Other(Box), + /// Failed to send message to multiplexor task. MuxMessageFailedToSend, /// Failed to receive message from multiplexor task. @@ -98,6 +95,17 @@ pub enum WispError { MuxTaskEnded, /// Multiplexor task already started. MuxTaskStarted, + + /// Error specific to the protocol extension implementation. + ExtensionImplError(Box), + /// The protocol extension implementation did not support the action. + ExtensionImplNotSupported, + /// The specified protocol extensions are not supported by the server. + ExtensionsNotSupported(Vec), + /// The password authentication username/password was invalid. + PasswordExtensionCredsInvalid, + /// The certificate authentication signature was invalid. + CertAuthExtensionSigInvalid, } impl From for WispError { @@ -153,6 +161,12 @@ impl std::fmt::Display for WispError { Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), Self::MuxTaskStarted => write!(f, "Multiplexor task already started"), + Self::PasswordExtensionCredsInvalid => { + write!(f, "Password extension: Invalid username/password") + } + Self::CertAuthExtensionSigInvalid => { + write!(f, "Certificate authentication extension: Invalid signature") + } } } } @@ -243,36 +257,64 @@ impl ServerMux { W: ws::WebSocketWrite + Send + 'static, { let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); + let ret_tx = tx.clone(); + let ret = async { + tx.write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; - tx.write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; + let (supported_extensions, extra_packet, downgraded) = + if let Some(mut builders) = extension_builders { + send_info_packet(&tx, &mut builders).await?; + maybe_wisp_v2(&mut rx, &tx, &mut builders).await? + } else { + (Vec::new(), None, true) + }; - let (supported_extensions, extra_packet, downgraded) = - if let Some(mut builders) = extension_builders { - send_info_packet(&tx, &mut builders).await?; - maybe_wisp_v2(&mut rx, &tx, &mut builders).await? - } else { - (Vec::new(), None, true) - }; + let (mux_result, muxstream_recv) = MuxInner::new_server( + AppendingWebSocketRead(extra_packet, rx), + tx.clone(), + supported_extensions.clone(), + buffer_size, + ); - let (mux_result, muxstream_recv) = MuxInner::new_server( - AppendingWebSocketRead(extra_packet, rx), - tx.clone(), - supported_extensions.clone(), - buffer_size, - ); + Ok(ServerMuxResult( + Self { + muxstream_recv, + actor_tx: mux_result.actor_tx, + downgraded, + supported_extensions, + tx, + actor_exited: mux_result.actor_exited, + }, + mux_result.mux.into_future(), + )) + } + .await; - Ok(ServerMuxResult( - Self { - muxstream_recv, - actor_tx: mux_result.actor_tx, - downgraded, - supported_extensions, - tx, - actor_exited: mux_result.actor_exited, + match ret { + Ok(x) => Ok(x), + Err(x) => match x { + WispError::PasswordExtensionCredsInvalid => { + ret_tx + .write_frame( + Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(), + ) + .await?; + ret_tx.close().await?; + Err(x) + } + WispError::CertAuthExtensionSigInvalid => { + ret_tx + .write_frame( + Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(), + ) + .await?; + ret_tx.close().await?; + Err(x) + } + x => Err(x), }, - mux_result.mux.into_future(), - )) + } } /// Wait for a stream to be created. @@ -300,12 +342,11 @@ impl ServerMux { self.close_internal(None).await } - /// Close all streams and send an extension incompatibility error to the client. + /// Close all streams and send a close reason on stream ID 0. /// /// Also terminates the multiplexor future. - pub async fn close_extension_incompat(&self) -> Result<(), WispError> { - self.close_internal(Some(CloseReason::IncompatibleExtensions)) - .await + pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_internal(Some(reason)).await } /// Get a protocol extension stream for sending packets with stream id 0. @@ -346,18 +387,23 @@ where ) -> Result<(ServerMux, F), WispError> { let mut unsupported_extensions = Vec::new(); for extension in extensions { - if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { + if !self + .0 + .supported_extensions + .iter() + .any(|x| x.get_id() == *extension) + { unsupported_extensions.push(*extension); } } if unsupported_extensions.is_empty() { Ok((self.0, self.1)) } else { - self.0.close_extension_incompat().await?; + self.0 + .close_with_reason(CloseReason::ExtensionsIncompatible) + .await?; self.1.await?; - Err(WispError::ExtensionsNotSupported( - unsupported_extensions, - )) + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) } } @@ -483,12 +529,11 @@ impl ClientMux { self.close_internal(None).await } - /// Close all streams and send an extension incompatibility error to the client. + /// Close all streams and send a close reason on stream ID 0. /// /// Also terminates the multiplexor future. - pub async fn close_extension_incompat(&self) -> Result<(), WispError> { - self.close_internal(Some(CloseReason::IncompatibleExtensions)) - .await + pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_internal(Some(reason)).await } /// Get a protocol extension stream for sending packets with stream id 0. @@ -528,14 +573,21 @@ where ) -> Result<(ClientMux, F), WispError> { let mut unsupported_extensions = Vec::new(); for extension in extensions { - if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { + if !self + .0 + .supported_extensions + .iter() + .any(|x| x.get_id() == *extension) + { unsupported_extensions.push(*extension); } } if unsupported_extensions.is_empty() { Ok((self.0, self.1)) } else { - self.0.close_extension_incompat().await?; + self.0 + .close_with_reason(CloseReason::ExtensionsIncompatible) + .await?; self.1.await?; Err(WispError::ExtensionsNotSupported(unsupported_extensions)) } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 0463f8f..86dfc56 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -60,7 +60,7 @@ mod close { /// Unexpected stream closure due to a network error. Unexpected = 0x03, /// Incompatible extensions. Only used during the handshake. - IncompatibleExtensions = 0x04, + ExtensionsIncompatible = 0x04, /// Stream creation failed due to invalid information. ServerStreamInvalidInfo = 0x41, /// Stream creation failed due to an unreachable destination host. @@ -77,6 +77,10 @@ mod close { ServerStreamThrottled = 0x49, /// The client has encountered an unexpected error. ClientUnexpected = 0x81, + /// Authentication failed due to invalid username/password. + ExtensionsPasswordAuthFailed = 0xc0, + /// Authentication failed due to invalid signature. + ExtensionsCertAuthFailed = 0xc1, } impl TryFrom for CloseReason { @@ -87,7 +91,7 @@ mod close { 0x01 => Ok(R::Unknown), 0x02 => Ok(R::Voluntary), 0x03 => Ok(R::Unexpected), - 0x04 => Ok(R::IncompatibleExtensions), + 0x04 => Ok(R::ExtensionsIncompatible), 0x41 => Ok(R::ServerStreamInvalidInfo), 0x42 => Ok(R::ServerStreamUnreachable), 0x43 => Ok(R::ServerStreamConnectionTimedOut), @@ -111,7 +115,7 @@ mod close { C::Unknown => "Unknown close reason", C::Voluntary => "Voluntarily closed", C::Unexpected => "Unexpectedly closed", - C::IncompatibleExtensions => "Incompatible protocol extensions", + C::ExtensionsIncompatible => "Incompatible protocol extensions", C::ServerStreamInvalidInfo => "Stream creation failed due to invalid information", C::ServerStreamUnreachable => @@ -124,6 +128,8 @@ mod close { C::ServerStreamBlockedAddress => "Destination address is blocked", C::ServerStreamThrottled => "Throttled", C::ClientUnexpected => "Client encountered unexpected error", + C::ExtensionsPasswordAuthFailed => "Invalid username/password", + C::ExtensionsCertAuthFailed => "Invalid signature", } ) }