move pw and cert auth errors out

This commit is contained in:
Toshit Chawda 2024-09-14 17:11:22 -07:00
parent 694d87f731
commit 01ff6ee956
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 137 additions and 99 deletions

View file

@ -20,10 +20,6 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
/// Certificate authentication protocol extension error.
#[derive(Debug)]
pub enum CertAuthError {
/// Invalid or unsupported certificate type
InvalidCertType,
/// Invalid signature
InvalidSignature,
/// ED25519 error
Ed25519(ed25519::Error),
/// Getrandom error
@ -33,8 +29,6 @@ pub enum CertAuthError {
impl std::fmt::Display for CertAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidCertType => write!(f, "Invalid or unsupported certificate type"),
Self::InvalidSignature => write!(f, "Invalid signature"),
Self::Ed25519(x) => write!(f, "ED25519: {:?}", x),
Self::Getrandom(x) => write!(f, "getrandom: {:?}", x),
}
@ -78,6 +72,17 @@ pub struct VerifyKey {
pub verifier: Arc<dyn Verifier<Signature>>,
}
impl VerifyKey {
/// Create a new ED25519 verification key.
pub fn new_ed25519(verifier: Arc<dyn Verifier<Signature>>, hash: [u8; 64]) -> Self {
Self {
cert_type: SupportedCertificateTypes::Ed25519,
hash,
verifier,
}
}
}
/// Signing key.
#[derive(Clone)]
pub struct SigningKey {
@ -88,6 +93,16 @@ pub struct SigningKey {
/// Signer.
pub signer: Arc<dyn Signer<Signature>>,
}
impl SigningKey {
/// Create a new ED25519 signing key.
pub fn new_ed25519(signer: Arc<dyn Signer<Signature>>, hash: [u8; 64]) -> Self {
Self {
cert_type: SupportedCertificateTypes::Ed25519,
hash,
signer,
}
}
}
/// Certificate authentication protocol extension.
#[derive(Debug, Clone)]
@ -241,7 +256,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
} => {
// validate and parse response
let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8())
.ok_or(CertAuthError::InvalidCertType)?;
.ok_or(WispError::CertAuthExtensionSigInvalid)?;
let hash = bytes.split_to(64);
let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?;
let is_valid = verifiers
@ -252,15 +267,15 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
if is_valid {
Ok(CertAuthProtocolExtension::ServerVerified.into())
} else {
Err(CertAuthError::InvalidSignature.into())
Err(WispError::CertAuthExtensionSigInvalid)
}
}
Self::ClientBeforeChallenge { signer } => {
// sign challenge
let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8())
.ok_or(CertAuthError::InvalidCertType)?;
.ok_or(WispError::CertAuthExtensionSigInvalid)?;
if !cert_types.iter().any(|x| x == signer.cert_type) {
return Err(CertAuthError::InvalidCertType.into());
return Err(WispError::CertAuthExtensionSigInvalid);
}
let signed: Bytes = signer

View file

