add autoreconnect, wisp_mux 1.2.0

This commit is contained in:
Toshit Chawda 2024-03-08 22:40:15 -08:00
parent 5b4fb1392a
commit a8709255b2
20 changed files with 404 additions and 333 deletions

View file

@ -28,11 +28,10 @@ impl From<Frame<'_>> for crate::ws::Frame {
}
}
impl TryFrom<crate::ws::Frame> for Frame<'_> {
type Error = crate::WispError;
fn try_from(frame: crate::ws::Frame) -> Result<Self, Self::Error> {
impl From<crate::ws::Frame> for Frame<'_> {
fn from(frame: crate::ws::Frame) -> Self {
use crate::ws::OpCode::*;
Ok(match frame.opcode {
match frame.opcode {
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
@ -43,13 +42,17 @@ impl TryFrom<crate::ws::Frame> for Frame<'_> {
Payload::Owned(frame.payload.to_vec()),
),
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
})
}
}
}
impl From<WebSocketError> for crate::WispError {
fn from(err: WebSocketError) -> Self {
Self::WsImplError(Box::new(err))
if let WebSocketError::ConnectionClosed = err {
Self::WsImplSocketClosed
} else {
Self::WsImplError(Box::new(err))
}
}
}
@ -67,6 +70,8 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.try_into()?).await.map_err(|e| e.into())
self.write_frame(frame.into())
.await
.map_err(|e| e.into())
}
}

View file

@ -11,12 +11,6 @@ mod fastwebsockets;
mod packet;
mod sink_unfold;
mod stream;
#[cfg(feature = "hyper_tower")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper_tower")))]
pub mod tokioio;
#[cfg(feature = "hyper_tower")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper_tower")))]
pub mod tower;
pub mod ws;
pub use crate::packet::*;
@ -140,10 +134,10 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.server_close_loop(close_rx).fuse() => x,
x = self.server_bg_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
};
self.stream_map.lock().await.iter().for_each(|x| {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown)));
@ -151,7 +145,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
ret
}
async fn server_close_loop(
async fn server_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
@ -168,6 +162,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
@ -186,66 +181,62 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
loop {
let packet: Packet = rx.wisp_read_frame(&self.tx).await?.try_into()?;
use PacketType::*;
match packet.packet_type {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
self.stream_map.lock().await.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
},
self.stream_map.lock().await.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
},
);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
flow_control_event,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
stream.flow_control.store(
stream
.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
flow_control_event,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
stream.flow_control.store(
stream.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
}
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
} else {
break;
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
drop(muxstream_sender);
Ok(())
}
}
@ -272,6 +263,8 @@ pub struct ServerMux<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
}
@ -290,7 +283,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self { muxstream_recv: rx },
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
stream_map: map.clone(),
},
ServerMuxInner {
tx: write,
close_tx,
@ -304,6 +301,19 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
self.muxstream_recv.next().await
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
}
struct ClientMuxInner<W>
@ -346,6 +356,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
@ -355,10 +366,11 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
where
R: ws::WebSocketRead,
{
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
match packet.packet_type {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
@ -382,7 +394,6 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
}
}
}
Ok(())
}
}
@ -425,7 +436,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet {
if let PacketType::Continue(packet) = first_packet.packet_type {
let (tx, rx) = mpsc::unbounded::<WsEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
Ok((
@ -487,4 +498,17 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
evt,
))
}
/// Close all streams.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
}

View file

