remove a bunch of allocations from packet encode, drop rogue clients' packets

This commit is contained in:
Toshit Chawda 2024-04-27 22:05:25 -07:00
parent ce2660943a
commit 855fa610ed
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 77 additions and 74 deletions

View file

@ -381,7 +381,7 @@ impl MuxInner {
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.send_async(data).await; let _ = stream.stream.try_send(data);
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream.flow_control.store( stream.flow_control.store(
stream stream

View file

@ -92,6 +92,10 @@ impl TryFrom<u8> for CloseReason {
} }
} }
trait Encode {
fn encode(self, bytes: &mut BytesMut);
}
/// Packet used to create a new stream. /// Packet used to create a new stream.
/// ///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
@ -120,9 +124,9 @@ impl ConnectPacket {
} }
} }
impl TryFrom<Bytes> for ConnectPacket { impl TryFrom<BytesMut> for ConnectPacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> { fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) { if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -134,13 +138,11 @@ impl TryFrom<Bytes> for ConnectPacket {
} }
} }
impl From<ConnectPacket> for Bytes { impl Encode for ConnectPacket {
fn from(packet: ConnectPacket) -> Self { fn encode(self, bytes: &mut BytesMut) {
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); bytes.put_u8(self.stream_type.into());
encoded.put_u8(packet.stream_type.into()); bytes.put_u16_le(self.destination_port);
encoded.put_u16_le(packet.destination_port); bytes.extend(self.destination_hostname.bytes());
encoded.extend(packet.destination_hostname.bytes());
encoded.freeze()
} }
} }
@ -160,9 +162,9 @@ impl ContinuePacket {
} }
} }
impl TryFrom<Bytes> for ContinuePacket { impl TryFrom<BytesMut> for ContinuePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> { fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 { if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -172,11 +174,9 @@ impl TryFrom<Bytes> for ContinuePacket {
} }
} }
impl From<ContinuePacket> for Bytes { impl Encode for ContinuePacket {
fn from(packet: ContinuePacket) -> Self { fn encode(self, bytes: &mut BytesMut) {
let mut encoded = BytesMut::with_capacity(4); bytes.put_u32_le(self.buffer_remaining);
encoded.put_u32_le(packet.buffer_remaining);
encoded.freeze()
} }
} }
@ -197,9 +197,9 @@ impl ClosePacket {
} }
} }
impl TryFrom<Bytes> for ClosePacket { impl TryFrom<BytesMut> for ClosePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> { fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -209,11 +209,9 @@ impl TryFrom<Bytes> for ClosePacket {
} }
} }
impl From<ClosePacket> for Bytes { impl Encode for ClosePacket {
fn from(packet: ClosePacket) -> Self { fn encode(self, bytes: &mut BytesMut) {
let mut encoded = BytesMut::with_capacity(1); bytes.put_u8(self.reason as u8);
encoded.put_u8(packet.reason as u8);
encoded.freeze()
} }
} }
@ -237,15 +235,13 @@ pub struct InfoPacket {
pub extensions: Vec<AnyProtocolExtension>, pub extensions: Vec<AnyProtocolExtension>,
} }
impl From<InfoPacket> for Bytes { impl Encode for InfoPacket {
fn from(value: InfoPacket) -> Self { fn encode(self, bytes: &mut BytesMut) {
let mut bytes = BytesMut::with_capacity(2); bytes.put_u8(self.version.major);
bytes.put_u8(value.version.major); bytes.put_u8(self.version.minor);
bytes.put_u8(value.version.minor); for extension in self.extensions {
for extension in value.extensions {
bytes.extend(Bytes::from(extension)); bytes.extend(Bytes::from(extension));
} }
bytes.freeze()
} }
} }
@ -276,18 +272,29 @@ impl PacketType {
P::Info(_) => 0x05, P::Info(_) => 0x05,
} }
} }
pub(crate) fn get_packet_size(&self) -> usize {
use PacketType as P;
match self {
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
P::Data(p) => p.len(),
P::Continue(_) => 4,
P::Close(_) => 1,
P::Info(_) => 2,
}
}
} }
impl From<PacketType> for Bytes { impl Encode for PacketType {
fn from(packet: PacketType) -> Self { fn encode(self, bytes: &mut BytesMut) {
use PacketType as P; use PacketType as P;
match packet { match self {
P::Connect(x) => x.into(), P::Connect(x) => x.encode(bytes),
P::Data(x) => x, P::Data(x) => bytes.extend(x),
P::Continue(x) => x.into(), P::Continue(x) => x.encode(bytes),
P::Close(x) => x.into(), P::Close(x) => x.encode(bytes),
P::Info(x) => x.into(), P::Info(x) => x.encode(bytes),
} };
} }
} }
@ -362,21 +369,13 @@ impl Packet {
} }
} }
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut { fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> {
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(stream_id);
encoded.extend(bytes);
encoded
}
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType as P; use PacketType as P;
Ok(Self { Ok(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: match packet_type { packet_type: match packet_type {
0x01 => P::Connect(ConnectPacket::try_from(bytes)?), 0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => P::Data(bytes), 0x02 => P::Data(bytes.freeze()),
0x03 => P::Continue(ContinuePacket::try_from(bytes)?), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => P::Close(ClosePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately // 0x05 is handled seperately
@ -396,7 +395,7 @@ impl Packet {
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType);
} }
let mut bytes = frame.payload.freeze(); let mut bytes = frame.payload;
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
@ -420,8 +419,8 @@ impl Packet {
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType);
} }
let mut bytes = frame.payload.freeze(); let mut bytes = frame.payload;
if bytes.remaining() < 1 { if bytes.remaining() < 5 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();
@ -432,7 +431,7 @@ impl Packet {
})), })),
0x02 => Ok(Some(Self { 0x02 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Data(bytes), packet_type: PacketType::Data(bytes.freeze()),
})), })),
0x03 => Ok(Some(Self { 0x03 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
@ -448,7 +447,7 @@ impl Packet {
.iter_mut() .iter_mut()
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{ {
extension.handle_packet(bytes, read, write).await?; extension.handle_packet(bytes.freeze(), read, write).await?;
Ok(None) Ok(None)
} else { } else {
Err(WispError::InvalidPacketType) Err(WispError::InvalidPacketType)
@ -458,7 +457,7 @@ impl Packet {
} }
fn parse_info( fn parse_info(
mut bytes: Bytes, mut bytes: BytesMut,
role: Role, role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
@ -507,9 +506,17 @@ impl Packet {
} }
} }
impl TryFrom<Bytes> for Packet { impl Encode for Packet {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.packet_type.as_u8());
bytes.put_u32_le(self.stream_id);
self.packet_type.encode(bytes);
}
}
impl TryFrom<BytesMut> for Packet {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> { fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -520,11 +527,9 @@ impl TryFrom<Bytes> for Packet {
impl From<Packet> for BytesMut { impl From<Packet> for BytesMut {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
Packet::raw_encode( let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
packet.packet_type.as_u8(), packet.encode(&mut encoded);
packet.stream_id, encoded
packet.packet_type.into(),
)
} }
} }
@ -537,12 +542,12 @@ impl TryFrom<ws::Frame> for Packet {
if frame.opcode != ws::OpCode::Binary { if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType); return Err(Self::Error::WsFrameInvalidType);
} }
frame.payload.freeze().try_into() Packet::try_from(frame.payload)
} }
} }
impl From<Packet> for ws::Frame { impl From<Packet> for ws::Frame {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
Self::binary(packet.into()) Self::binary(BytesMut::from(packet))
} }
} }

View file

@ -5,7 +5,7 @@ use crate::{
}; };
pub use async_io_stream::IoStream; pub use async_io_stream::IoStream;
use bytes::Bytes; use bytes::{BufMut, Bytes, BytesMut};
use event_listener::Event; use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{ use futures::{
@ -114,7 +114,7 @@ impl MuxStreamWrite {
} }
self.tx self.tx
.write_frame(Packet::new_data(self.stream_id, data).into()) .write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
.await?; .await?;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
@ -348,13 +348,11 @@ impl MuxProtocolExtensionStream {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.tx let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
.write_frame(Frame::binary(Packet::raw_encode( encoded.put_u8(packet_type);
packet_type, encoded.put_u32_le(self.stream_id);
self.stream_id, encoded.extend(data);
data, self.tx.write_frame(Frame::binary(encoded)).await
)))
.await
} }
} }