From 397fd43dc57ef7da8a7b78b884a3109a7dec6354 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 16:49:07 -0700 Subject: [PATCH] remove invalidstreamtype to allow for custom protocol extension streams --- server/src/main.rs | 4 +++ wisp/src/lib.rs | 3 -- wisp/src/packet.rs | 78 ++++++++++++++++++++++++++-------------------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 9644a22..a191a01 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -247,6 +247,10 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result { + stream.close(CloseReason::ServerStreamInvalidInfo).await?; + return Ok(false); + } } Ok(true) } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 40b21f7..3c77584 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -52,8 +52,6 @@ pub enum WispError { PacketTooSmall, /// The packet received had an invalid type. InvalidPacketType, - /// The stream had an invalid type. - InvalidStreamType, /// The stream had an invalid ID. InvalidStreamId, /// The close packet had an invalid reason. @@ -113,7 +111,6 @@ impl std::fmt::Display for WispError { match self { Self::PacketTooSmall => write!(f, "Packet too small"), Self::InvalidPacketType => write!(f, "Invalid packet type"), - Self::InvalidStreamType => write!(f, "Invalid stream type"), Self::InvalidStreamId => write!(f, "Invalid stream id"), Self::InvalidCloseReason => write!(f, "Invalid close reason"), Self::InvalidUri => write!(f, "Invalid URI"), diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 0017307..9ff6a3c 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -9,19 +9,31 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; #[derive(Debug, PartialEq, Copy, Clone)] pub enum StreamType { /// TCP Wisp stream. - Tcp = 0x01, + Tcp, /// UDP Wisp stream. - Udp = 0x02, + Udp, + /// Unknown Wisp stream type used for custom streams by protocol extensions. + Unknown(u8), } -impl TryFrom for StreamType { - type Error = WispError; - fn try_from(stream_type: u8) -> Result { - use StreamType::*; - match stream_type { - 0x01 => Ok(Tcp), - 0x02 => Ok(Udp), - _ => Err(Self::Error::InvalidStreamType), +impl From for StreamType { + fn from(value: u8) -> Self { + use StreamType as S; + match value { + 0x01 => S::Tcp, + 0x02 => S::Udp, + x => S::Unknown(x), + } + } +} + +impl From for u8 { + fn from(value: StreamType) -> Self { + use StreamType as S; + match value { + S::Tcp => 0x01, + S::Udp => 0x02, + S::Unknown(x) => x, } } } @@ -60,9 +72,9 @@ pub enum CloseReason { impl TryFrom for CloseReason { type Error = WispError; - fn try_from(stream_type: u8) -> Result { + fn try_from(close_reason: u8) -> Result { use CloseReason as R; - match stream_type { + match close_reason { 0x01 => Ok(R::Unknown), 0x02 => Ok(R::Voluntary), 0x03 => Ok(R::Unexpected), @@ -75,7 +87,7 @@ impl TryFrom for CloseReason { 0x48 => Ok(R::ServerStreamBlockedAddress), 0x49 => Ok(R::ServerStreamThrottled), 0x81 => Ok(R::ClientUnexpected), - _ => Err(Self::Error::InvalidStreamType), + _ => Err(Self::Error::InvalidCloseReason), } } } @@ -115,7 +127,7 @@ impl TryFrom for ConnectPacket { return Err(Self::Error::PacketTooSmall); } Ok(Self { - stream_type: bytes.get_u8().try_into()?, + stream_type: bytes.get_u8().into(), destination_port: bytes.get_u16_le(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(), }) @@ -125,7 +137,7 @@ impl TryFrom for ConnectPacket { impl From for Bytes { fn from(packet: ConnectPacket) -> Self { let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); - encoded.put_u8(packet.stream_type as u8); + encoded.put_u8(packet.stream_type.into()); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); encoded.freeze() @@ -255,26 +267,26 @@ pub enum PacketType { impl PacketType { /// Get the packet type used in the protocol. pub fn as_u8(&self) -> u8 { - use PacketType::*; + use PacketType as P; match self { - Connect(_) => 0x01, - Data(_) => 0x02, - Continue(_) => 0x03, - Close(_) => 0x04, - Info(_) => 0x05, + P::Connect(_) => 0x01, + P::Data(_) => 0x02, + P::Continue(_) => 0x03, + P::Close(_) => 0x04, + P::Info(_) => 0x05, } } } impl From for Bytes { fn from(packet: PacketType) -> Self { - use PacketType::*; + use PacketType as P; match packet { - Connect(x) => x.into(), - Data(x) => x, - Continue(x) => x.into(), - Close(x) => x.into(), - Info(x) => x.into(), + P::Connect(x) => x.into(), + P::Data(x) => x, + P::Continue(x) => x.into(), + P::Close(x) => x.into(), + P::Info(x) => x.into(), } } } @@ -351,14 +363,14 @@ impl Packet { } fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { - use PacketType::*; + use PacketType as P; Ok(Self { stream_id: bytes.get_u32_le(), packet_type: match packet_type { - 0x01 => Connect(ConnectPacket::try_from(bytes)?), - 0x02 => Data(bytes), - 0x03 => Continue(ContinuePacket::try_from(bytes)?), - 0x04 => Close(ClosePacket::try_from(bytes)?), + 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), + 0x02 => P::Data(bytes), + 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), + 0x04 => P::Close(ClosePacket::try_from(bytes)?), // 0x05 is handled seperately _ => return Err(WispError::InvalidPacketType), }, @@ -465,7 +477,7 @@ impl Packet { impl TryFrom for Packet { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { - if bytes.remaining() < 5 { + if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } let packet_type = bytes.get_u8();