@ -242,8 +242,8 @@ impl From<PacketType> for Vec<u8> {
pub struct Packet {
/// Stream this packet is associated with.
pub stream_id: u32,
/// Packet recieved.
pub packet: PacketType,
/// Packet type recieved.
pub packet_type: PacketType,
}
impl Packet {
@ -251,7 +251,7 @@ impl Packet {
///
/// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self {
Self { stream_id, packet }
Self { stream_id, packet_type: packet }
}
/// Create a new connect packet.
@ -263,7 +263,7 @@ impl Packet {
) -> Self {
Self {
stream_id,
packet: PacketType::Connect(ConnectPacket::new(
packet_type: PacketType::Connect(ConnectPacket::new(
stream_type,
destination_port,
destination_hostname,
@ -275,7 +275,7 @@ impl Packet {
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
Self {
stream_id,
packet: PacketType::Data(data),
packet_type: PacketType::Data(data),
}
}
@ -283,7 +283,7 @@ impl Packet {
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
Self {
stream_id,
packet: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
}
}
@ -291,7 +291,7 @@ impl Packet {
pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
Self {
stream_id,
packet: PacketType::Close(ClosePacket::new(reason)),
packet_type: PacketType::Close(ClosePacket::new(reason)),
}
}
}
@ -306,7 +306,7 @@ impl TryFrom<Bytes> for Packet {
use PacketType::*;
Ok(Self {
stream_id: bytes.get_u32_le(),
packet: match packet_type {
packet_type: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?),
@ -320,9 +320,9 @@ impl TryFrom<Bytes> for Packet {
impl From<Packet> for Vec<u8> {
fn from(packet: Packet) -> Self {
let mut encoded = Self::with_capacity(1 + 4);
encoded.push(packet.packet.as_u8());
encoded.push(packet.packet_type.as_u8());
encoded.put_u32_le(packet.stream_id);
encoded.extend(Vec::<u8>::from(packet.packet));
encoded.extend(Vec::<u8>::from(packet.packet_type));
encoded
}
}

View file

@ -26,6 +26,7 @@ pub enum MuxEvent {
pub(crate) enum WsEvent {
Close(u32, crate::CloseReason, oneshot::Sender<Result<(), crate::WispError>>),
EndFut,
}
/// Read side of a multiplexor stream.

View file

@ -1,175 +0,0 @@
#![allow(dead_code)]
//! hyper_util::rt::tokio::TokioIo
use std::{
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
pin_project! {
/// A wrapping implementing hyper IO traits for a type that
/// implements Tokio's IO traits.
#[derive(Debug)]
pub struct TokioIo<T> {
#[pin]
inner: T,
}
}
impl<T> TokioIo<T> {
/// Wrap a type implementing Tokio's IO traits.
pub fn new(inner: T) -> Self {
Self { inner }
}
/// Borrow the inner type.
pub fn inner(&self) -> &T {
&self.inner
}
/// Mut borrow the inner type.
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
/// Consume this wrapper and get the inner type.
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> tokio::io::AsyncRead for TokioIo<T>
where
T: hyper::rt::Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len();
let filled = tbuf.filled().len();
let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(),
other => return other,
}
};
let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized.
let n_init = sub_filled;
unsafe {
tbuf.assume_init(n_init);
tbuf.set_filled(n_filled);
}
Poll::Ready(Ok(()))
}
}
impl<T> tokio::io::AsyncWrite for TokioIo<T>
where
T: hyper::rt::Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
hyper::rt::Write::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> hyper_util_wasm::client::legacy::connect::Connection for TokioIo<T> {
fn connected(&self) -> hyper_util_wasm::client::legacy::connect::Connected {
hyper_util_wasm::client::legacy::connect::Connected::new()
}
}

View file

@ -1,43 +0,0 @@
//! Helper that implements a Tower Service for a client multiplexor.
use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError};
use async_io_stream::IoStream;
use futures::{
task::{Context, Poll},
Future,
};
use std::sync::Arc;
/// Wrapper struct that implements a Tower Service sfor a client multiplexor.
pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>);
impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> {
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Uri) -> Self::Future {
let mux = self.0.clone();
async move {
Ok(TokioIo::new(
mux.client_new_stream(
StreamType::Tcp,
req.host().ok_or(WispError::UriHasNoHost)?.to_string(),
req.port().ok_or(WispError::UriHasNoPort)?.into(),
)
.await?
.into_io()
.into_asyncrw(),
))
}
}
}
impl<W: WebSocketWrite + Send + 'static> Clone for ServiceWrapper<W> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}