move pw and cert auth errors out

This commit is contained in:
Toshit Chawda 2024-09-14 17:11:22 -07:00
parent 694d87f731
commit 01ff6ee956
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 137 additions and 99 deletions

View file

@ -68,6 +68,7 @@ pub enum WispError {
IncompatibleProtocolVersion,
/// The stream had already been closed.
StreamAlreadyClosed,
/// The websocket frame received had an invalid type.
WsFrameInvalidType,
/// The websocket frame received was not finished.
@ -78,18 +79,14 @@ pub enum WispError {
WsImplSocketClosed,
/// The websocket implementation did not support the action.
WsImplNotSupported,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error),
/// The integer failed to convert.
TryFromIntError(std::num::TryFromIntError),
/// Other error.
Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task.
MuxMessageFailedToSend,
/// Failed to receive message from multiplexor task.
@ -98,6 +95,17 @@ pub enum WispError {
MuxTaskEnded,
/// Multiplexor task already started.
MuxTaskStarted,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The specified protocol extensions are not supported by the server.
ExtensionsNotSupported(Vec<u8>),
/// The password authentication username/password was invalid.
PasswordExtensionCredsInvalid,
/// The certificate authentication signature was invalid.
CertAuthExtensionSigInvalid,
}
impl From<std::str::Utf8Error> for WispError {
@ -153,6 +161,12 @@ impl std::fmt::Display for WispError {
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
Self::MuxTaskStarted => write!(f, "Multiplexor task already started"),
Self::PasswordExtensionCredsInvalid => {
write!(f, "Password extension: Invalid username/password")
}
Self::CertAuthExtensionSigInvalid => {
write!(f, "Certificate authentication extension: Invalid signature")
}
}
}
}
@ -243,36 +257,64 @@ impl ServerMux {
W: ws::WebSocketWrite + Send + 'static,
{
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone();
let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (supported_extensions, extra_packet, downgraded) =
if let Some(mut builders) = extension_builders {
send_info_packet(&tx, &mut builders).await?;
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
}
.await;
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
match ret {
Ok(x) => Ok(x),
Err(x) => match x {
WispError::PasswordExtensionCredsInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
WispError::CertAuthExtensionSigInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
x => Err(x),
},
mux_result.mux.into_future(),
))
}
}
/// Wait for a stream to be created.
@ -300,12 +342,11 @@ impl ServerMux {
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
@ -346,18 +387,23 @@ where
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(
unsupported_extensions,
))
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
@ -483,12 +529,11 @@ impl ClientMux {
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
@ -528,14 +573,21 @@ where
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extensions.iter().any(|x| x.get_id() == *extension) {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}