mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
remove a bunch of allocations from packet encode, drop rogue clients' packets
This commit is contained in:
parent
ce2660943a
commit
855fa610ed
3 changed files with 77 additions and 74 deletions
|
@ -381,7 +381,7 @@ impl MuxInner {
|
|||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.send_async(data).await;
|
||||
let _ = stream.stream.try_send(data);
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream.flow_control.store(
|
||||
stream
|
||||
|
|
|
@ -92,6 +92,10 @@ impl TryFrom<u8> for CloseReason {
|
|||
}
|
||||
}
|
||||
|
||||
trait Encode {
|
||||
fn encode(self, bytes: &mut BytesMut);
|
||||
}
|
||||
|
||||
/// Packet used to create a new stream.
|
||||
///
|
||||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
|
||||
|
@ -120,9 +124,9 @@ impl ConnectPacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ConnectPacket {
|
||||
impl TryFrom<BytesMut> for ConnectPacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < (1 + 2) {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
|
@ -134,13 +138,11 @@ impl TryFrom<Bytes> for ConnectPacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ConnectPacket> for Bytes {
|
||||
fn from(packet: ConnectPacket) -> Self {
|
||||
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||
encoded.put_u8(packet.stream_type.into());
|
||||
encoded.put_u16_le(packet.destination_port);
|
||||
encoded.extend(packet.destination_hostname.bytes());
|
||||
encoded.freeze()
|
||||
impl Encode for ConnectPacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.stream_type.into());
|
||||
bytes.put_u16_le(self.destination_port);
|
||||
bytes.extend(self.destination_hostname.bytes());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,9 +162,9 @@ impl ContinuePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ContinuePacket {
|
||||
impl TryFrom<BytesMut> for ContinuePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 4 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
|
@ -172,11 +174,9 @@ impl TryFrom<Bytes> for ContinuePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ContinuePacket> for Bytes {
|
||||
fn from(packet: ContinuePacket) -> Self {
|
||||
let mut encoded = BytesMut::with_capacity(4);
|
||||
encoded.put_u32_le(packet.buffer_remaining);
|
||||
encoded.freeze()
|
||||
impl Encode for ContinuePacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u32_le(self.buffer_remaining);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,9 +197,9 @@ impl ClosePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ClosePacket {
|
||||
impl TryFrom<BytesMut> for ClosePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
|
@ -209,11 +209,9 @@ impl TryFrom<Bytes> for ClosePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ClosePacket> for Bytes {
|
||||
fn from(packet: ClosePacket) -> Self {
|
||||
let mut encoded = BytesMut::with_capacity(1);
|
||||
encoded.put_u8(packet.reason as u8);
|
||||
encoded.freeze()
|
||||
impl Encode for ClosePacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.reason as u8);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -237,15 +235,13 @@ pub struct InfoPacket {
|
|||
pub extensions: Vec<AnyProtocolExtension>,
|
||||
}
|
||||
|
||||
impl From<InfoPacket> for Bytes {
|
||||
fn from(value: InfoPacket) -> Self {
|
||||
let mut bytes = BytesMut::with_capacity(2);
|
||||
bytes.put_u8(value.version.major);
|
||||
bytes.put_u8(value.version.minor);
|
||||
for extension in value.extensions {
|
||||
impl Encode for InfoPacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.version.major);
|
||||
bytes.put_u8(self.version.minor);
|
||||
for extension in self.extensions {
|
||||
bytes.extend(Bytes::from(extension));
|
||||
}
|
||||
bytes.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,18 +272,29 @@ impl PacketType {
|
|||
P::Info(_) => 0x05,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_packet_size(&self) -> usize {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
|
||||
P::Data(p) => p.len(),
|
||||
P::Continue(_) => 4,
|
||||
P::Close(_) => 1,
|
||||
P::Info(_) => 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PacketType> for Bytes {
|
||||
fn from(packet: PacketType) -> Self {
|
||||
impl Encode for PacketType {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
use PacketType as P;
|
||||
match packet {
|
||||
P::Connect(x) => x.into(),
|
||||
P::Data(x) => x,
|
||||
P::Continue(x) => x.into(),
|
||||
P::Close(x) => x.into(),
|
||||
P::Info(x) => x.into(),
|
||||
}
|
||||
match self {
|
||||
P::Connect(x) => x.encode(bytes),
|
||||
P::Data(x) => bytes.extend(x),
|
||||
P::Continue(x) => x.encode(bytes),
|
||||
P::Close(x) => x.encode(bytes),
|
||||
P::Info(x) => x.encode(bytes),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,21 +369,13 @@ impl Packet {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut {
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
|
||||
encoded.put_u8(packet_type);
|
||||
encoded.put_u32_le(stream_id);
|
||||
encoded.extend(bytes);
|
||||
encoded
|
||||
}
|
||||
|
||||
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
|
||||
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> {
|
||||
use PacketType as P;
|
||||
Ok(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: match packet_type {
|
||||
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
||||
0x02 => P::Data(bytes),
|
||||
0x02 => P::Data(bytes.freeze()),
|
||||
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
||||
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
||||
// 0x05 is handled seperately
|
||||
|
@ -396,7 +395,7 @@ impl Packet {
|
|||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload.freeze();
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
@ -420,8 +419,8 @@ impl Packet {
|
|||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload.freeze();
|
||||
if bytes.remaining() < 1 {
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 5 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
|
@ -432,7 +431,7 @@ impl Packet {
|
|||
})),
|
||||
0x02 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Data(bytes),
|
||||
packet_type: PacketType::Data(bytes.freeze()),
|
||||
})),
|
||||
0x03 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
|
@ -448,7 +447,7 @@ impl Packet {
|
|||
.iter_mut()
|
||||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||
{
|
||||
extension.handle_packet(bytes, read, write).await?;
|
||||
extension.handle_packet(bytes.freeze(), read, write).await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
|
@ -458,7 +457,7 @@ impl Packet {
|
|||
}
|
||||
|
||||
fn parse_info(
|
||||
mut bytes: Bytes,
|
||||
mut bytes: BytesMut,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
|
@ -507,9 +506,17 @@ impl Packet {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for Packet {
|
||||
impl Encode for Packet {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.packet_type.as_u8());
|
||||
bytes.put_u32_le(self.stream_id);
|
||||
self.packet_type.encode(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<BytesMut> for Packet {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
|
@ -520,11 +527,9 @@ impl TryFrom<Bytes> for Packet {
|
|||
|
||||
impl From<Packet> for BytesMut {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Packet::raw_encode(
|
||||
packet.packet_type.as_u8(),
|
||||
packet.stream_id,
|
||||
packet.packet_type.into(),
|
||||
)
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
|
||||
packet.encode(&mut encoded);
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -537,12 +542,12 @@ impl TryFrom<ws::Frame> for Packet {
|
|||
if frame.opcode != ws::OpCode::Binary {
|
||||
return Err(Self::Error::WsFrameInvalidType);
|
||||
}
|
||||
frame.payload.freeze().try_into()
|
||||
Packet::try_from(frame.payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for ws::Frame {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Self::binary(packet.into())
|
||||
Self::binary(BytesMut::from(packet))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
};
|
||||
|
||||
pub use async_io_stream::IoStream;
|
||||
use bytes::Bytes;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use event_listener::Event;
|
||||
use flume as mpsc;
|
||||
use futures::{
|
||||
|
@ -114,7 +114,7 @@ impl MuxStreamWrite {
|
|||
}
|
||||
|
||||
self.tx
|
||||
.write_frame(Packet::new_data(self.stream_id, data).into())
|
||||
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
|
||||
.await?;
|
||||
|
||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||
|
@ -348,13 +348,11 @@ impl MuxProtocolExtensionStream {
|
|||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
self.tx
|
||||
.write_frame(Frame::binary(Packet::raw_encode(
|
||||
packet_type,
|
||||
self.stream_id,
|
||||
data,
|
||||
)))
|
||||
.await
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
|
||||
encoded.put_u8(packet_type);
|
||||
encoded.put_u32_le(self.stream_id);
|
||||
encoded.extend(data);
|
||||
self.tx.write_frame(Frame::binary(encoded)).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue