use thiserror and add new close reason

This commit is contained in:
Toshit Chawda 2024-10-23 22:49:46 -07:00
parent 065de8e85f
commit 65a7904437
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 78 additions and 96 deletions

9
Cargo.lock generated
View file

@ -2210,18 +2210,18 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.64" version = "1.0.65"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.64" version = "1.0.65"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -3015,6 +3015,7 @@ dependencies = [
"getrandom", "getrandom",
"nohash-hasher", "nohash-hasher",
"pin-project-lite", "pin-project-lite",
"thiserror",
"tokio", "tokio",
] ]

View file

@ -24,6 +24,7 @@ futures-timer = "3.0.3"
getrandom = { version = "0.2.15", features = ["std"], optional = true } getrandom = { version = "0.2.15", features = ["std"], optional = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
pin-project-lite = "0.2.14" pin-project-lite = "0.2.14"
thiserror = "1.0.65"
tokio = { version = "1.39.3", optional = true, default-features = false } tokio = { version = "1.39.3", optional = true, default-features = false }
[features] [features]

View file

@ -33,6 +33,7 @@ use std::{
}, },
time::Duration, time::Duration,
}; };
use thiserror::Error;
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload}; use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
/// Wisp version supported by this crate. /// Wisp version supported by this crate.
@ -48,133 +49,86 @@ pub enum Role {
} }
/// Errors the Wisp implementation can return. /// Errors the Wisp implementation can return.
#[derive(Debug)] #[derive(Error, Debug)]
pub enum WispError { pub enum WispError {
/// The packet received did not have enough data. /// The packet received did not have enough data.
#[error("Packet too small")]
PacketTooSmall, PacketTooSmall,
/// The packet received had an invalid type. /// The packet received had an invalid type.
#[error("Invalid packet type")]
InvalidPacketType, InvalidPacketType,
/// The stream had an invalid ID. /// The stream had an invalid ID.
#[error("Invalid steam ID")]
InvalidStreamId, InvalidStreamId,
/// The close packet had an invalid reason. /// The close packet had an invalid reason.
#[error("Invalid close reason")]
InvalidCloseReason, InvalidCloseReason,
/// The URI received was invalid.
InvalidUri,
/// The URI received had no host.
UriHasNoHost,
/// The URI received had no port.
UriHasNoPort,
/// The max stream count was reached. /// The max stream count was reached.
#[error("Maximum stream count reached")]
MaxStreamCountReached, MaxStreamCountReached,
/// The Wisp protocol version was incompatible. /// The Wisp protocol version was incompatible.
IncompatibleProtocolVersion, #[error("Incompatible Wisp protocol version: found {0} but needed {1}")]
IncompatibleProtocolVersion(WispVersion, WispVersion),
/// The stream had already been closed. /// The stream had already been closed.
#[error("Stream already closed")]
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame received had an invalid type. /// The websocket frame received had an invalid type.
#[error("Invalid websocket frame type: {0:?}")]
WsFrameInvalidType(ws::OpCode), WsFrameInvalidType(ws::OpCode),
/// The websocket frame received was not finished. /// The websocket frame received was not finished.
#[error("Unfinished websocket frame")]
WsFrameNotFinished, WsFrameNotFinished,
/// Error specific to the websocket implementation. /// Error specific to the websocket implementation.
#[error("Websocket implementation error:")]
WsImplError(Box<dyn std::error::Error + Sync + Send>), WsImplError(Box<dyn std::error::Error + Sync + Send>),
/// The websocket implementation socket closed. /// The websocket implementation socket closed.
#[error("Websocket implementation error: socket closed")]
WsImplSocketClosed, WsImplSocketClosed,
/// The websocket implementation did not support the action. /// The websocket implementation did not support the action.
#[error("Websocket implementation error: not supported")]
WsImplNotSupported, WsImplNotSupported,
/// The string was invalid UTF-8. /// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error), #[error("UTF-8 error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
/// The integer failed to convert. /// The integer failed to convert.
TryFromIntError(std::num::TryFromIntError), #[error("Integer conversion error: {0}")]
TryFromIntError(#[from] std::num::TryFromIntError),
/// Other error. /// Other error.
#[error("Other: {0:?}")]
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task. /// Failed to send message to multiplexor task.
#[error("Failed to send multiplexor message")]
MuxMessageFailedToSend, MuxMessageFailedToSend,
/// Failed to receive message from multiplexor task. /// Failed to receive message from multiplexor task.
#[error("Failed to receive multiplexor message")]
MuxMessageFailedToRecv, MuxMessageFailedToRecv,
/// Multiplexor task ended. /// Multiplexor task ended.
#[error("Multiplexor task ended")]
MuxTaskEnded, MuxTaskEnded,
/// Multiplexor task already started. /// Multiplexor task already started.
#[error("Multiplexor task already started")]
MuxTaskStarted, MuxTaskStarted,
/// Error specific to the protocol extension implementation. /// Error specific to the protocol extension implementation.
#[error("Protocol extension implementation error: {0:?}")]
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>), ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action. /// The protocol extension implementation did not support the action.
#[error("Protocol extension implementation error: unsupported feature")]
ExtensionImplNotSupported, ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server. /// The specified protocol extensions are not supported by the other side.
#[error("Protocol extensions {0:?} not supported")]
ExtensionsNotSupported(Vec<u8>), ExtensionsNotSupported(Vec<u8>),
/// The password authentication username/password was invalid. /// The password authentication username/password was invalid.
#[error("Password protocol extension: Invalid username/password")]
PasswordExtensionCredsInvalid, PasswordExtensionCredsInvalid,
/// The certificate authentication signature was invalid. /// The certificate authentication signature was invalid.
#[error("Certificate authentication protocol extension: Invalid signature")]
CertAuthExtensionSigInvalid, CertAuthExtensionSigInvalid,
} }
impl From<std::str::Utf8Error> for WispError {
fn from(err: std::str::Utf8Error) -> Self {
Self::Utf8Error(err)
}
}
impl From<std::num::TryFromIntError> for WispError {
fn from(value: std::num::TryFromIntError) -> Self {
Self::TryFromIntError(value)
}
}
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Self::PacketTooSmall => write!(f, "Packet too small"),
Self::InvalidPacketType => write!(f, "Invalid packet type"),
Self::InvalidStreamId => write!(f, "Invalid stream id"),
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
Self::InvalidUri => write!(f, "Invalid URI"),
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType(ty) => write!(f, "Invalid websocket frame type: {:?}", ty),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
Self::WsImplSocketClosed => {
write!(f, "Websocket implementation error: websocket closed")
}
Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature")
}
Self::ExtensionImplError(err) => {
write!(f, "Protocol extension implementation error: {}", err)
}
Self::ExtensionImplNotSupported => {
write!(
f,
"Protocol extension implementation error: unsupported feature"
)
}
Self::ExtensionsNotSupported(list) => {
write!(f, "Protocol extensions {:?} not supported", list)
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
Self::PasswordExtensionCredsInvalid => {
write!(f, "Password extension: Invalid username/password")
}
Self::CertAuthExtensionSigInvalid => {
write!(f, "Certificate authentication extension: Invalid signature")
}
}
}
}
impl std::error::Error for WispError {}
async fn maybe_wisp_v2<R>( async fn maybe_wisp_v2<R>(
read: &mut R, read: &mut R,
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,

View file

@ -1,3 +1,5 @@
use std::fmt::Display;
use crate::{ use crate::{
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
@ -59,8 +61,9 @@ mod close {
Voluntary = 0x02, Voluntary = 0x02,
/// Unexpected stream closure due to a network error. /// Unexpected stream closure due to a network error.
Unexpected = 0x03, Unexpected = 0x03,
/// Incompatible extensions. Only used during the handshake. /// Incompatible extensions.
ExtensionsIncompatible = 0x04, ExtensionsIncompatible = 0x04,
/// Stream creation failed due to invalid information. /// Stream creation failed due to invalid information.
ServerStreamInvalidInfo = 0x41, ServerStreamInvalidInfo = 0x41,
/// Stream creation failed due to an unreachable destination host. /// Stream creation failed due to an unreachable destination host.
@ -75,31 +78,41 @@ mod close {
ServerStreamBlockedAddress = 0x48, ServerStreamBlockedAddress = 0x48,
/// Connection throttled by the server. /// Connection throttled by the server.
ServerStreamThrottled = 0x49, ServerStreamThrottled = 0x49,
/// The client has encountered an unexpected error.
/// The client has encountered an unexpected error and is unable to recieve any more data.
ClientUnexpected = 0x81, ClientUnexpected = 0x81,
/// Authentication failed due to invalid username/password. /// Authentication failed due to invalid username/password.
ExtensionsPasswordAuthFailed = 0xc0, ExtensionsPasswordAuthFailed = 0xc0,
/// Authentication failed due to invalid signature. /// Authentication failed due to invalid signature.
ExtensionsCertAuthFailed = 0xc1, ExtensionsCertAuthFailed = 0xc1,
/// Authentication required but the client did not provide credentials.
ExtensionsAuthRequired = 0xc2,
} }
impl TryFrom<u8> for CloseReason { impl TryFrom<u8> for CloseReason {
type Error = WispError; type Error = WispError;
fn try_from(close_reason: u8) -> Result<Self, Self::Error> { fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
use CloseReason as R;
match close_reason { match close_reason {
0x01 => Ok(R::Unknown), 0x01 => Ok(Self::Unknown),
0x02 => Ok(R::Voluntary), 0x02 => Ok(Self::Voluntary),
0x03 => Ok(R::Unexpected), 0x03 => Ok(Self::Unexpected),
0x04 => Ok(R::ExtensionsIncompatible), 0x04 => Ok(Self::ExtensionsIncompatible),
0x41 => Ok(R::ServerStreamInvalidInfo),
0x42 => Ok(R::ServerStreamUnreachable), 0x41 => Ok(Self::ServerStreamInvalidInfo),
0x43 => Ok(R::ServerStreamConnectionTimedOut), 0x42 => Ok(Self::ServerStreamUnreachable),
0x44 => Ok(R::ServerStreamConnectionRefused), 0x43 => Ok(Self::ServerStreamConnectionTimedOut),
0x47 => Ok(R::ServerStreamTimedOut), 0x44 => Ok(Self::ServerStreamConnectionRefused),
0x48 => Ok(R::ServerStreamBlockedAddress), 0x47 => Ok(Self::ServerStreamTimedOut),
0x49 => Ok(R::ServerStreamThrottled), 0x48 => Ok(Self::ServerStreamBlockedAddress),
0x81 => Ok(R::ClientUnexpected), 0x49 => Ok(Self::ServerStreamThrottled),
0x81 => Ok(Self::ClientUnexpected),
0xc0 => Ok(Self::ExtensionsPasswordAuthFailed),
0xc1 => Ok(Self::ExtensionsCertAuthFailed),
0xc2 => Ok(Self::ExtensionsAuthRequired),
_ => Err(Self::Error::InvalidCloseReason), _ => Err(Self::Error::InvalidCloseReason),
} }
} }
@ -116,6 +129,7 @@ mod close {
C::Voluntary => "Voluntarily closed", C::Voluntary => "Voluntarily closed",
C::Unexpected => "Unexpectedly closed", C::Unexpected => "Unexpectedly closed",
C::ExtensionsIncompatible => "Incompatible protocol extensions", C::ExtensionsIncompatible => "Incompatible protocol extensions",
C::ServerStreamInvalidInfo => C::ServerStreamInvalidInfo =>
"Stream creation failed due to invalid information", "Stream creation failed due to invalid information",
C::ServerStreamUnreachable => C::ServerStreamUnreachable =>
@ -127,9 +141,12 @@ mod close {
C::ServerStreamTimedOut => "TCP timed out", C::ServerStreamTimedOut => "TCP timed out",
C::ServerStreamBlockedAddress => "Destination address is blocked", C::ServerStreamBlockedAddress => "Destination address is blocked",
C::ServerStreamThrottled => "Throttled", C::ServerStreamThrottled => "Throttled",
C::ClientUnexpected => "Client encountered unexpected error", C::ClientUnexpected => "Client encountered unexpected error",
C::ExtensionsPasswordAuthFailed => "Invalid username/password", C::ExtensionsPasswordAuthFailed => "Invalid username/password",
C::ExtensionsCertAuthFailed => "Invalid signature", C::ExtensionsCertAuthFailed => "Invalid signature",
C::ExtensionsAuthRequired => "Authentication required",
} }
) )
} }
@ -271,6 +288,12 @@ pub struct WispVersion {
pub minor: u8, pub minor: u8,
} }
impl Display for WispVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.major, self.minor)
}
}
/// Packet used in the initial handshake. /// Packet used in the initial handshake.
/// ///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
@ -450,7 +473,10 @@ impl<'a> Packet<'a> {
}; };
if version.major != WISP_VERSION.major { if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion); return Err(WispError::IncompatibleProtocolVersion(
version,
WISP_VERSION,
));
} }
let mut extensions = Vec::new(); let mut extensions = Vec::new();