add ability to send protocol extension packets

This commit is contained in:
Toshit Chawda 2024-04-16 21:57:27 -07:00
parent fd94f1245a
commit 6c41c54cf9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 84 additions and 33 deletions

View file

@ -253,7 +253,7 @@ async fn accept_http(
}
}
async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool, WispError> {
async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
@ -318,8 +318,8 @@ async fn accept_ws(
println!("{:?}: connected", addr);
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128
let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
let (mux, fut) = if mux_options.enforce_auth {
let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
if !mux
.supported_extension_ids
.iter()
@ -354,7 +354,7 @@ async fn accept_ws(
}
});
while let Some((packet, mut stream)) = mux.server_new_stream().await {
while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
if (mux_options.block_non_http
&& !(packet.destination_port == 80 || packet.destination_port == 443))
@ -386,8 +386,8 @@ async fn accept_ws(
}
}
}
let mut close_err = stream.get_close_handle();
let mut close_ok = stream.get_close_handle();
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream)
.or_else(|err| async move {
let _ = close_err.close(CloseReason::Unexpected).await;

View file

@ -164,7 +164,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
extensions.push(Box::new(auth));
}
let (mut mux, fut) = if opts.wisp_v1 {
let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await?
} else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
@ -212,7 +212,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let start_time = Instant::now();
for _ in 0..opts.streams {
let (mut cr, mut cw) = mux
let (cr, cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?
.into_split();

View file

@ -272,6 +272,9 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::SendBytes(packet, channel) => {
let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await);
}
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let stream_id = next_free_stream_id;
@ -552,11 +555,11 @@ impl ServerMux {
}
/// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.recv_async().await.ok()
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send_async(WsEvent::EndFut(reason))
.await
@ -567,7 +570,7 @@ impl ServerMux {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> {
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
@ -575,7 +578,7 @@ impl ServerMux {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}
@ -696,7 +699,7 @@ impl ClientMux {
/// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream(
&mut self,
&self,
stream_type: StreamType,
host: String,
port: u16,
@ -717,7 +720,7 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.stream_tx
.send_async(WsEvent::EndFut(reason))
.await
@ -728,7 +731,7 @@ impl ClientMux {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&mut self) -> Result<(), WispError> {
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
@ -736,7 +739,7 @@ impl ClientMux {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}

View file

@ -362,6 +362,14 @@ impl Packet {
}
}
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes {
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(stream_id);
encoded.extend(bytes);
encoded.freeze()
}
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType as P;
Ok(Self {
@ -494,13 +502,11 @@ impl TryFrom<Bytes> for Packet {
impl From<Packet> for Bytes {
fn from(packet: Packet) -> Self {
let inner_u8 = packet.packet_type.as_u8();
let inner = Bytes::from(packet.packet_type);
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len());
encoded.put_u8(inner_u8);
encoded.put_u32_le(packet.stream_id);
encoded.extend(inner);
encoded.freeze()
Packet::raw_encode(
packet.packet_type.as_u8(),
packet.stream_id,
packet.packet_type.into(),
)
}
}

View file

@ -21,6 +21,7 @@ use std::{
pub(crate) enum WsEvent {
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
SendBytes(Bytes, oneshot::Sender<Result<(), WispError>>),
Close(Packet, oneshot::Sender<Result<(), WispError>>),
CreateStream(
StreamType,
@ -49,7 +50,7 @@ pub struct MuxStreamRead {
impl MuxStreamRead {
/// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> {
pub async fn read(&self) -> Option<Bytes> {
if self.is_closed.load(Ordering::Acquire) {
return None;
}
@ -79,7 +80,7 @@ impl MuxStreamRead {
}
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move {
Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await?, rx))
}))
}
@ -100,7 +101,7 @@ pub struct MuxStreamWrite {
impl MuxStreamWrite {
/// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -147,8 +148,17 @@ impl MuxStreamWrite {
}
}
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: self.stream_id,
tx: self.tx.clone(),
is_closed: self.is_closed.clone(),
}
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -171,12 +181,12 @@ impl MuxStreamWrite {
let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold(
self,
|mut tx, data| async move {
|tx, data| async move {
tx.write(data).await?;
Ok(tx)
},
handle,
move |mut handle| async {
move |handle| async {
handle.close(CloseReason::Unknown).await?;
Ok(handle)
},
@ -246,12 +256,12 @@ impl MuxStream {
}
/// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> {
pub async fn read(&self) -> Option<Bytes> {
self.rx.read().await
}
/// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await
}
@ -270,8 +280,13 @@ impl MuxStream {
self.tx.get_close_handle()
}
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
self.tx.get_protocol_extension_stream()
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await
}
@ -300,7 +315,7 @@ pub struct MuxStreamCloser {
impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -320,6 +335,33 @@ impl MuxStreamCloser {
}
}
/// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream {
/// ID of the stream.
pub stream_id: u32,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
}
impl MuxProtocolExtensionStream {
/// Send a protocol extension packet.
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send_async(WsEvent::SendBytes(
Packet::raw_encode(packet_type, self.stream_id, data),
tx,
))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(())
}
}
pin_project! {
/// Multiplexor stream that implements futures `Stream + Sink`.
pub struct MuxStreamIo {