fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets

This commit is contained in:
Toshit Chawda 2024-04-13 22:34:26 -07:00
parent 4d433b60c4
commit d10b7691e4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 220 additions and 50 deletions

View file

@ -202,6 +202,8 @@ pub mod udp {
pub mod password {
//! Password protocol extension.
//!
//! Passwords are sent in plain text!!
//!
//! # Example
//! Server:
//! ```
@ -246,6 +248,7 @@ pub mod password {
#[derive(Debug, Clone)]
/// Password protocol extension.
///
/// **Passwords are sent in plain text!!**
/// **This extension will panic when encoding if the username's length does not fit within a u8
/// or the password's length does not fit within a u16.**
pub struct PasswordProtocolExtension {
@ -306,7 +309,7 @@ pub mod password {
let password = Bytes::from(self.password.clone().into_bytes());
let username_len = u8::try_from(username.len()).expect("username was too long");
let password_len =
u16::try_from(username.len()).expect("password was too long");
u16::try_from(password.len()).expect("password was too long");
let mut bytes =
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
@ -380,6 +383,8 @@ pub mod password {
}
/// Password protocol extension builder.
///
/// **Passwords are sent in plain text!!**
pub struct PasswordProtocolExtensionBuilder {
/// Map of users and their passwords to allow. Only used on server.
pub users: HashMap<String, String>,
@ -448,6 +453,7 @@ pub mod password {
if *user.1 != password {
return Err(EError::InvalidPassword.into());
}
Ok(PasswordProtocolExtension {
username,
password,

View file

@ -64,6 +64,8 @@ pub enum WispError {
UriHasNoPort,
/// The max stream count was reached.
MaxStreamCountReached,
/// The Wisp protocol version was incompatible.
IncompatibleProtocolVersion,
/// The stream had already been closed.
StreamAlreadyClosed,
/// The websocket frame received had an invalid type.
@ -117,6 +119,7 @@ impl std::fmt::Display for WispError {
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
@ -286,7 +289,15 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
WsEvent::EndFut(x) => {
if let Some(reason) = x {
let _ = self
.tx
.write_frame(Packet::new_close(0, reason).into())
.await;
}
break;
}
}
}
}
@ -364,6 +375,9 @@ impl MuxInner {
}
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
@ -410,6 +424,9 @@ impl MuxInner {
}
}
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
@ -532,15 +549,28 @@ impl ServerMux {
self.muxstream_recv.next().await
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}
/// Client side multiplexor.
@ -600,7 +630,7 @@ impl ClientMux {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
@ -671,14 +701,27 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}

View file

@ -446,6 +446,10 @@ impl Packet {
minor: bytes.get_u8(),
};
if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion);
}
let mut extensions = Vec::new();
while bytes.remaining() > 4 {

View file

@ -27,7 +27,7 @@ pub(crate) enum WsEvent {
u16,
oneshot::Sender<Result<MuxStream, WispError>>,
),
EndFut,
EndFut(Option<CloseReason>),
}
/// Read side of a multiplexor stream.