remove invalidstreamtype to allow for custom protocol extension streams

This commit is contained in:
Toshit Chawda 2024-04-13 16:49:07 -07:00
parent b8eb13903b
commit 397fd43dc5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 49 additions and 36 deletions

View file

@ -247,6 +247,10 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
} }
} }
} }
StreamType::Unknown(_) => {
stream.close(CloseReason::ServerStreamInvalidInfo).await?;
return Ok(false);
}
} }
Ok(true) Ok(true)
} }

View file

@ -52,8 +52,6 @@ pub enum WispError {
PacketTooSmall, PacketTooSmall,
/// The packet received had an invalid type. /// The packet received had an invalid type.
InvalidPacketType, InvalidPacketType,
/// The stream had an invalid type.
InvalidStreamType,
/// The stream had an invalid ID. /// The stream had an invalid ID.
InvalidStreamId, InvalidStreamId,
/// The close packet had an invalid reason. /// The close packet had an invalid reason.
@ -113,7 +111,6 @@ impl std::fmt::Display for WispError {
match self { match self {
Self::PacketTooSmall => write!(f, "Packet too small"), Self::PacketTooSmall => write!(f, "Packet too small"),
Self::InvalidPacketType => write!(f, "Invalid packet type"), Self::InvalidPacketType => write!(f, "Invalid packet type"),
Self::InvalidStreamType => write!(f, "Invalid stream type"),
Self::InvalidStreamId => write!(f, "Invalid stream id"), Self::InvalidStreamId => write!(f, "Invalid stream id"),
Self::InvalidCloseReason => write!(f, "Invalid close reason"), Self::InvalidCloseReason => write!(f, "Invalid close reason"),
Self::InvalidUri => write!(f, "Invalid URI"), Self::InvalidUri => write!(f, "Invalid URI"),

View file

@ -9,19 +9,31 @@ use bytes::{Buf, BufMut, Bytes, BytesMut};
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum StreamType { pub enum StreamType {
/// TCP Wisp stream. /// TCP Wisp stream.
Tcp = 0x01, Tcp,
/// UDP Wisp stream. /// UDP Wisp stream.
Udp = 0x02, Udp,
/// Unknown Wisp stream type used for custom streams by protocol extensions.
Unknown(u8),
} }
impl TryFrom<u8> for StreamType { impl From<u8> for StreamType {
type Error = WispError; fn from(value: u8) -> Self {
fn try_from(stream_type: u8) -> Result<Self, Self::Error> { use StreamType as S;
use StreamType::*; match value {
match stream_type { 0x01 => S::Tcp,
0x01 => Ok(Tcp), 0x02 => S::Udp,
0x02 => Ok(Udp), x => S::Unknown(x),
_ => Err(Self::Error::InvalidStreamType), }
}
}
impl From<StreamType> for u8 {
fn from(value: StreamType) -> Self {
use StreamType as S;
match value {
S::Tcp => 0x01,
S::Udp => 0x02,
S::Unknown(x) => x,
} }
} }
} }
@ -60,9 +72,9 @@ pub enum CloseReason {
impl TryFrom<u8> for CloseReason { impl TryFrom<u8> for CloseReason {
type Error = WispError; type Error = WispError;
fn try_from(stream_type: u8) -> Result<Self, Self::Error> { fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
use CloseReason as R; use CloseReason as R;
match stream_type { match close_reason {
0x01 => Ok(R::Unknown), 0x01 => Ok(R::Unknown),
0x02 => Ok(R::Voluntary), 0x02 => Ok(R::Voluntary),
0x03 => Ok(R::Unexpected), 0x03 => Ok(R::Unexpected),
@ -75,7 +87,7 @@ impl TryFrom<u8> for CloseReason {
0x48 => Ok(R::ServerStreamBlockedAddress), 0x48 => Ok(R::ServerStreamBlockedAddress),
0x49 => Ok(R::ServerStreamThrottled), 0x49 => Ok(R::ServerStreamThrottled),
0x81 => Ok(R::ClientUnexpected), 0x81 => Ok(R::ClientUnexpected),
_ => Err(Self::Error::InvalidStreamType), _ => Err(Self::Error::InvalidCloseReason),
} }
} }
} }
@ -115,7 +127,7 @@ impl TryFrom<Bytes> for ConnectPacket {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
Ok(Self { Ok(Self {
stream_type: bytes.get_u8().try_into()?, stream_type: bytes.get_u8().into(),
destination_port: bytes.get_u16_le(), destination_port: bytes.get_u16_le(),
destination_hostname: std::str::from_utf8(&bytes)?.to_string(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
}) })
@ -125,7 +137,7 @@ impl TryFrom<Bytes> for ConnectPacket {
impl From<ConnectPacket> for Bytes { impl From<ConnectPacket> for Bytes {
fn from(packet: ConnectPacket) -> Self { fn from(packet: ConnectPacket) -> Self {
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type as u8); encoded.put_u8(packet.stream_type.into());
encoded.put_u16_le(packet.destination_port); encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes()); encoded.extend(packet.destination_hostname.bytes());
encoded.freeze() encoded.freeze()
@ -255,26 +267,26 @@ pub enum PacketType {
impl PacketType { impl PacketType {
/// Get the packet type used in the protocol. /// Get the packet type used in the protocol.
pub fn as_u8(&self) -> u8 { pub fn as_u8(&self) -> u8 {
use PacketType::*; use PacketType as P;
match self { match self {
Connect(_) => 0x01, P::Connect(_) => 0x01,
Data(_) => 0x02, P::Data(_) => 0x02,
Continue(_) => 0x03, P::Continue(_) => 0x03,
Close(_) => 0x04, P::Close(_) => 0x04,
Info(_) => 0x05, P::Info(_) => 0x05,
} }
} }
} }
impl From<PacketType> for Bytes { impl From<PacketType> for Bytes {
fn from(packet: PacketType) -> Self { fn from(packet: PacketType) -> Self {
use PacketType::*; use PacketType as P;
match packet { match packet {
Connect(x) => x.into(), P::Connect(x) => x.into(),
Data(x) => x, P::Data(x) => x,
Continue(x) => x.into(), P::Continue(x) => x.into(),
Close(x) => x.into(), P::Close(x) => x.into(),
Info(x) => x.into(), P::Info(x) => x.into(),
} }
} }
} }
@ -351,14 +363,14 @@ impl Packet {
} }
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> { fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType::*; use PacketType as P;
Ok(Self { Ok(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: match packet_type { packet_type: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?), 0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes), 0x02 => P::Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => Close(ClosePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately // 0x05 is handled seperately
_ => return Err(WispError::InvalidPacketType), _ => return Err(WispError::InvalidPacketType),
}, },
@ -465,7 +477,7 @@ impl Packet {
impl TryFrom<Bytes> for Packet { impl TryFrom<Bytes> for Packet {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> { fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
if bytes.remaining() < 5 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();