move pw and cert auth errors out

This commit is contained in:
Toshit Chawda 2024-09-14 17:11:22 -07:00
parent 694d87f731
commit 01ff6ee956
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 137 additions and 99 deletions

View file

@ -20,10 +20,6 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
/// Certificate authentication protocol extension error. /// Certificate authentication protocol extension error.
#[derive(Debug)] #[derive(Debug)]
pub enum CertAuthError { pub enum CertAuthError {
/// Invalid or unsupported certificate type
InvalidCertType,
/// Invalid signature
InvalidSignature,
/// ED25519 error /// ED25519 error
Ed25519(ed25519::Error), Ed25519(ed25519::Error),
/// Getrandom error /// Getrandom error
@ -33,8 +29,6 @@ pub enum CertAuthError {
impl std::fmt::Display for CertAuthError { impl std::fmt::Display for CertAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::InvalidCertType => write!(f, "Invalid or unsupported certificate type"),
Self::InvalidSignature => write!(f, "Invalid signature"),
Self::Ed25519(x) => write!(f, "ED25519: {:?}", x), Self::Ed25519(x) => write!(f, "ED25519: {:?}", x),
Self::Getrandom(x) => write!(f, "getrandom: {:?}", x), Self::Getrandom(x) => write!(f, "getrandom: {:?}", x),
} }
@ -78,6 +72,17 @@ pub struct VerifyKey {
pub verifier: Arc<dyn Verifier<Signature>>, pub verifier: Arc<dyn Verifier<Signature>>,
} }
impl VerifyKey {
/// Create a new ED25519 verification key.
pub fn new_ed25519(verifier: Arc<dyn Verifier<Signature>>, hash: [u8; 64]) -> Self {
Self {
cert_type: SupportedCertificateTypes::Ed25519,
hash,
verifier,
}
}
}
/// Signing key. /// Signing key.
#[derive(Clone)] #[derive(Clone)]
pub struct SigningKey { pub struct SigningKey {
@ -88,6 +93,16 @@ pub struct SigningKey {
/// Signer. /// Signer.
pub signer: Arc<dyn Signer<Signature>>, pub signer: Arc<dyn Signer<Signature>>,
} }
impl SigningKey {
/// Create a new ED25519 signing key.
pub fn new_ed25519(signer: Arc<dyn Signer<Signature>>, hash: [u8; 64]) -> Self {
Self {
cert_type: SupportedCertificateTypes::Ed25519,
hash,
signer,
}
}
}
/// Certificate authentication protocol extension. /// Certificate authentication protocol extension.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -241,7 +256,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
} => { } => {
// validate and parse response // validate and parse response
let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8()) let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8())
.ok_or(CertAuthError::InvalidCertType)?; .ok_or(WispError::CertAuthExtensionSigInvalid)?;
let hash = bytes.split_to(64); let hash = bytes.split_to(64);
let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?; let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?;
let is_valid = verifiers let is_valid = verifiers
@ -252,15 +267,15 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
if is_valid { if is_valid {
Ok(CertAuthProtocolExtension::ServerVerified.into()) Ok(CertAuthProtocolExtension::ServerVerified.into())
} else { } else {
Err(CertAuthError::InvalidSignature.into()) Err(WispError::CertAuthExtensionSigInvalid)
} }
} }
Self::ClientBeforeChallenge { signer } => { Self::ClientBeforeChallenge { signer } => {
// sign challenge // sign challenge
let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8()) let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8())
.ok_or(CertAuthError::InvalidCertType)?; .ok_or(WispError::CertAuthExtensionSigInvalid)?;
if !cert_types.iter().any(|x| x == signer.cert_type) { if !cert_types.iter().any(|x| x == signer.cert_type) {
return Err(CertAuthError::InvalidCertType.into()); return Err(WispError::CertAuthExtensionSigInvalid);
} }
let signed: Bytes = signer let signed: Bytes = signer

View file

