add ability to send protocol extension packets

This commit is contained in:
Toshit Chawda 2024-04-16 21:57:27 -07:00
parent fd94f1245a
commit 6c41c54cf9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 84 additions and 33 deletions

View file

@ -253,7 +253,7 @@ async fn accept_http(
} }
} }
async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool, WispError> { async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> {
let uri = format!( let uri = format!(
"{}:{}", "{}:{}",
packet.destination_hostname, packet.destination_port packet.destination_hostname, packet.destination_port
@ -318,8 +318,8 @@ async fn accept_ws(
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128 // size is set to 128
let (mut mux, fut) = if mux_options.enforce_auth { let (mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?; let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
if !mux if !mux
.supported_extension_ids .supported_extension_ids
.iter() .iter()
@ -354,7 +354,7 @@ async fn accept_ws(
} }
}); });
while let Some((packet, mut stream)) = mux.server_new_stream().await { while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move { tokio::spawn(async move {
if (mux_options.block_non_http if (mux_options.block_non_http
&& !(packet.destination_port == 80 || packet.destination_port == 443)) && !(packet.destination_port == 80 || packet.destination_port == 443))
@ -386,8 +386,8 @@ async fn accept_ws(
} }
} }
} }
let mut close_err = stream.get_close_handle(); let close_err = stream.get_close_handle();
let mut close_ok = stream.get_close_handle(); let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream) let _ = handle_mux(packet, stream)
.or_else(|err| async move { .or_else(|err| async move {
let _ = close_err.close(CloseReason::Unexpected).await; let _ = close_err.close(CloseReason::Unexpected).await;

View file

@ -164,7 +164,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
extensions.push(Box::new(auth)); extensions.push(Box::new(auth));
} }
let (mut mux, fut) = if opts.wisp_v1 { let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await? ClientMux::new(rx, tx, None).await?
} else { } else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await? ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
@ -212,7 +212,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let start_time = Instant::now(); let start_time = Instant::now();
for _ in 0..opts.streams { for _ in 0..opts.streams {
let (mut cr, mut cw) = mux let (cr, cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await? .await?
.into_split(); .into_split();

View file

@ -272,6 +272,9 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
} }
WsEvent::SendBytes(packet, channel) => {
let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await);
}
WsEvent::CreateStream(stream_type, host, port, channel) => { WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async { let ret: Result<MuxStream, WispError> = async {
let stream_id = next_free_stream_id; let stream_id = next_free_stream_id;
@ -552,11 +555,11 @@ impl ServerMux {
} }
/// Wait for a stream to be created. /// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.recv_async().await.ok() self.muxstream_recv.recv_async().await.ok()
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx self.close_tx
.send_async(WsEvent::EndFut(reason)) .send_async(WsEvent::EndFut(reason))
.await .await
@ -567,7 +570,7 @@ impl ServerMux {
/// ///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called. /// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> { pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await self.close_internal(None).await
} }
@ -575,7 +578,7 @@ impl ServerMux {
/// ///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after /// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called. /// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions)) self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await .await
} }
@ -696,7 +699,7 @@ impl ClientMux {
/// Create a new stream, multiplexed through Wisp. /// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream( pub async fn client_new_stream(
&mut self, &self,
stream_type: StreamType, stream_type: StreamType,
host: String, host: String,
port: u16, port: u16,
@ -717,7 +720,7 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
} }
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.stream_tx self.stream_tx
.send_async(WsEvent::EndFut(reason)) .send_async(WsEvent::EndFut(reason))
.await .await
@ -728,7 +731,7 @@ impl ClientMux {
/// ///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this /// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function. /// function.
pub async fn close(&mut self) -> Result<(), WispError> { pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await self.close_internal(None).await
} }
@ -736,7 +739,7 @@ impl ClientMux {
/// ///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this /// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function. /// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions)) self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await .await
} }

View file

@ -362,6 +362,14 @@ impl Packet {
} }
} }
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes {
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(stream_id);
encoded.extend(bytes);
encoded.freeze()
}
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> { fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType as P; use PacketType as P;
Ok(Self { Ok(Self {
@ -494,13 +502,11 @@ impl TryFrom<Bytes> for Packet {
impl From<Packet> for Bytes { impl From<Packet> for Bytes {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
let inner_u8 = packet.packet_type.as_u8(); Packet::raw_encode(
let inner = Bytes::from(packet.packet_type); packet.packet_type.as_u8(),
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len()); packet.stream_id,
encoded.put_u8(inner_u8); packet.packet_type.into(),
encoded.put_u32_le(packet.stream_id); )
encoded.extend(inner);
encoded.freeze()
} }
} }

View file

@ -21,6 +21,7 @@ use std::{
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>), SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
SendBytes(Bytes, oneshot::Sender<Result<(), WispError>>),
Close(Packet, oneshot::Sender<Result<(), WispError>>), Close(Packet, oneshot::Sender<Result<(), WispError>>),
CreateStream( CreateStream(
StreamType, StreamType,
@ -49,7 +50,7 @@ pub struct MuxStreamRead {
impl MuxStreamRead { impl MuxStreamRead {
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> { pub async fn read(&self) -> Option<Bytes> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return None; return None;
} }
@ -79,7 +80,7 @@ impl MuxStreamRead {
} }
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> { pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move { Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await?, rx)) Some((rx.read().await?, rx))
})) }))
} }
@ -100,7 +101,7 @@ pub struct MuxStreamWrite {
impl MuxStreamWrite { impl MuxStreamWrite {
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -147,8 +148,17 @@ impl MuxStreamWrite {
} }
} }
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: self.stream_id,
tx: self.tx.clone(),
is_closed: self.is_closed.clone(),
}
}
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -171,12 +181,12 @@ impl MuxStreamWrite {
let handle = self.get_close_handle(); let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold( Box::pin(sink_unfold::unfold(
self, self,
|mut tx, data| async move { |tx, data| async move {
tx.write(data).await?; tx.write(data).await?;
Ok(tx) Ok(tx)
}, },
handle, handle,
move |mut handle| async { move |handle| async {
handle.close(CloseReason::Unknown).await?; handle.close(CloseReason::Unknown).await?;
Ok(handle) Ok(handle)
}, },
@ -246,12 +256,12 @@ impl MuxStream {
} }
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> { pub async fn read(&self) -> Option<Bytes> {
self.rx.read().await self.rx.read().await
} }
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
@ -270,8 +280,13 @@ impl MuxStream {
self.tx.get_close_handle() self.tx.get_close_handle()
} }
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
self.tx.get_protocol_extension_stream()
}
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await self.tx.close(reason).await
} }
@ -300,7 +315,7 @@ pub struct MuxStreamCloser {
impl MuxStreamCloser { impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -320,6 +335,33 @@ impl MuxStreamCloser {
} }
} }
/// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream {
/// ID of the stream.
pub stream_id: u32,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
}
impl MuxProtocolExtensionStream {
/// Send a protocol extension packet.
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send_async(WsEvent::SendBytes(
Packet::raw_encode(packet_type, self.stream_id, data),
tx,
))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(())
}
}
pin_project! { pin_project! {
/// Multiplexor stream that implements futures `Stream + Sink`. /// Multiplexor stream that implements futures `Stream + Sink`.
pub struct MuxStreamIo { pub struct MuxStreamIo {