custom wisp transport support

This commit is contained in:
Toshit Chawda 2024-08-16 23:29:33 -07:00
parent 80b68f1cee
commit 16268905fc
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 313 additions and 135 deletions

13
Cargo.lock generated
View file

@ -535,7 +535,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"send_wrapper", "send_wrapper 0.6.0",
"thiserror", "thiserror",
"tokio", "tokio",
"wasm-bindgen", "wasm-bindgen",
@ -726,7 +726,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
dependencies = [ dependencies = [
"gloo-timers", "gloo-timers",
"send_wrapper", "send_wrapper 0.4.0",
] ]
[[package]] [[package]]
@ -1488,6 +1488,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0"
[[package]]
name = "send_wrapper"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73"
dependencies = [
"futures-core",
]
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.204" version = "1.0.204"

View file

@ -2,6 +2,6 @@ build.sh
Cargo.toml Cargo.toml
serve.py serve.py
src src
tests
test.sh
.cargo .cargo
index.html
demo.js

View file

@ -23,7 +23,7 @@ hyper-util-wasm = { git = "https://github.com/r58Playz/hyper-util-wasm", branch
js-sys = "0.3.69" js-sys = "0.3.69"
lazy_static = "1.5.0" lazy_static = "1.5.0"
pin-project-lite = "0.2.14" pin-project-lite = "0.2.14"
send_wrapper = "0.4.0" send_wrapper = { version = "0.6.0", features = ["futures"] }
thiserror = "1.0.61" thiserror = "1.0.61"
tokio = "1.38.0" tokio = "1.38.0"
wasm-bindgen = "0.2.92" wasm-bindgen = "0.2.92"

View file

@ -254,7 +254,6 @@ import initEpoxy, { EpoxyClient, EpoxyClientOptions, EpoxyHandlers, info as epox
} }
total_mux_multi = total_mux_multi / num_outer_tests; total_mux_multi = total_mux_multi / num_outer_tests;
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
} else { } else {
console.time(); console.time();
let resp = await epoxy_client.fetch("https://www.example.com/"); let resp = await epoxy_client.fetch("https://www.example.com/");

View file

@ -111,7 +111,7 @@ pub struct EpoxyUdpStream {
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyUdpStream { impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (mut rx, tx) = stream.into_inner().into_split(); let (mut rx, tx) = stream.into_split();
let EpoxyHandlers { let EpoxyHandlers {
onopen, onopen,

View file

@ -18,21 +18,28 @@ use hyper::{body::Incoming, Uri};
use hyper_util_wasm::client::legacy::Client; use hyper_util_wasm::client::legacy::Client;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use io_stream::{EpoxyIoStream, EpoxyUdpStream}; use io_stream::{EpoxyIoStream, EpoxyUdpStream};
use js_sys::{Array, Function, Object}; use js_sys::{Array, Function, Object, Promise};
use send_wrapper::SendWrapper;
use stream_provider::{StreamProvider, StreamProviderService}; use stream_provider::{StreamProvider, StreamProviderService};
use thiserror::Error; use thiserror::Error;
use utils::{ use utils::{
asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect,
object_get, object_set, object_truthy, IncomingBody, UriExt, WasmExecutor, object_get, object_set, object_truthy, IncomingBody, UriExt, WasmExecutor, WispTransportRead,
WispTransportWrite,
}; };
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use wasm_streams::ReadableStream; use wasm_streams::ReadableStream;
use web_sys::ResponseInit; use web_sys::ResponseInit;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use websocket::EpoxyWebSocket; use websocket::EpoxyWebSocket;
use wisp_mux::CloseReason;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use wisp_mux::StreamType; use wisp_mux::StreamType;
use wisp_mux::{
ws::{WebSocketRead, WebSocketWrite},
CloseReason,
};
use ws_wrapper::WebSocketWrapper;
#[cfg(feature = "full")] #[cfg(feature = "full")]
mod io_stream; mod io_stream;
@ -67,6 +74,15 @@ pub enum EpoxyError {
#[error("Fastwebsockets: {0:?} ({0})")] #[error("Fastwebsockets: {0:?} ({0})")]
FastWebSockets(#[from] fastwebsockets::WebSocketError), FastWebSockets(#[from] fastwebsockets::WebSocketError),
#[error("Custom wisp transport: {0}")]
WispTransport(String),
#[error("Invalid Wisp transport")]
InvalidWispTransport,
#[error("Invalid Wisp transport packet")]
InvalidWispTransportPacket,
#[error("Wisp transport already closed")]
WispTransportClosed,
#[error("Invalid URL scheme")] #[error("Invalid URL scheme")]
InvalidUrlScheme, InvalidUrlScheme,
#[error("No URL host found")] #[error("No URL host found")]
@ -99,6 +115,12 @@ pub enum EpoxyError {
ResponseNewFailed, ResponseNewFailed,
} }
impl EpoxyError {
pub fn wisp_transport(value: JsValue) -> Self {
Self::WispTransport(format!("{:?}", value))
}
}
impl From<EpoxyError> for JsValue { impl From<EpoxyError> for JsValue {
fn from(value: EpoxyError) -> Self { fn from(value: EpoxyError) -> Self {
JsError::from(value).into() JsError::from(value).into()
@ -137,7 +159,7 @@ impl From<InvalidMethod> for EpoxyError {
impl From<CloseReason> for EpoxyError { impl From<CloseReason> for EpoxyError {
fn from(value: CloseReason) -> Self { fn from(value: CloseReason) -> Self {
EpoxyError::WispCloseReason(value) EpoxyError::WispCloseReason(value)
} }
} }
@ -224,13 +246,79 @@ pub struct EpoxyClient {
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyClient { impl EpoxyClient {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new(wisp_url: String, options: EpoxyClientOptions) -> Result<EpoxyClient, EpoxyError> { pub fn new(wisp_url: JsValue, options: EpoxyClientOptions) -> Result<EpoxyClient, EpoxyError> {
let wisp_url: Uri = wisp_url.try_into()?; let stream_provider = if let Some(wisp_url) = wisp_url.as_string() {
if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") { let wisp_uri: Uri = wisp_url.clone().try_into()?;
return Err(EpoxyError::InvalidUrlScheme); if wisp_uri.scheme_str() != Some("wss") && wisp_uri.scheme_str() != Some("ws") {
} return Err(EpoxyError::InvalidUrlScheme);
}
let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), &options)?); let ws_protocols = options.websocket_protocols.clone();
Arc::new(StreamProvider::new(
Box::new(move || {
let wisp_url = wisp_url.clone();
let ws_protocols = ws_protocols.clone();
Box::pin(async move {
let (write, read) = WebSocketWrapper::connect(&wisp_url, &ws_protocols)?;
if !write.wait_for_open().await {
return Err(EpoxyError::WebSocketConnectFailed);
}
Ok((
Box::new(read) as Box<dyn WebSocketRead + Send + Sync>,
Box::new(write) as Box<dyn WebSocketWrite + Send + Sync>,
))
})
}),
&options,
)?)
} else if let Ok(wisp_transport) = wisp_url.dyn_into::<Function>() {
let wisp_transport = SendWrapper::new(wisp_transport);
Arc::new(StreamProvider::new(
Box::new(move || {
let wisp_transport = wisp_transport.clone();
Box::pin(SendWrapper::new(async move {
let transport = wisp_transport
.call0(&JsValue::NULL)
.map_err(EpoxyError::wisp_transport)?;
let transport = match transport.dyn_into::<Promise>() {
Ok(transport) => {
let fut = JsFuture::from(transport);
fut.await.map_err(EpoxyError::wisp_transport)?
}
Err(transport) => transport,
}
.into();
let read = WispTransportRead {
inner: SendWrapper::new(
wasm_streams::ReadableStream::from_raw(
object_get(&transport, "read").into(),
)
.into_stream(),
),
};
let write = WispTransportWrite {
inner: Some(SendWrapper::new(
wasm_streams::WritableStream::from_raw(
object_get(&transport, "write").into(),
)
.into_sink(),
)),
};
Ok((
Box::new(read) as Box<dyn WebSocketRead + Send + Sync>,
Box::new(write) as Box<dyn WebSocketWrite + Send + Sync>,
))
}))
}),
&options,
)?)
} else {
return Err(EpoxyError::InvalidWispTransport);
};
let service = StreamProviderService(stream_provider.clone()); let service = StreamProviderService(stream_provider.clone());
let client = Client::builder(WasmExecutor) let client = Client::builder(WasmExecutor)

View file

@ -1,10 +1,4 @@
use std::{ use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
io::ErrorKind,
ops::{Deref, DerefMut},
pin::Pin,
sync::Arc,
task::Poll,
};
use futures_rustls::{ use futures_rustls::{
rustls::{ClientConfig, RootCertStore}, rustls::{ClientConfig, RootCertStore},
@ -22,10 +16,11 @@ use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS; use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType, ws::{WebSocketRead, WebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
}; };
use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; use crate::{console_log, EpoxyClientOptions, EpoxyError};
lazy_static! { lazy_static! {
static ref CLIENT_CONFIG: Arc<ClientConfig> = { static ref CLIENT_CONFIG: Arc<ClientConfig> = {
@ -38,117 +33,45 @@ lazy_static! {
}; };
} }
pin_project! { pub type ProviderUnencryptedStream = MuxStreamIo;
pub struct CloserWrapper<T> { pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
#[pin] pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub inner: T, pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
pub closer: MuxStreamCloser, pub type ProviderWispTransportGenerator = Box<
} dyn Fn() -> Pin<
} Box<
dyn Future<
impl<T> CloserWrapper<T> { Output = Result<
pub fn new(inner: T, closer: MuxStreamCloser) -> Self { (
Self { inner, closer } Box<dyn WebSocketRead + Sync + Send>,
} Box<dyn WebSocketWrite + Sync + Send>,
),
pub fn into_inner(self) -> T { EpoxyError,
self.inner >,
} > + Sync + Send,
} >,
> + Sync + Send,
impl<T> Deref for CloserWrapper<T> { >;
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for CloserWrapper<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T: AsyncRead> AsyncRead for CloserWrapper<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read_vectored(cx, bufs)
}
}
impl<T: AsyncWrite> AsyncWrite for CloserWrapper<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl From<CloserWrapper<MuxStreamIo>> for CloserWrapper<MuxStreamAsyncRW> {
fn from(value: CloserWrapper<MuxStreamIo>) -> Self {
let CloserWrapper { inner, closer } = value;
CloserWrapper::new(inner.into_asyncrw(), closer)
}
}
pub struct StreamProvider { pub struct StreamProvider {
wisp_url: String, wisp_generator: ProviderWispTransportGenerator,
wisp_v2: bool, wisp_v2: bool,
udp_extension: bool, udp_extension: bool,
websocket_protocols: Vec<String>,
current_client: Arc<Mutex<Option<ClientMux>>>, current_client: Arc<Mutex<Option<ClientMux>>>,
} }
pub type ProviderUnencryptedStream = CloserWrapper<MuxStreamIo>;
pub type ProviderUnencryptedAsyncRW = CloserWrapper<MuxStreamAsyncRW>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
impl StreamProvider { impl StreamProvider {
pub fn new(wisp_url: String, options: &EpoxyClientOptions) -> Result<Self, EpoxyError> { pub fn new(
wisp_generator: ProviderWispTransportGenerator,
options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> {
Ok(Self { Ok(Self {
wisp_url, wisp_generator,
current_client: Arc::new(Mutex::new(None)), current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2, wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required, udp_extension: options.udp_extension_required,
websocket_protocols: options.websocket_protocols.clone(),
}) })
} }
@ -163,10 +86,9 @@ impl StreamProvider {
} else { } else {
None None
}; };
let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?;
if !write.wait_for_open().await { let (read, write) = (self.wisp_generator)().await?;
return Err(EpoxyError::WebSocketConnectFailed);
}
let client = ClientMux::create(read, write, extensions).await?; let client = ClientMux::create(read, write, extensions).await?;
let (mux, fut) = if self.udp_extension { let (mux, fut) = if self.udp_extension {
client.with_udp_extension_required().await? client.with_udp_extension_required().await?
@ -196,8 +118,7 @@ impl StreamProvider {
let locked = self.current_client.lock().await; let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() { if let Some(mux) = locked.as_ref() {
let stream = mux.client_new_stream(stream_type, host, port).await?; let stream = mux.client_new_stream(stream_type, host, port).await?;
let closer = stream.get_close_handle(); Ok(stream.into_io())
Ok(CloserWrapper::new(stream.into_io(), closer))
} else { } else {
self.create_client(locked).await?; self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await self.get_stream(stream_type, host, port).await
@ -212,7 +133,10 @@ impl StreamProvider {
host: String, host: String,
port: u16, port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> { ) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self.get_stream(stream_type, host, port).await?.into()) Ok(self
.get_stream(stream_type, host, port)
.await?
.into_asyncrw())
} }
pub async fn get_tls_stream( pub async fn get_tls_stream(
@ -233,7 +157,7 @@ impl StreamProvider {
Err((err, stream)) => { Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) { if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error? // maybe actually a wisp error?
if let Some(reason) = stream.closer.get_close_reason() { if let Some(reason) = stream.get_close_reason() {
return Err(reason.into()); return Err(reason.into());
} }
} }

View file

@ -3,13 +3,20 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use async_trait::async_trait;
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
use futures_util::{ready, AsyncRead, Future, Stream, TryStreamExt}; use futures_util::{ready, AsyncRead, Future, SinkExt, Stream, StreamExt, TryStreamExt};
use http::{HeaderValue, Uri}; use http::{HeaderValue, Uri};
use hyper::{body::Body, rt::Executor}; use hyper::{body::Body, rt::Executor};
use js_sys::{Array, JsString, Object, Uint8Array}; use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use send_wrapper::SendWrapper;
use wasm_bindgen::{prelude::*, JsCast, JsValue}; use wasm_bindgen::{prelude::*, JsCast, JsValue};
use wasm_streams::{readable::IntoStream, writable::IntoSink};
use wisp_mux::{
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
WispError,
};
use crate::EpoxyError; use crate::EpoxyError;
@ -168,6 +175,64 @@ impl<R: AsyncRead> Stream for ReaderStream<R> {
} }
} }
pub struct WispTransportRead {
pub inner: SendWrapper<IntoStream<'static>>,
}
#[async_trait]
impl WebSocketRead for WispTransportRead {
async fn wisp_read_frame(
&mut self,
_tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, wisp_mux::WispError> {
let obj = self.inner.next().await;
if let Some(pkt) = obj {
let pkt =
pkt.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))?;
let arr: ArrayBuffer = pkt.dyn_into().map_err(|_| {
WispError::WsImplError(Box::new(EpoxyError::InvalidWispTransportPacket))
})?;
Ok(Frame::binary(Payload::Bytes(
Uint8Array::new(&arr).to_vec().as_slice().into(),
)))
} else {
Ok(Frame::close(Payload::Borrowed(&[])))
}
}
}
pub struct WispTransportWrite {
pub inner: Option<SendWrapper<IntoSink<'static>>>,
}
#[async_trait]
impl WebSocketWrite for WispTransportWrite {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
SendWrapper::new(
self.inner
.as_mut()
.ok_or_else(|| WispError::WsImplError(Box::new(EpoxyError::WispTransportClosed)))?
.send(Uint8Array::from(frame.payload.as_ref()).into()),
)
.await
.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
SendWrapper::new(
self.inner
.take()
.ok_or_else(|| WispError::WsImplError(Box::new(EpoxyError::WispTransportClosed)))?
.take()
.abort(),
)
.await
.map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x))))
}
}
pub fn is_redirect(code: u16) -> bool { pub fn is_redirect(code: u16) -> bool {
[301, 302, 303, 307, 308].contains(&code) [301, 302, 303, 307, 308].contains(&code)
} }

View file

@ -93,6 +93,8 @@ impl MuxStreamRead {
/// Turn the read half into one that implements futures `Stream`, consuming it. /// Turn the read half into one that implements futures `Stream`, consuming it.
pub fn into_stream(self) -> MuxStreamIoStream { pub fn into_stream(self) -> MuxStreamIoStream {
MuxStreamIoStream { MuxStreamIoStream {
close_reason: self.close_reason.clone(),
is_closed: self.is_closed.clone(),
rx: self.into_inner_stream(), rx: self.into_inner_stream(),
} }
} }
@ -246,6 +248,8 @@ impl MuxStreamWrite {
/// Turn the write half into one that implements futures `Sink`, consuming it. /// Turn the write half into one that implements futures `Sink`, consuming it.
pub fn into_sink(self) -> MuxStreamIoSink { pub fn into_sink(self) -> MuxStreamIoSink {
MuxStreamIoSink { MuxStreamIoSink {
close_reason: self.close_reason.clone(),
is_closed: self.is_closed.clone(),
tx: self.into_inner_sink(), tx: self.into_inner_sink(),
} }
} }
@ -352,6 +356,11 @@ impl MuxStream {
self.tx.get_protocol_extension_stream() self.tx.get_protocol_extension_stream()
} }
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
self.rx.get_close_reason()
}
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await self.tx.close(reason).await
@ -455,6 +464,11 @@ impl MuxStreamIo {
} }
} }
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
self.rx.get_close_reason()
}
/// Split the stream into read and write parts, consuming it. /// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) { pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) {
(self.rx, self.tx) (self.rx, self.tx)
@ -489,6 +503,8 @@ pin_project! {
pub struct MuxStreamIoStream { pub struct MuxStreamIoStream {
#[pin] #[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>, rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
} }
} }
@ -497,6 +513,15 @@ impl MuxStreamIoStream {
pub fn into_asyncread(self) -> MuxStreamAsyncRead { pub fn into_asyncread(self) -> MuxStreamAsyncRead {
MuxStreamAsyncRead::new(self) MuxStreamAsyncRead::new(self)
} }
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
}
}
} }
impl Stream for MuxStreamIoStream { impl Stream for MuxStreamIoStream {
@ -511,6 +536,8 @@ pin_project! {
pub struct MuxStreamIoSink { pub struct MuxStreamIoSink {
#[pin] #[pin]
tx: Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>>, tx: Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
} }
} }
@ -519,6 +546,15 @@ impl MuxStreamIoSink {
pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite { pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite {
MuxStreamAsyncWrite::new(self) MuxStreamAsyncWrite::new(self)
} }
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
}
}
} }
impl Sink<&[u8]> for MuxStreamIoSink { impl Sink<&[u8]> for MuxStreamIoSink {
@ -560,6 +596,11 @@ pin_project! {
} }
impl MuxStreamAsyncRW { impl MuxStreamAsyncRW {
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
self.rx.get_close_reason()
}
/// Split the stream into read and write parts, consuming it. /// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) {
(self.rx, self.tx) (self.rx, self.tx)
@ -617,15 +658,26 @@ pin_project! {
pub struct MuxStreamAsyncRead { pub struct MuxStreamAsyncRead {
#[pin] #[pin]
rx: IntoAsyncRead<MuxStreamIoStream>, rx: IntoAsyncRead<MuxStreamIoStream>,
// state: Option<MuxStreamAsyncReadState> is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
} }
} }
impl MuxStreamAsyncRead { impl MuxStreamAsyncRead {
pub(crate) fn new(stream: MuxStreamIoStream) -> Self { pub(crate) fn new(stream: MuxStreamIoStream) -> Self {
Self { Self {
is_closed: stream.is_closed.clone(),
close_reason: stream.close_reason.clone(),
rx: stream.into_async_read(), rx: stream.into_async_read(),
// state: None, }
}
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
} }
} }
} }
@ -664,6 +716,11 @@ impl MuxStreamAsyncWrite {
error: None, error: None,
} }
} }
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
self.tx.get_close_reason()
}
} }
impl AsyncWrite for MuxStreamAsyncWrite { impl AsyncWrite for MuxStreamAsyncWrite {

View file

@ -166,6 +166,23 @@ pub trait WebSocketRead {
} }
} }
#[async_trait]
impl WebSocketRead for Box<dyn WebSocketRead + Send + Sync> {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
self.as_mut().wisp_read_frame(tx).await
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
self.as_mut().wisp_read_split(tx).await
}
}
/// Generic WebSocket write trait. /// Generic WebSocket write trait.
#[async_trait] #[async_trait]
pub trait WebSocketWrite { pub trait WebSocketWrite {
@ -188,6 +205,25 @@ pub trait WebSocketWrite {
} }
} }
#[async_trait]
impl WebSocketWrite for Box<dyn WebSocketWrite + Send + Sync> {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
self.as_mut().wisp_write_frame(frame).await
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.as_mut().wisp_close().await
}
async fn wisp_write_split(
&mut self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
self.as_mut().wisp_write_split(header, body).await
}
}
/// Locked WebSocket. /// Locked WebSocket.
#[derive(Clone)] #[derive(Clone)]
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>); pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);