@ -4,7 +4,7 @@
//! //!
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
@ -118,38 +118,6 @@ impl ProtocolExtension for PasswordProtocolExtension {
} }
} }
#[derive(Debug)]
enum PasswordProtocolExtensionError {
Utf8Error(FromUtf8Error),
InvalidUsername,
InvalidPassword,
}
impl Display for PasswordProtocolExtensionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use PasswordProtocolExtensionError as E;
match self {
E::Utf8Error(e) => write!(f, "{}", e),
E::InvalidUsername => write!(f, "Invalid username"),
E::InvalidPassword => write!(f, "Invalid password"),
}
}
}
impl Error for PasswordProtocolExtensionError {}
impl From<PasswordProtocolExtensionError> for WispError {
fn from(value: PasswordProtocolExtensionError) -> Self {
WispError::ExtensionImplError(Box::new(value))
}
}
impl From<FromUtf8Error> for PasswordProtocolExtensionError {
fn from(value: FromUtf8Error) -> Self {
PasswordProtocolExtensionError::Utf8Error(value)
}
}
impl From<PasswordProtocolExtension> for AnyProtocolExtension { impl From<PasswordProtocolExtension> for AnyProtocolExtension {
fn from(value: PasswordProtocolExtension) -> Self { fn from(value: PasswordProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value)) AnyProtocolExtension(Box::new(value))
@ -212,20 +180,17 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
use PasswordProtocolExtensionError as EError;
let username = let username =
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string();
.map_err(|x| WispError::from(EError::from(x)))?;
let password = let password =
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string();
.map_err(|x| WispError::from(EError::from(x)))?;
let Some(user) = self.users.iter().find(|x| *x.0 == username) else { let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
return Err(EError::InvalidUsername.into()); return Err(WispError::PasswordExtensionCredsInvalid);
}; };
if *user.1 != password { if *user.1 != password {
return Err(EError::InvalidPassword.into()); return Err(WispError::PasswordExtensionCredsInvalid);
} }
Ok(PasswordProtocolExtension { Ok(PasswordProtocolExtension {

View file

@ -68,6 +68,7 @@ pub enum WispError {
IncompatibleProtocolVersion, IncompatibleProtocolVersion,
/// The stream had already been closed. /// The stream had already been closed.
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame received had an invalid type. /// The websocket frame received had an invalid type.
WsFrameInvalidType, WsFrameInvalidType,
/// The websocket frame received was not finished. /// The websocket frame received was not finished.
@ -78,18 +79,14 @@ pub enum WispError {
WsImplSocketClosed, WsImplSocketClosed,
/// The websocket implementation did not support the action. /// The websocket implementation did not support the action.
WsImplNotSupported, WsImplNotSupported,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The string was invalid UTF-8. /// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
/// The integer failed to convert. /// The integer failed to convert.
TryFromIntError(std::num::TryFromIntError), TryFromIntError(std::num::TryFromIntError),
/// Other error. /// Other error.
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task. /// Failed to send message to multiplexor task.
MuxMessageFailedToSend, MuxMessageFailedToSend,
/// Failed to receive message from multiplexor task. /// Failed to receive message from multiplexor task.
@ -98,6 +95,17 @@ pub enum WispError {
MuxTaskEnded, MuxTaskEnded,
/// Multiplexor task already started. /// Multiplexor task already started.
MuxTaskStarted, MuxTaskStarted,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The password authentication username/password was invalid.
PasswordExtensionCredsInvalid,
/// The certificate authentication signature was invalid.
CertAuthExtensionSigInvalid,
} }
impl From<std::str::Utf8Error> for WispError { impl From<std::str::Utf8Error> for WispError {
@ -153,6 +161,12 @@ impl std::fmt::Display for WispError {
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"), Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
Self::PasswordExtensionCredsInvalid => {
write!(f, "Password extension: Invalid username/password")
}
Self::CertAuthExtensionSigInvalid => {
write!(f, "Certificate authentication extension: Invalid signature")
}
} }
} }
} }
@ -243,36 +257,64 @@ impl ServerMux {
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
{ {
let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone();
let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
tx.write_frame(Packet::new_continue(0, buffer_size).into()) let (supported_extensions, extra_packet, downgraded) =
.await?; if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (supported_extensions, extra_packet, downgraded) = let (mux_result, muxstream_recv) = MuxInner::new_server(
if let Some(mut builders) = extension_builders { AppendingWebSocketRead(extra_packet, rx),
send_info_packet(&tx, &mut builders).await?; tx.clone(),
maybe_wisp_v2(&mut rx, &tx, &mut builders).await? supported_extensions.clone(),
} else { buffer_size,
(Vec::new(), None, true) );
};
let (mux_result, muxstream_recv) = MuxInner::new_server( Ok(ServerMuxResult(
AppendingWebSocketRead(extra_packet, rx), Self {
tx.clone(), muxstream_recv,
supported_extensions.clone(), actor_tx: mux_result.actor_tx,
buffer_size, downgraded,
); supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
}
.await;
Ok(ServerMuxResult( match ret {
Self { Ok(x) => Ok(x),
muxstream_recv, Err(x) => match x {
actor_tx: mux_result.actor_tx, WispError::PasswordExtensionCredsInvalid => {
downgraded, ret_tx
supported_extensions, .write_frame(
tx, Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
actor_exited: mux_result.actor_exited, )
.await?;
ret_tx.close().await?;
Err(x)
}
WispError::CertAuthExtensionSigInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
x => Err(x),
}, },
mux_result.mux.into_future(), }
))
} }
/// Wait for a stream to be created. /// Wait for a stream to be created.
@ -300,12 +342,11 @@ impl ServerMux {
self.close_internal(None).await self.close_internal(None).await
} }
/// Close all streams and send an extension incompatibility error to the client. /// Close all streams and send a close reason on stream ID 0.
/// ///
/// Also terminates the multiplexor future. /// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> { pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions)) self.close_internal(Some(reason)).await
.await
} }
/// Get a protocol extension stream for sending packets with stream id 0. /// Get a protocol extension stream for sending packets with stream id 0.
@ -346,18 +387,23 @@ where
) -> Result<(ServerMux, F), WispError> { ) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new(); let mut unsupported_extensions = Vec::new();
for extension in extensions { for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension); unsupported_extensions.push(*extension);
} }
} }
if unsupported_extensions.is_empty() { if unsupported_extensions.is_empty() {
Ok((self.0, self.1)) Ok((self.0, self.1))
} else { } else {
self.0.close_extension_incompat().await?; self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?; self.1.await?;
Err(WispError::ExtensionsNotSupported( Err(WispError::ExtensionsNotSupported(unsupported_extensions))
unsupported_extensions,
))
} }
} }
@ -483,12 +529,11 @@ impl ClientMux {
self.close_internal(None).await self.close_internal(None).await
} }
/// Close all streams and send an extension incompatibility error to the client. /// Close all streams and send a close reason on stream ID 0.
/// ///
/// Also terminates the multiplexor future. /// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> { pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions)) self.close_internal(Some(reason)).await
.await
} }
/// Get a protocol extension stream for sending packets with stream id 0. /// Get a protocol extension stream for sending packets with stream id 0.
@ -528,14 +573,21 @@ where
) -> Result<(ClientMux, F), WispError> { ) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new(); let mut unsupported_extensions = Vec::new();
for extension in extensions { for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) { if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension); unsupported_extensions.push(*extension);
} }
} }
if unsupported_extensions.is_empty() { if unsupported_extensions.is_empty() {
Ok((self.0, self.1)) Ok((self.0, self.1))
} else { } else {
self.0.close_extension_incompat().await?; self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?; self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions)) Err(WispError::ExtensionsNotSupported(unsupported_extensions))
} }

