diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 49ee3fc..e978029 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -381,7 +381,7 @@ impl MuxInner { } Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.send_async(data).await; + let _ = stream.stream.try_send(data); if stream.stream_type == StreamType::Tcp { stream.flow_control.store( stream diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 62ee9f1..919b443 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -92,6 +92,10 @@ impl TryFrom for CloseReason { } } +trait Encode { + fn encode(self, bytes: &mut BytesMut); +} + /// Packet used to create a new stream. /// /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). @@ -120,9 +124,9 @@ impl ConnectPacket { } } -impl TryFrom for ConnectPacket { +impl TryFrom for ConnectPacket { type Error = WispError; - fn try_from(mut bytes: Bytes) -> Result { + fn try_from(mut bytes: BytesMut) -> Result { if bytes.remaining() < (1 + 2) { return Err(Self::Error::PacketTooSmall); } @@ -134,13 +138,11 @@ impl TryFrom for ConnectPacket { } } -impl From for Bytes { - fn from(packet: ConnectPacket) -> Self { - let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); - encoded.put_u8(packet.stream_type.into()); - encoded.put_u16_le(packet.destination_port); - encoded.extend(packet.destination_hostname.bytes()); - encoded.freeze() +impl Encode for ConnectPacket { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.stream_type.into()); + bytes.put_u16_le(self.destination_port); + bytes.extend(self.destination_hostname.bytes()); } } @@ -160,9 +162,9 @@ impl ContinuePacket { } } -impl TryFrom for ContinuePacket { +impl TryFrom for ContinuePacket { type Error = WispError; - fn try_from(mut bytes: Bytes) -> Result { + fn try_from(mut bytes: BytesMut) -> Result { if bytes.remaining() < 4 { return Err(Self::Error::PacketTooSmall); } @@ -172,11 +174,9 @@ impl TryFrom for ContinuePacket { } } -impl From for Bytes { - fn from(packet: ContinuePacket) -> Self { - let mut encoded = BytesMut::with_capacity(4); - encoded.put_u32_le(packet.buffer_remaining); - encoded.freeze() +impl Encode for ContinuePacket { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u32_le(self.buffer_remaining); } } @@ -197,9 +197,9 @@ impl ClosePacket { } } -impl TryFrom for ClosePacket { +impl TryFrom for ClosePacket { type Error = WispError; - fn try_from(mut bytes: Bytes) -> Result { + fn try_from(mut bytes: BytesMut) -> Result { if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } @@ -209,11 +209,9 @@ impl TryFrom for ClosePacket { } } -impl From for Bytes { - fn from(packet: ClosePacket) -> Self { - let mut encoded = BytesMut::with_capacity(1); - encoded.put_u8(packet.reason as u8); - encoded.freeze() +impl Encode for ClosePacket { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.reason as u8); } } @@ -237,15 +235,13 @@ pub struct InfoPacket { pub extensions: Vec, } -impl From for Bytes { - fn from(value: InfoPacket) -> Self { - let mut bytes = BytesMut::with_capacity(2); - bytes.put_u8(value.version.major); - bytes.put_u8(value.version.minor); - for extension in value.extensions { +impl Encode for InfoPacket { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.version.major); + bytes.put_u8(self.version.minor); + for extension in self.extensions { bytes.extend(Bytes::from(extension)); } - bytes.freeze() } } @@ -276,18 +272,29 @@ impl PacketType { P::Info(_) => 0x05, } } + + pub(crate) fn get_packet_size(&self) -> usize { + use PacketType as P; + match self { + P::Connect(p) => 1 + 2 + p.destination_hostname.len(), + P::Data(p) => p.len(), + P::Continue(_) => 4, + P::Close(_) => 1, + P::Info(_) => 2, + } + } } -impl From for Bytes { - fn from(packet: PacketType) -> Self { +impl Encode for PacketType { + fn encode(self, bytes: &mut BytesMut) { use PacketType as P; - match packet { - P::Connect(x) => x.into(), - P::Data(x) => x, - P::Continue(x) => x.into(), - P::Close(x) => x.into(), - P::Info(x) => x.into(), - } + match self { + P::Connect(x) => x.encode(bytes), + P::Data(x) => bytes.extend(x), + P::Continue(x) => x.encode(bytes), + P::Close(x) => x.encode(bytes), + P::Info(x) => x.encode(bytes), + }; } } @@ -362,21 +369,13 @@ impl Packet { } } - pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut { - let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len()); - encoded.put_u8(packet_type); - encoded.put_u32_le(stream_id); - encoded.extend(bytes); - encoded - } - - fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { + fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result { use PacketType as P; Ok(Self { stream_id: bytes.get_u32_le(), packet_type: match packet_type { 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), - 0x02 => P::Data(bytes), + 0x02 => P::Data(bytes.freeze()), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?), // 0x05 is handled seperately @@ -396,7 +395,7 @@ impl Packet { if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } - let mut bytes = frame.payload.freeze(); + let mut bytes = frame.payload; if bytes.remaining() < 1 { return Err(WispError::PacketTooSmall); } @@ -420,8 +419,8 @@ impl Packet { if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } - let mut bytes = frame.payload.freeze(); - if bytes.remaining() < 1 { + let mut bytes = frame.payload; + if bytes.remaining() < 5 { return Err(WispError::PacketTooSmall); } let packet_type = bytes.get_u8(); @@ -432,7 +431,7 @@ impl Packet { })), 0x02 => Ok(Some(Self { stream_id: bytes.get_u32_le(), - packet_type: PacketType::Data(bytes), + packet_type: PacketType::Data(bytes.freeze()), })), 0x03 => Ok(Some(Self { stream_id: bytes.get_u32_le(), @@ -448,7 +447,7 @@ impl Packet { .iter_mut() .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) { - extension.handle_packet(bytes, read, write).await?; + extension.handle_packet(bytes.freeze(), read, write).await?; Ok(None) } else { Err(WispError::InvalidPacketType) @@ -458,7 +457,7 @@ impl Packet { } fn parse_info( - mut bytes: Bytes, + mut bytes: BytesMut, role: Role, extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], ) -> Result { @@ -507,9 +506,17 @@ impl Packet { } } -impl TryFrom for Packet { +impl Encode for Packet { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.packet_type.as_u8()); + bytes.put_u32_le(self.stream_id); + self.packet_type.encode(bytes); + } +} + +impl TryFrom for Packet { type Error = WispError; - fn try_from(mut bytes: Bytes) -> Result { + fn try_from(mut bytes: BytesMut) -> Result { if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } @@ -520,11 +527,9 @@ impl TryFrom for Packet { impl From for BytesMut { fn from(packet: Packet) -> Self { - Packet::raw_encode( - packet.packet_type.as_u8(), - packet.stream_id, - packet.packet_type.into(), - ) + let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); + packet.encode(&mut encoded); + encoded } } @@ -537,12 +542,12 @@ impl TryFrom for Packet { if frame.opcode != ws::OpCode::Binary { return Err(Self::Error::WsFrameInvalidType); } - frame.payload.freeze().try_into() + Packet::try_from(frame.payload) } } impl From for ws::Frame { fn from(packet: Packet) -> Self { - Self::binary(packet.into()) + Self::binary(BytesMut::from(packet)) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 3918edb..1dc792b 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -5,7 +5,7 @@ use crate::{ }; pub use async_io_stream::IoStream; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use event_listener::Event; use flume as mpsc; use futures::{ @@ -114,7 +114,7 @@ impl MuxStreamWrite { } self.tx - .write_frame(Packet::new_data(self.stream_id, data).into()) + .write_frame(Frame::from(Packet::new_data(self.stream_id, data))) .await?; if self.role == Role::Client && self.stream_type == StreamType::Tcp { @@ -348,13 +348,11 @@ impl MuxProtocolExtensionStream { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } - self.tx - .write_frame(Frame::binary(Packet::raw_encode( - packet_type, - self.stream_id, - data, - ))) - .await + let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); + encoded.put_u8(packet_type); + encoded.put_u32_le(self.stream_id); + encoded.extend(data); + self.tx.write_frame(Frame::binary(encoded)).await } }