use crate::{ extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead}, Role, WispError, WISP_VERSION, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; /// Wisp stream type. #[derive(Debug, PartialEq, Copy, Clone)] pub enum StreamType { /// TCP Wisp stream. Tcp, /// UDP Wisp stream. Udp, /// Unknown Wisp stream type used for custom streams by protocol extensions. Unknown(u8), } impl From for StreamType { fn from(value: u8) -> Self { use StreamType as S; match value { 0x01 => S::Tcp, 0x02 => S::Udp, x => S::Unknown(x), } } } impl From for u8 { fn from(value: StreamType) -> Self { use StreamType as S; match value { S::Tcp => 0x01, S::Udp => 0x02, S::Unknown(x) => x, } } } /// Close reason. /// /// See [the /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) #[derive(Debug, PartialEq, Copy, Clone)] pub enum CloseReason { /// Reason unspecified or unknown. Unknown = 0x01, /// Voluntary stream closure. Voluntary = 0x02, /// Unexpected stream closure due to a network error. Unexpected = 0x03, /// Incompatible extensions. Only used during the handshake. IncompatibleExtensions = 0x04, /// Stream creation failed due to invalid information. ServerStreamInvalidInfo = 0x41, /// Stream creation failed due to an unreachable destination host. ServerStreamUnreachable = 0x42, /// Stream creation timed out due to the destination server not responding. ServerStreamConnectionTimedOut = 0x43, /// Stream creation failed due to the destination server refusing the connection. ServerStreamConnectionRefused = 0x44, /// TCP data transfer timed out. ServerStreamTimedOut = 0x47, /// Stream destination address/domain is intentionally blocked by the proxy server. ServerStreamBlockedAddress = 0x48, /// Connection throttled by the server. ServerStreamThrottled = 0x49, /// The client has encountered an unexpected error. ClientUnexpected = 0x81, } impl TryFrom for CloseReason { type Error = WispError; fn try_from(close_reason: u8) -> Result { use CloseReason as R; match close_reason { 0x01 => Ok(R::Unknown), 0x02 => Ok(R::Voluntary), 0x03 => Ok(R::Unexpected), 0x04 => Ok(R::IncompatibleExtensions), 0x41 => Ok(R::ServerStreamInvalidInfo), 0x42 => Ok(R::ServerStreamUnreachable), 0x43 => Ok(R::ServerStreamConnectionTimedOut), 0x44 => Ok(R::ServerStreamConnectionRefused), 0x47 => Ok(R::ServerStreamTimedOut), 0x48 => Ok(R::ServerStreamBlockedAddress), 0x49 => Ok(R::ServerStreamThrottled), 0x81 => Ok(R::ClientUnexpected), _ => Err(Self::Error::InvalidCloseReason), } } } /// Packet used to create a new stream. /// /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). #[derive(Debug, Clone)] pub struct ConnectPacket { /// Whether the new stream should use a TCP or UDP socket. pub stream_type: StreamType, /// Destination TCP/UDP port for the new stream. pub destination_port: u16, /// Destination hostname, in a UTF-8 string. pub destination_hostname: String, } impl ConnectPacket { /// Create a new connect packet. pub fn new( stream_type: StreamType, destination_port: u16, destination_hostname: String, ) -> Self { Self { stream_type, destination_port, destination_hostname, } } } impl TryFrom for ConnectPacket { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { if bytes.remaining() < (1 + 2) { return Err(Self::Error::PacketTooSmall); } Ok(Self { stream_type: bytes.get_u8().into(), destination_port: bytes.get_u16_le(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(), }) } } impl From for Bytes { fn from(packet: ConnectPacket) -> Self { let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); encoded.put_u8(packet.stream_type.into()); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); encoded.freeze() } } /// Packet used for Wisp TCP stream flow control. /// /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue). #[derive(Debug, Copy, Clone)] pub struct ContinuePacket { /// Number of packets that the server can buffer for the current stream. pub buffer_remaining: u32, } impl ContinuePacket { /// Create a new continue packet. pub fn new(buffer_remaining: u32) -> Self { Self { buffer_remaining } } } impl TryFrom for ContinuePacket { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { if bytes.remaining() < 4 { return Err(Self::Error::PacketTooSmall); } Ok(Self { buffer_remaining: bytes.get_u32_le(), }) } } impl From for Bytes { fn from(packet: ContinuePacket) -> Self { let mut encoded = BytesMut::with_capacity(4); encoded.put_u32_le(packet.buffer_remaining); encoded.freeze() } } /// Packet used to close a stream. /// /// See [the /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close). #[derive(Debug, Copy, Clone)] pub struct ClosePacket { /// The close reason. pub reason: CloseReason, } impl ClosePacket { /// Create a new close packet. pub fn new(reason: CloseReason) -> Self { Self { reason } } } impl TryFrom for ClosePacket { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } Ok(Self { reason: bytes.get_u8().try_into()?, }) } } impl From for Bytes { fn from(packet: ClosePacket) -> Self { let mut encoded = BytesMut::with_capacity(1); encoded.put_u8(packet.reason as u8); encoded.freeze() } } /// Wisp version sent in the handshake. #[derive(Debug, Clone)] pub struct WispVersion { /// Major Wisp version according to semver. pub major: u8, /// Minor Wisp version according to semver. pub minor: u8, } /// Packet used in the initial handshake. /// /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) #[derive(Debug, Clone)] pub struct InfoPacket { /// Wisp version sent in the packet. pub version: WispVersion, /// List of protocol extensions sent in the packet. pub extensions: Vec, } impl From for Bytes { fn from(value: InfoPacket) -> Self { let mut bytes = BytesMut::with_capacity(2); bytes.put_u8(value.version.major); bytes.put_u8(value.version.minor); for extension in value.extensions { bytes.extend(Bytes::from(extension)); } bytes.freeze() } } #[derive(Debug, Clone)] /// Type of packet recieved. pub enum PacketType { /// Connect packet. Connect(ConnectPacket), /// Data packet. Data(Bytes), /// Continue packet. Continue(ContinuePacket), /// Close packet. Close(ClosePacket), /// Info packet. Info(InfoPacket), } impl PacketType { /// Get the packet type used in the protocol. pub fn as_u8(&self) -> u8 { use PacketType as P; match self { P::Connect(_) => 0x01, P::Data(_) => 0x02, P::Continue(_) => 0x03, P::Close(_) => 0x04, P::Info(_) => 0x05, } } } impl From for Bytes { fn from(packet: PacketType) -> Self { use PacketType as P; match packet { P::Connect(x) => x.into(), P::Data(x) => x, P::Continue(x) => x.into(), P::Close(x) => x.into(), P::Info(x) => x.into(), } } } /// Wisp protocol packet. #[derive(Debug, Clone)] pub struct Packet { /// Stream this packet is associated with. pub stream_id: u32, /// Packet type recieved. pub packet_type: PacketType, } impl Packet { /// Create a new packet. /// /// The helper functions should be used for most use cases. pub fn new(stream_id: u32, packet: PacketType) -> Self { Self { stream_id, packet_type: packet, } } /// Create a new connect packet. pub fn new_connect( stream_id: u32, stream_type: StreamType, destination_port: u16, destination_hostname: String, ) -> Self { Self { stream_id, packet_type: PacketType::Connect(ConnectPacket::new( stream_type, destination_port, destination_hostname, )), } } /// Create a new data packet. pub fn new_data(stream_id: u32, data: Bytes) -> Self { Self { stream_id, packet_type: PacketType::Data(data), } } /// Create a new continue packet. pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { Self { stream_id, packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)), } } /// Create a new close packet. pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { Self { stream_id, packet_type: PacketType::Close(ClosePacket::new(reason)), } } pub(crate) fn new_info(extensions: Vec) -> Self { Self { stream_id: 0, packet_type: PacketType::Info(InfoPacket { version: WISP_VERSION, extensions, }), } } fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { use PacketType as P; Ok(Self { stream_id: bytes.get_u32_le(), packet_type: match packet_type { 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), 0x02 => P::Data(bytes), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?), // 0x05 is handled seperately _ => return Err(WispError::InvalidPacketType), }, }) } pub(crate) fn maybe_parse_info( frame: Frame, role: Role, extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], ) -> Result { if !frame.finished { return Err(WispError::WsFrameNotFinished); } if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } let mut bytes = frame.payload; if bytes.remaining() < 1 { return Err(WispError::PacketTooSmall); } let packet_type = bytes.get_u8(); if packet_type == 0x05 { Self::parse_info(bytes, role, extension_builders) } else { Self::parse_packet(packet_type, bytes) } } pub(crate) async fn maybe_handle_extension( frame: Frame, extensions: &mut [AnyProtocolExtension], read: &mut (dyn WebSocketRead + Send), write: &LockedWebSocketWrite, ) -> Result, WispError> { if !frame.finished { return Err(WispError::WsFrameNotFinished); } if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } let mut bytes = frame.payload; if bytes.remaining() < 1 { return Err(WispError::PacketTooSmall); } let packet_type = bytes.get_u8(); if let Some(extension) = extensions .iter_mut() .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) { extension.handle_packet(bytes, read, write).await?; Ok(None) } else { Ok(Some(Self::parse_packet(packet_type, bytes)?)) } } fn parse_info( mut bytes: Bytes, role: Role, extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], ) -> Result { // packet type is already read by code that calls this if bytes.remaining() < 4 + 2 { return Err(WispError::PacketTooSmall); } if bytes.get_u32_le() != 0 { return Err(WispError::InvalidStreamId); } let version = WispVersion { major: bytes.get_u8(), minor: bytes.get_u8(), }; if version.major != WISP_VERSION.major { return Err(WispError::IncompatibleProtocolVersion); } let mut extensions = Vec::new(); while bytes.remaining() > 4 { // We have some extensions let id = bytes.get_u8(); let length = usize::try_from(bytes.get_u32_le())?; if bytes.remaining() < length { return Err(WispError::PacketTooSmall); } if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { extensions.push(extension) } } else { bytes.advance(length) } } Ok(Self { stream_id: 0, packet_type: PacketType::Info(InfoPacket { version, extensions, }), }) } } impl TryFrom for Packet { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } let packet_type = bytes.get_u8(); Self::parse_packet(packet_type, bytes) } } impl From for Bytes { fn from(packet: Packet) -> Self { let inner_u8 = packet.packet_type.as_u8(); let inner = Bytes::from(packet.packet_type); let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len()); encoded.put_u8(inner_u8); encoded.put_u32_le(packet.stream_id); encoded.extend(inner); encoded.freeze() } } impl TryFrom for Packet { type Error = WispError; fn try_from(frame: ws::Frame) -> Result { if !frame.finished { return Err(Self::Error::WsFrameNotFinished); } if frame.opcode != ws::OpCode::Binary { return Err(Self::Error::WsFrameInvalidType); } frame.payload.try_into() } } impl From for ws::Frame { fn from(packet: Packet) -> Self { Self::binary(packet.into()) } }