View file

@ -60,7 +60,7 @@ mod close {
/// Unexpected stream closure due to a network error. /// Unexpected stream closure due to a network error.
Unexpected = 0x03, Unexpected = 0x03,
/// Incompatible extensions. Only used during the handshake. /// Incompatible extensions. Only used during the handshake.
IncompatibleExtensions = 0x04, ExtensionsIncompatible = 0x04,
/// Stream creation failed due to invalid information. /// Stream creation failed due to invalid information.
ServerStreamInvalidInfo = 0x41, ServerStreamInvalidInfo = 0x41,
/// Stream creation failed due to an unreachable destination host. /// Stream creation failed due to an unreachable destination host.
@ -77,6 +77,10 @@ mod close {
ServerStreamThrottled = 0x49, ServerStreamThrottled = 0x49,
/// The client has encountered an unexpected error. /// The client has encountered an unexpected error.
ClientUnexpected = 0x81, ClientUnexpected = 0x81,
/// Authentication failed due to invalid username/password.
ExtensionsPasswordAuthFailed = 0xc0,
/// Authentication failed due to invalid signature.
ExtensionsCertAuthFailed = 0xc1,
} }
impl TryFrom<u8> for CloseReason { impl TryFrom<u8> for CloseReason {
@ -87,7 +91,7 @@ mod close {
0x01 => Ok(R::Unknown), 0x01 => Ok(R::Unknown),
0x02 => Ok(R::Voluntary), 0x02 => Ok(R::Voluntary),
0x03 => Ok(R::Unexpected), 0x03 => Ok(R::Unexpected),
0x04 => Ok(R::IncompatibleExtensions), 0x04 => Ok(R::ExtensionsIncompatible),
0x41 => Ok(R::ServerStreamInvalidInfo), 0x41 => Ok(R::ServerStreamInvalidInfo),
0x42 => Ok(R::ServerStreamUnreachable), 0x42 => Ok(R::ServerStreamUnreachable),
0x43 => Ok(R::ServerStreamConnectionTimedOut), 0x43 => Ok(R::ServerStreamConnectionTimedOut),
@ -111,7 +115,7 @@ mod close {
C::Unknown => "Unknown close reason", C::Unknown => "Unknown close reason",
C::Voluntary => "Voluntarily closed", C::Voluntary => "Voluntarily closed",
C::Unexpected => "Unexpectedly closed", C::Unexpected => "Unexpectedly closed",
C::IncompatibleExtensions => "Incompatible protocol extensions", C::ExtensionsIncompatible => "Incompatible protocol extensions",
C::ServerStreamInvalidInfo => C::ServerStreamInvalidInfo =>
"Stream creation failed due to invalid information", "Stream creation failed due to invalid information",
C::ServerStreamUnreachable => C::ServerStreamUnreachable =>
@ -124,6 +128,8 @@ mod close {
C::ServerStreamBlockedAddress => "Destination address is blocked", C::ServerStreamBlockedAddress => "Destination address is blocked",
C::ServerStreamThrottled => "Throttled", C::ServerStreamThrottled => "Throttled",
C::ClientUnexpected => "Client encountered unexpected error", C::ClientUnexpected => "Client encountered unexpected error",
C::ExtensionsPasswordAuthFailed => "Invalid username/password",
C::ExtensionsCertAuthFailed => "Invalid signature",
} }
) )
} }