@ -4,7 +4,7 @@
//!
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
use std::collections::HashMap;
use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
@ -118,38 +118,6 @@ impl ProtocolExtension for PasswordProtocolExtension {
}
}
#[derive(Debug)]
enum PasswordProtocolExtensionError {
Utf8Error(FromUtf8Error),
InvalidUsername,
InvalidPassword,
}
impl Display for PasswordProtocolExtensionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use PasswordProtocolExtensionError as E;
match self {
E::Utf8Error(e) => write!(f, "{}", e),
E::InvalidUsername => write!(f, "Invalid username"),
E::InvalidPassword => write!(f, "Invalid password"),
}
}
}
impl Error for PasswordProtocolExtensionError {}
impl From<PasswordProtocolExtensionError> for WispError {
fn from(value: PasswordProtocolExtensionError) -> Self {
WispError::ExtensionImplError(Box::new(value))
}
}
impl From<FromUtf8Error> for PasswordProtocolExtensionError {
fn from(value: FromUtf8Error) -> Self {
PasswordProtocolExtensionError::Utf8Error(value)
}
}
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
fn from(value: PasswordProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
@ -212,20 +180,17 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
return Err(WispError::PacketTooSmall);
}
use PasswordProtocolExtensionError as EError;
let username =
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
.map_err(|x| WispError::from(EError::from(x)))?;
std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string();
let password =
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
.map_err(|x| WispError::from(EError::from(x)))?;
std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string();
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
return Err(EError::InvalidUsername.into());
return Err(WispError::PasswordExtensionCredsInvalid);
};
if *user.1 != password {
return Err(EError::InvalidPassword.into());
return Err(WispError::PasswordExtensionCredsInvalid);
}
Ok(PasswordProtocolExtension {

View file

@ -68,6 +68,7 @@ pub enum WispError {
IncompatibleProtocolVersion,
/// The stream had already been closed.
StreamAlreadyClosed,
/// The websocket frame received had an invalid type.
WsFrameInvalidType,
/// The websocket frame received was not finished.
@ -78,18 +79,14 @@ pub enum WispError {
WsImplSocketClosed,
/// The websocket implementation did not support the action.
WsImplNotSupported,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error),
/// The integer failed to convert.
TryFromIntError(std::num::TryFromIntError),
/// Other error.
Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task.
MuxMessageFailedToSend,
/// Failed to receive message from multiplexor task.
@ -98,6 +95,17 @@ pub enum WispError {
MuxTaskEnded,
/// Multiplexor task already started.
MuxTaskStarted,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The password authentication username/password was invalid.
PasswordExtensionCredsInvalid,
/// The certificate authentication signature was invalid.
CertAuthExtensionSigInvalid,
}
impl From<std::str::Utf8Error> for WispError {
@ -153,6 +161,12 @@ impl std::fmt::Display for WispError {
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
Self::PasswordExtensionCredsInvalid => {
write!(f, "Password extension: Invalid username/password")
}
Self::CertAuthExtensionSigInvalid => {
write!(f, "Certificate authentication extension: Invalid signature")
}
}
}
}
@ -243,36 +257,64 @@ impl ServerMux {
W: ws::WebSocketWrite + Send + 'static,
{
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone();
let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
}
.await;
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
match ret {
Ok(x) => Ok(x),
Err(x) => match x {
WispError::PasswordExtensionCredsInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
WispError::CertAuthExtensionSigInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
x => Err(x),
},
mux_result.mux.into_future(),
))
}
}
/// Wait for a stream to be created.
@ -300,12 +342,11 @@ impl ServerMux {
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
@ -346,18 +387,23 @@ where
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(
unsupported_extensions,
))
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
@ -483,12 +529,11 @@ impl ClientMux {
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
@ -528,14 +573,21 @@ where
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}

View file

@ -60,7 +60,7 @@ mod close {
/// Unexpected stream closure due to a network error.
Unexpected = 0x03,
/// Incompatible extensions. Only used during the handshake.
IncompatibleExtensions = 0x04,
ExtensionsIncompatible = 0x04,
/// Stream creation failed due to invalid information.
ServerStreamInvalidInfo = 0x41,
/// Stream creation failed due to an unreachable destination host.
@ -77,6 +77,10 @@ mod close {
ServerStreamThrottled = 0x49,
/// The client has encountered an unexpected error.
ClientUnexpected = 0x81,
/// Authentication failed due to invalid username/password.
ExtensionsPasswordAuthFailed = 0xc0,
/// Authentication failed due to invalid signature.
ExtensionsCertAuthFailed = 0xc1,
}
impl TryFrom<u8> for CloseReason {
@ -87,7 +91,7 @@ mod close {
0x01 => Ok(R::Unknown),
0x02 => Ok(R::Voluntary),
0x03 => Ok(R::Unexpected),
0x04 => Ok(R::IncompatibleExtensions),
0x04 => Ok(R::ExtensionsIncompatible),
0x41 => Ok(R::ServerStreamInvalidInfo),
0x42 => Ok(R::ServerStreamUnreachable),
0x43 => Ok(R::ServerStreamConnectionTimedOut),
@ -111,7 +115,7 @@ mod close {
C::Unknown => "Unknown close reason",
C::Voluntary => "Voluntarily closed",
C::Unexpected => "Unexpectedly closed",
C::IncompatibleExtensions => "Incompatible protocol extensions",
C::ExtensionsIncompatible => "Incompatible protocol extensions",
C::ServerStreamInvalidInfo =>
"Stream creation failed due to invalid information",
C::ServerStreamUnreachable =>
@ -124,6 +128,8 @@ mod close {
C::ServerStreamBlockedAddress => "Destination address is blocked",
C::ServerStreamThrottled => "Throttled",
C::ClientUnexpected => "Client encountered unexpected error",
C::ExtensionsPasswordAuthFailed => "Invalid username/password",
C::ExtensionsCertAuthFailed => "Invalid signature",
}
)
}