fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets

This commit is contained in:
Toshit Chawda 2024-04-13 22:34:26 -07:00
parent 4d433b60c4
commit d10b7691e4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 220 additions and 50 deletions

View file

@ -64,6 +64,8 @@ pub enum WispError {
UriHasNoPort,
/// The max stream count was reached.
MaxStreamCountReached,
/// The Wisp protocol version was incompatible.
IncompatibleProtocolVersion,
/// The stream had already been closed.
StreamAlreadyClosed,
/// The websocket frame received had an invalid type.
@ -117,6 +119,7 @@ impl std::fmt::Display for WispError {
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
@ -286,7 +289,15 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
WsEvent::EndFut(x) => {
if let Some(reason) = x {
let _ = self
.tx
.write_frame(Packet::new_close(0, reason).into())
.await;
}
break;
}
}
}
}
@ -364,6 +375,9 @@ impl MuxInner {
}
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
@ -410,6 +424,9 @@ impl MuxInner {
}
}
Close(_) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
@ -532,15 +549,28 @@ impl ServerMux {
self.muxstream_recv.next().await
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}
/// Client side multiplexor.
@ -600,7 +630,7 @@ impl ClientMux {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
@ -671,14 +701,27 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
self.close_internal(None).await
}
/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}