use std::fmt::Display; use bytes::{Buf, BufMut}; use num_enum::{FromPrimitive, IntoPrimitive}; use crate::{ extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, ws::{Payload, PayloadMut, PayloadRef, WebSocketRead, WebSocketWrite}, LockedWebSocketWrite, Role, WispError, WISP_VERSION, }; trait PacketCodec: Sized { fn size_hint(&self) -> usize; fn encode_into(&self, packet: &mut PayloadMut); fn decode(packet: &mut Payload) -> Result; } #[derive(FromPrimitive, IntoPrimitive, Debug, Copy, Clone, Eq, PartialEq)] #[repr(u8)] pub enum StreamType { Tcp = 0x01, Udp = 0x02, #[num_enum(catch_all)] Other(u8), } impl PacketCodec for StreamType { fn size_hint(&self) -> usize { size_of::() } fn encode_into(&self, packet: &mut PayloadMut) { packet.put_u8((*self).into()); } fn decode(packet: &mut Payload) -> Result { if packet.remaining() < size_of::() { return Err(WispError::PacketTooSmall); } Ok(Self::from(packet.get_u8())) } } #[derive(FromPrimitive, IntoPrimitive, Debug, Copy, Clone, Eq, PartialEq)] #[repr(u8)] pub enum CloseReason { /// Reason unspecified or unknown. Unknown = 0x01, /// Voluntary stream closure. Voluntary = 0x02, /// Unexpected stream closure due to a network error. Unexpected = 0x03, /// Incompatible extensions. ExtensionsIncompatible = 0x04, /// Stream creation failed due to invalid information. ServerStreamInvalidInfo = 0x41, /// Stream creation failed due to an unreachable destination host. ServerStreamUnreachable = 0x42, /// Stream creation timed out due to the destination server not responding. ServerStreamConnectionTimedOut = 0x43, /// Stream creation failed due to the destination server refusing the connection. ServerStreamConnectionRefused = 0x44, /// TCP data transfer timed out. ServerStreamTimedOut = 0x47, /// Stream destination address/domain is intentionally blocked by the proxy server. ServerStreamBlockedAddress = 0x48, /// Connection throttled by the server. ServerStreamThrottled = 0x49, /// The client has encountered an unexpected error and is unable to recieve any more data. ClientUnexpected = 0x81, /// Authentication failed due to invalid username/password. ExtensionsPasswordAuthFailed = 0xc0, /// Authentication failed due to invalid signature. ExtensionsCertAuthFailed = 0xc1, /// Authentication required but the client did not provide credentials. ExtensionsAuthRequired = 0xc2, #[num_enum(catch_all)] Other(u8), } impl Display for CloseReason { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use CloseReason as C; if let C::Other(x) = self { return write!(f, "Other: {x}"); } write!( f, "{}", match self { C::Unknown => "Unknown close reason", C::Voluntary => "Voluntarily closed", C::Unexpected => "Unexpectedly closed", C::ExtensionsIncompatible => "Incompatible protocol extensions", C::ServerStreamInvalidInfo => "Stream creation failed due to invalid information", C::ServerStreamUnreachable => "Stream creation failed due to an unreachable destination", C::ServerStreamConnectionTimedOut => "Stream creation failed due to destination not responding", C::ServerStreamConnectionRefused => "Stream creation failed due to destination refusing connection", C::ServerStreamTimedOut => "TCP timed out", C::ServerStreamBlockedAddress => "Destination address is blocked", C::ServerStreamThrottled => "Throttled", C::ClientUnexpected => "Client encountered unexpected error", C::ExtensionsPasswordAuthFailed => "Invalid username/password", C::ExtensionsCertAuthFailed => "Invalid signature", C::ExtensionsAuthRequired => "Authentication required", C::Other(_) => unreachable!(), } ) } } impl PacketCodec for CloseReason { fn size_hint(&self) -> usize { size_of::() } fn encode_into(&self, packet: &mut PayloadMut) { packet.put_u8((*self).into()); } fn decode(packet: &mut Payload) -> Result { if packet.remaining() < size_of::() { return Err(WispError::PacketTooSmall); } Ok(Self::from(packet.get_u8())) } } #[derive(Debug, Clone, Eq, PartialEq)] pub struct ConnectPacket { pub stream_type: StreamType, pub host: String, pub port: u16, } impl PacketCodec for ConnectPacket { fn size_hint(&self) -> usize { self.stream_type.size_hint() + self.host.len() + size_of::() } fn encode_into(&self, packet: &mut PayloadMut) { self.stream_type.encode_into(packet); packet.put_u16_le(self.port); packet.extend_from_slice(self.host.as_bytes()); } fn decode(packet: &mut Payload) -> Result { if packet.remaining() < (size_of::() + size_of::()) { return Err(WispError::PacketTooSmall); } let stream_type = StreamType::decode(packet)?; let port = packet.get_u16_le(); let host = String::from_utf8(packet.to_vec())?; Ok(Self { stream_type, host, port, }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ContinuePacket { pub buffer_remaining: u32, } impl PacketCodec for ContinuePacket { fn size_hint(&self) -> usize { size_of::() } fn encode_into(&self, packet: &mut PayloadMut) { packet.put_u32_le(self.buffer_remaining); } fn decode(packet: &mut Payload) -> Result { if packet.remaining() < size_of::() { return Err(WispError::PacketTooSmall); } let buffer_remaining = packet.get_u32_le(); Ok(Self { buffer_remaining }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ClosePacket { pub reason: CloseReason, } impl PacketCodec for ClosePacket { fn size_hint(&self) -> usize { self.reason.size_hint() } fn encode_into(&self, packet: &mut PayloadMut) { self.reason.encode_into(packet); } fn decode(packet: &mut Payload) -> Result { let reason = CloseReason::decode(packet)?; Ok(Self { reason }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct WispVersion { pub major: u8, pub minor: u8, } impl Display for WispVersion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}.{}", self.major, self.minor) } } impl PacketCodec for WispVersion { fn size_hint(&self) -> usize { size_of::() * 2 } fn encode_into(&self, packet: &mut PayloadMut) { packet.put_u8(self.major); packet.put_u8(self.minor); } fn decode(packet: &mut Payload) -> Result { if packet.remaining() < 2 { return Err(WispError::PacketTooSmall); } Ok(Self { major: packet.get_u8(), minor: packet.get_u8(), }) } } #[derive(Debug, Clone)] pub struct InfoPacket { pub version: WispVersion, pub extensions: Vec, } impl InfoPacket { pub(crate) fn decode( packet: &mut Payload, builders: &mut [AnyProtocolExtensionBuilder], role: Role, ) -> Result { if packet.remaining() < (size_of::() * 2) { return Err(WispError::PacketTooSmall); } let version = WispVersion { major: packet.get_u8(), minor: packet.get_u8(), }; if version.major != WISP_VERSION.major { return Err(WispError::IncompatibleProtocolVersion( version, WISP_VERSION, )); } let mut extensions = Vec::new(); while packet.remaining() >= (size_of::() + size_of::()) { // We have some extensions let id = packet.get_u8(); let length = usize::try_from(packet.get_u32_le())?; if packet.remaining() < length { return Err(WispError::PacketTooSmall); } if let Some(builder) = builders.iter_mut().find(|x| x.get_id() == id) { extensions.push(builder.build_from_bytes(packet.split_to(length), role)?); } else { packet.advance(length); } } Ok(Self { version, extensions, }) } pub(crate) fn encode(&self) -> Payload { let mut packet = PayloadMut::with_capacity( size_of::() + size_of::() + self.version.size_hint(), ); packet.put_u8(0x05); packet.put_u32(0); self.version.encode_into(&mut packet); for extension in &self.extensions { extension.encode_into(&mut packet); } packet.freeze() } } #[derive(Debug, Clone, Eq, PartialEq)] pub enum PacketType<'a> { Connect(ConnectPacket), Data(PayloadRef<'a>), Continue(ContinuePacket), Close(ClosePacket), } impl PacketType<'_> { pub(crate) fn size_hint(&self) -> usize { match self { Self::Connect(x) => x.size_hint(), Self::Data(x) => x.len(), Self::Continue(x) => x.size_hint(), Self::Close(x) => x.size_hint(), } } pub(crate) fn get_type(&self) -> u8 { match self { Self::Connect(_) => 0x01, Self::Data(_) => 0x02, Self::Continue(_) => 0x03, Self::Close(_) => 0x04, } } pub(crate) fn encode(&self, packet: &mut PayloadMut) { match self { Self::Connect(x) => x.encode_into(packet), Self::Data(x) => packet.extend_from_slice(x), Self::Continue(x) => x.encode_into(packet), Self::Close(x) => x.encode_into(packet), } } pub(crate) fn decode(mut packet: Payload, ty: u8) -> Result, WispError> { Ok(match ty { 0x01 => PacketType::Connect(ConnectPacket::decode(&mut packet)?), 0x02 => PacketType::Data(packet.into()), 0x03 => PacketType::Continue(ContinuePacket::decode(&mut packet)?), 0x04 => PacketType::Close(ClosePacket::decode(&mut packet)?), x => return Err(WispError::InvalidPacketType(x)), }) } } pub(crate) enum MaybeInfoPacket<'a> { Packet(Packet<'a>), Info(InfoPacket), } impl MaybeInfoPacket<'static> { pub(crate) fn decode( mut packet: Payload, builders: &mut [AnyProtocolExtensionBuilder], role: Role, ) -> Result { if packet.remaining() < size_of::() + size_of::() { return Err(WispError::PacketTooSmall); } let ty = packet.get_u8(); let stream_id = packet.get_u32_le(); if ty == 0x05 { Ok(Self::Info(InfoPacket::decode(&mut packet, builders, role)?)) } else { Ok(Self::Packet(Packet { stream_id, packet_type: PacketType::decode(packet, ty)?, })) } } } pub(crate) enum MaybeExtensionPacket<'a> { Packet(Packet<'a>), ExtensionHandled, } impl MaybeExtensionPacket<'static> { pub(crate) async fn decode( mut packet: Payload, extensions: &mut [AnyProtocolExtension], rx: &mut dyn WebSocketRead, tx: &mut LockedWebSocketWrite, ) -> Result { if packet.remaining() < size_of::() + size_of::() { return Err(WispError::PacketTooSmall); } let ty = packet.get_u8(); let stream_id = packet.get_u32_le(); if (0x01..=0x04).contains(&ty) { Ok(Self::Packet(Packet { stream_id, packet_type: PacketType::decode(packet, ty)?, })) } else { tx.lock().await; let mut handle = tx.get_handle(); for extension in extensions { if extension.get_supported_packets().contains(&ty) { extension.handle_packet(ty, packet, rx, &mut handle).await?; return Ok(Self::ExtensionHandled); } } drop(handle); Err(WispError::InvalidPacketType(ty)) } } } #[derive(Debug, Clone, Eq, PartialEq)] pub struct Packet<'a> { pub stream_id: u32, pub packet_type: PacketType<'a>, } impl Packet<'_> { fn size_hint(&self) -> usize { size_of::() + size_of::() + self.packet_type.size_hint() } fn encode_into(&self, packet: &mut PayloadMut) { packet.put_u8(self.packet_type.get_type()); packet.put_u32_le(self.stream_id); self.packet_type.encode(packet); } pub(crate) fn encode(&self) -> Payload { let mut payload = PayloadMut::with_capacity(self.size_hint()); self.encode_into(&mut payload); payload.into() } pub(crate) fn decode(mut packet: Payload) -> Result, WispError> { if packet.remaining() < size_of::() + size_of::() { return Err(WispError::PacketTooSmall); } let ty = packet.get_u8(); let stream_id = packet.get_u32_le(); Ok(Packet { stream_id, packet_type: PacketType::decode(packet, ty)?, }) } pub fn new_data<'a>(stream_id: u32, data: impl Into>) -> Packet<'a> { Packet { stream_id, packet_type: PacketType::Data(data.into()), } } pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { Self { stream_id, packet_type: PacketType::Continue(ContinuePacket { buffer_remaining }), } } pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { Self { stream_id, packet_type: PacketType::Close(ClosePacket { reason }), } } }