use knockoff dynosaur to remove async_trait on wsr/wsw

This commit is contained in:
Toshit Chawda 2024-11-23 15:00:12 -08:00
parent 5e54465e58
commit 9129d767f8
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
31 changed files with 692 additions and 258 deletions

9
Cargo.lock generated
View file

@ -698,7 +698,6 @@ name = "epoxy-client"
version = "2.1.15"
dependencies = [
"async-compression",
"async-trait",
"bytes",
"cfg-if",
"event-listener",
@ -755,6 +754,7 @@ dependencies = [
"libc",
"log",
"nix",
"pin-project-lite",
"pty-process",
"regex",
"rustls-pemfile",
@ -1850,6 +1850,12 @@ dependencies = [
"quick-error",
]
[[package]]
name = "reusable-box-future"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9"
[[package]]
name = "ring"
version = "0.17.8"
@ -3015,6 +3021,7 @@ dependencies = [
"getrandom",
"nohash-hasher",
"pin-project-lite",
"reusable-box-future",
"thiserror",
"tokio",
]

View file

@ -8,7 +8,6 @@ crate-type = ["cdylib"]
[dependencies]
async-compression = { version = "0.4.12", features = ["futures-io", "gzip", "brotli"], optional = true }
async-trait = "0.1.81"
bytes = "1.7.1"
cfg-if = "1.0.0"
event-listener = "5.3.1"

View file

@ -1,5 +1,5 @@
#![feature(let_chains, impl_trait_in_assoc_type)]
use std::{error::Error, str::FromStr, sync::Arc};
use std::{error::Error, pin::Pin, str::FromStr, sync::Arc};
#[cfg(feature = "full")]
use async_compression::futures::bufread as async_comp;
@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut};
use cfg_if::cfg_if;
#[cfg(feature = "full")]
use futures_util::future::Either;
use futures_util::{StreamExt, TryStreamExt};
use futures_util::{Stream, StreamExt, TryStreamExt};
use http::{
header::{
InvalidHeaderName, InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH,
@ -41,7 +41,7 @@ use websocket::EpoxyWebSocket;
use wisp_mux::StreamType;
use wisp_mux::{
generic::GenericWebSocketRead,
ws::{WebSocketRead, WebSocketWrite},
ws::{EitherWebSocketRead, EitherWebSocketWrite},
CloseReason,
};
use ws_wrapper::WebSocketWrapper;
@ -343,7 +343,7 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator {
}
.into();
let read = GenericWebSocketRead::new(SendWrapper::new(
let read = GenericWebSocketRead::new(Box::pin(SendWrapper::new(
wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into())
.into_stream()
.map(|x| {
@ -355,15 +355,16 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator {
Uint8Array::new(&arr).to_vec().as_slice(),
))
}),
));
))
as Pin<Box<dyn Stream<Item = Result<BytesMut, EpoxyError>> + Send>>);
let write: WritableStream = object_get(&transport, "write").into();
let write = WispTransportWrite {
inner: SendWrapper::new(write.get_writer().map_err(EpoxyError::wisp_transport)?),
};
Ok((
Box::new(read) as Box<dyn WebSocketRead + Send>,
Box::new(write) as Box<dyn WebSocketWrite + Send>,
EitherWebSocketRead::Right(read),
EitherWebSocketWrite::Right(write),
))
}))
})
@ -421,8 +422,8 @@ impl EpoxyClient {
}
}
Ok((
Box::new(read) as Box<dyn WebSocketRead + Send>,
Box::new(write) as Box<dyn WebSocketWrite + Send>,
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
))
})
}),

View file

@ -1,5 +1,6 @@
use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
use bytes::BytesMut;
use cfg_if::cfg_if;
use futures_rustls::{
rustls::{ClientConfig, RootCertStore},
@ -8,7 +9,7 @@ use futures_rustls::{
use futures_util::{
future::Either,
lock::{Mutex, MutexGuard},
AsyncRead, AsyncWrite, Future,
AsyncRead, AsyncWrite, Future, Stream,
};
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
use pin_project_lite::pin_project;
@ -16,18 +17,30 @@ use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
ws::{WebSocketRead, WebSocketWrite},
generic::GenericWebSocketRead,
ws::{EitherWebSocketRead, EitherWebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake,
};
use crate::{
console_error, console_log, utils::{IgnoreCloseNotify, NoCertificateVerification}, EpoxyClientOptions, EpoxyError
console_error, console_log,
utils::{IgnoreCloseNotify, NoCertificateVerification, WispTransportWrite},
ws_wrapper::{WebSocketReader, WebSocketWrapper},
EpoxyClientOptions, EpoxyError,
};
pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderTlsAsyncRW = IgnoreCloseNotify;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
pub type ProviderWispTransportRead = EitherWebSocketRead<
WebSocketReader,
GenericWebSocketRead<
Pin<Box<dyn Stream<Item = Result<BytesMut, EpoxyError>> + Send>>,
EpoxyError,
>,
>;
pub type ProviderWispTransportWrite = EitherWebSocketWrite<WebSocketWrapper, WispTransportWrite>;
pub type ProviderWispTransportGenerator = Box<
dyn Fn(
bool,
@ -35,10 +48,7 @@ pub type ProviderWispTransportGenerator = Box<
Box<
dyn Future<
Output = Result<
(
Box<dyn WebSocketRead + Send>,
Box<dyn WebSocketWrite + Send>,
),
(ProviderWispTransportRead, ProviderWispTransportWrite),
EpoxyError,
>,
> + Sync
@ -54,7 +64,7 @@ pub struct StreamProvider {
wisp_v2: bool,
udp_extension: bool,
current_client: Arc<Mutex<Option<ClientMux>>>,
current_client: Arc<Mutex<Option<ClientMux<ProviderWispTransportWrite>>>>,
h2_config: Arc<ClientConfig>,
client_config: Arc<ClientConfig>,
@ -115,7 +125,7 @@ impl StreamProvider {
async fn create_client(
&self,
mut locked: MutexGuard<'_, Option<ClientMux>>,
mut locked: MutexGuard<'_, Option<ClientMux<ProviderWispTransportWrite>>>,
) -> Result<(), EpoxyError> {
let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
vec![AnyProtocolExtensionBuilder::new(
@ -140,7 +150,11 @@ impl StreamProvider {
spawn_local(async move {
match fut.await {
Ok(_) => console_log!("epoxy: wisp multiplexor task ended successfully"),
Err(x) => console_error!("epoxy: wisp multiplexor task ended with an error: {} {:?}", x, x),
Err(x) => console_error!(
"epoxy: wisp multiplexor task ended with an error: {} {:?}",
x,
x
),
}
current_client.lock().await.take();
});

View file

@ -1,4 +1,7 @@
use std::{pin::Pin, task::{Context, Poll}};
use std::{
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use futures_util::{AsyncRead, Stream, StreamExt, TryStreamExt};
@ -12,7 +15,6 @@ use crate::{console_error, EpoxyError};
use super::ReaderStream;
#[wasm_bindgen(inline_js = r#"
export function ws_protocol() {
return (

View file

@ -8,7 +8,6 @@ use std::{
task::{Context, Poll},
};
use async_trait::async_trait;
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
use futures_util::{ready, AsyncRead, Future, Stream};
use http::{HeaderValue, Uri};
@ -179,7 +178,6 @@ pub struct WispTransportWrite {
pub inner: SendWrapper<WritableStreamDefaultWriter>,
}
#[async_trait]
impl WebSocketWrite for WispTransportWrite {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
SendWrapper::new(async {

View file

@ -1,5 +1,8 @@
use std::{
io::ErrorKind, pin::Pin, sync::Arc, task::{Context, Poll}
io::ErrorKind,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_rustls::{

View file

@ -3,7 +3,6 @@ use std::sync::{
Arc,
};
use async_trait::async_trait;
use bytes::BytesMut;
use event_listener::Event;
use flume::Receiver;
@ -14,7 +13,7 @@ use thiserror::Error;
use wasm_bindgen::{closure::Closure, JsCast, JsValue};
use web_sys::{BinaryType, MessageEvent, WebSocket};
use wisp_mux::{
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{Frame, LockingWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
WispError,
};
@ -66,11 +65,10 @@ pub struct WebSocketReader {
close_event: Arc<Event>,
}
#[async_trait]
impl WebSocketRead for WebSocketReader {
async fn wisp_read_frame(
&mut self,
_: &LockedWebSocketWrite,
_: &dyn LockingWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
use WebSocketMessage as M;
if self.closed.load(Ordering::Acquire) {
@ -185,7 +183,6 @@ impl WebSocketWrapper {
}
}
#[async_trait]
impl WebSocketWrite for WebSocketWrapper {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
use wisp_mux::ws::OpCode::*;

View file

@ -24,6 +24,7 @@ lazy_static = "1.5.0"
libc = { version = "0.2.158", optional = true }
log = { version = "0.4.22", features = ["serde", "std"] }
nix = { version = "0.29.0", features = ["term"] }
pin-project-lite = "0.2.15"
pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true }
regex = "1.10.6"
rustls-pemfile = "2.1.3"

View file

@ -25,7 +25,7 @@ use wisp_mux::{
};
use crate::{
route::WispResult,
route::{WispResult, WispStreamWrite},
stream::{ClientStream, ResolvedPacket},
CLIENTS, CONFIG,
};
@ -58,7 +58,7 @@ async fn copy_read_fast(
}
async fn copy_write_fast(
muxtx: MuxStreamWrite,
muxtx: MuxStreamWrite<WispStreamWrite>,
tcprx: OwnedReadHalf,
#[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
@ -83,7 +83,7 @@ async fn copy_write_fast(
async fn handle_stream(
connect: ConnectPacket,
muxstream: MuxStream,
muxstream: MuxStream<WispStreamWrite>,
id: String,
event: Arc<Event>,
#[cfg(feature = "twisp")] twisp_map: twisp::TwispMap,

View file

@ -37,6 +37,8 @@ mod route;
mod stats;
#[doc(hidden)]
mod stream;
#[doc(hidden)]
mod util_chain;
#[doc(hidden)]
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);

View file

@ -2,7 +2,7 @@ use std::{fmt::Display, future::Future, io::Cursor};
use anyhow::Context;
use bytes::Bytes;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite};
use http_body_util::Full;
use hyper::{
body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder,
@ -10,25 +10,30 @@ use hyper::{
};
use hyper_util::rt::TokioIo;
use log::{debug, error, trace};
use tokio::io::AsyncReadExt;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use wisp_mux::{
generic::{GenericWebSocketRead, GenericWebSocketWrite},
ws::{WebSocketRead, WebSocketWrite},
ws::{EitherWebSocketRead, EitherWebSocketWrite},
};
use crate::{
config::SocketTransport,
generate_stats,
listener::{ServerStream, ServerStreamExt},
listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite},
stream::WebSocketStreamWrapper,
util_chain::{chain, Chain},
CONFIG,
};
pub type WispResult = (
Box<dyn WebSocketRead + Send>,
Box<dyn WebSocketWrite + Send>,
);
pub type WispStreamRead = EitherWebSocketRead<
WebSocketRead<Chain<Cursor<Bytes>, ServerStreamRead>>,
GenericWebSocketRead<FramedRead<ServerStreamRead, LengthDelimitedCodec>, std::io::Error>,
>;
pub type WispStreamWrite = EitherWebSocketWrite<
WebSocketWrite<ServerStreamWrite>,
GenericWebSocketWrite<FramedWrite<ServerStreamWrite, LengthDelimitedCodec>, std::io::Error>,
>;
pub type WispResult = (WispStreamRead, WispStreamWrite);
pub enum ServerRouteResult {
Wisp(WispResult, bool),
@ -190,12 +195,15 @@ pub async fn route(
.downcast::<TokioIo<ServerStream>>()
.unwrap();
let (r, w) = parts.io.into_inner().split();
(Cursor::new(parts.read_buf).chain(r), w)
(chain(Cursor::new(parts.read_buf), r), w)
});
(callback)(
ServerRouteResult::Wisp(
(Box::new(read), Box::new(write)),
(
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
),
is_v2,
),
maybe_ip,
@ -229,7 +237,13 @@ pub async fn route(
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)(
ServerRouteResult::Wisp((Box::new(read), Box::new(write)), true),
ServerRouteResult::Wisp(
(
EitherWebSocketRead::Right(read),
EitherWebSocketWrite::Right(write),
),
true,
),
None,
);
}

View file

@ -44,7 +44,6 @@ pub enum ClientStream {
Invalid,
}
// taken from rust 1.82.0
fn ipv4_is_global(addr: &Ipv4Addr) -> bool {
!(addr.octets()[0] == 0 // "This network"

100
server/src/util_chain.rs Normal file
View file

@ -0,0 +1,100 @@
// taken from tokio io util
use std::{
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
pin_project! {
pub struct Chain<T, U> {
#[pin]
first: T,
#[pin]
second: U,
done_first: bool,
}
}
pub fn chain<T, U>(first: T, second: U) -> Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
Chain {
first,
second,
done_first: false,
}
}
impl<T, U> fmt::Debug for Chain<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Chain")
.field("t", &self.first)
.field("u", &self.second)
.finish()
}
}
impl<T, U> AsyncRead for Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.project();
if !*me.done_first {
let rem = buf.remaining();
ready!(me.first.poll_read(cx, buf))?;
if buf.remaining() == rem {
*me.done_first = true;
} else {
return Poll::Ready(Ok(()));
}
}
me.second.poll_read(cx, buf)
}
}
impl<T, U> AsyncBufRead for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let me = self.project();
if !*me.done_first {
match ready!(me.first.poll_fill_buf(cx)?) {
[] => {
*me.done_first = true;
}
buf => return Poll::Ready(Ok(buf)),
}
}
me.second.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let me = self.project();
if !*me.done_first {
me.first.consume(amt)
} else {
me.second.consume(amt)
}
}
}

View file

@ -245,7 +245,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
}));
threads.push(tokio::spawn(async move {
loop {
cr.read().await;
let _ = cr.read().await;
}
}));
}

View file

@ -23,6 +23,7 @@ futures = "0.3.30"
getrandom = { version = "0.2.15", features = ["std"], optional = true }
nohash-hasher = "0.2.0"
pin-project-lite = "0.2.14"
reusable-box-future = "0.2.0"
thiserror = "1.0.65"
tokio = { version = "1.39.3", optional = true, default-features = false }

View file

@ -11,7 +11,7 @@ use ed25519::{
};
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
ws::{DynWebSocketRead, LockingWebSocketWrite},
Role, WispError,
};
@ -183,8 +183,8 @@ impl ProtocolExtension for CertAuthProtocolExtension {
async fn handle_handshake(
&mut self,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
@ -192,8 +192,8 @@ impl ProtocolExtension for CertAuthProtocolExtension {
async fn handle_packet(
&mut self,
_: Bytes,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}

View file

@ -14,7 +14,7 @@ use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut};
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
ws::{DynWebSocketRead, LockingWebSocketWrite},
Role, WispError,
};
@ -105,16 +105,16 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static {
/// This should be used to send or receive data before any streams are created.
async fn handle_handshake(
&mut self,
read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite,
read: &mut DynWebSocketRead,
write: &dyn LockingWebSocketWrite,
) -> Result<(), WispError>;
/// Handle receiving a packet.
async fn handle_packet(
&mut self,
packet: Bytes,
read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite,
read: &mut DynWebSocketRead,
write: &dyn LockingWebSocketWrite,
) -> Result<(), WispError>;
/// Clone the protocol extension.

View file

@ -6,7 +6,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
ws::{DynWebSocketRead, LockingWebSocketWrite},
Role, WispError,
};
@ -48,8 +48,8 @@ impl ProtocolExtension for MotdProtocolExtension {
async fn handle_handshake(
&mut self,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
@ -57,8 +57,8 @@ impl ProtocolExtension for MotdProtocolExtension {
async fn handle_packet(
&mut self,
_: Bytes,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}

View file

@ -9,7 +9,7 @@ use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
ws::{DynWebSocketRead, LockingWebSocketWrite},
Role, WispError,
};
@ -94,17 +94,17 @@ impl ProtocolExtension for PasswordProtocolExtension {
async fn handle_handshake(
&mut self,
_read: &mut dyn WebSocketRead,
_write: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
async fn handle_packet(
&mut self,
_packet: Bytes,
_read: &mut dyn WebSocketRead,
_write: &LockedWebSocketWrite,
_: Bytes,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Err(WispError::ExtensionImplNotSupported)
}

View file

@ -5,7 +5,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
ws::{DynWebSocketRead, LockingWebSocketWrite},
WispError,
};
@ -40,8 +40,8 @@ impl ProtocolExtension for UdpProtocolExtension {
async fn handle_handshake(
&mut self,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
@ -49,8 +49,8 @@ impl ProtocolExtension for UdpProtocolExtension {
async fn handle_packet(
&mut self,
_: Bytes,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}

View file

@ -2,7 +2,6 @@
use std::ops::Deref;
use async_trait::async_trait;
use bytes::BytesMut;
use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead,
@ -10,7 +9,7 @@ use fastwebsockets::{
};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError};
use crate::{ws::LockingWebSocketWrite, WispError};
fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> {
match payload {
@ -87,27 +86,25 @@ impl From<WebSocketError> for crate::WispError {
}
}
#[async_trait]
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
tx: &dyn LockingWebSocketWrite,
) -> Result<crate::ws::Frame<'static>, WispError> {
Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
.await?
.into())
}
}
#[async_trait]
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
tx: &dyn LockingWebSocketWrite,
) -> Result<crate::ws::Frame<'static>, WispError> {
let mut frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
.await?;
if frame.opcode == OpCode::Continuation {
@ -121,7 +118,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
while !frame.fin {
frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
.await?;
if frame.opcode != OpCode::Continuation {
@ -142,11 +139,11 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
tx: &dyn LockingWebSocketWrite,
) -> Result<(crate::ws::Frame<'static>, Option<crate::ws::Frame<'static>>), WispError> {
let mut frame_cnt = 1;
let mut frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
.await?;
let mut extra_frame = None;
@ -161,7 +158,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
while !frame.fin {
frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
.await?;
if frame.opcode != OpCode::Continuation {
@ -197,7 +194,6 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
}
}
#[async_trait]
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into())

View file

@ -1,12 +1,11 @@
//! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s.
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use futures::{Sink, SinkExt, Stream, StreamExt};
use std::error::Error;
use crate::{
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
ws::{Frame, LockingWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
WispError,
};
@ -30,13 +29,12 @@ impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Sen
}
}
#[async_trait]
impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static>
WebSocketRead for GenericWebSocketRead<T, E>
{
async fn wisp_read_frame(
&mut self,
_tx: &LockedWebSocketWrite,
_tx: &dyn LockingWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
match self.0.next().await {
Some(data) => Ok(Frame::binary(Payload::Bytes(
@ -67,7 +65,6 @@ impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static>
}
}
#[async_trait]
impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite
for GenericWebSocketWrite<T, E>
{

View file

@ -12,7 +12,7 @@ use futures::channel::oneshot;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
mux::send_info_packet,
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
WispError,
};
@ -24,9 +24,9 @@ use super::{
WispV2Handshake,
};
async fn handshake<R: WebSocketRead>(
async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
rx: &mut R,
tx: &LockedWebSocketWrite,
tx: &LockedWebSocketWrite<W>,
v2_info: Option<WispV2Handshake>,
) -> Result<(WispHandshakeResult, u32), WispError> {
if let Some(WispV2Handshake {
@ -47,7 +47,9 @@ async fn handshake<R: WebSocketRead>(
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(rx, tx).await?;
extension
.handle_handshake(DynWebSocketRead::from_mut(rx), tx)
.await?;
}
Ok((
@ -86,34 +88,36 @@ async fn handshake<R: WebSocketRead>(
}
/// Client side multiplexor.
pub struct ClientMux {
pub struct ClientMux<W: WebSocketWrite + 'static> {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
actor_tx: mpsc::Sender<WsEvent<W>>,
tx: LockedWebSocketWrite<W>,
actor_exited: Arc<AtomicBool>,
}
impl ClientMux {
impl<W: WebSocketWrite + 'static> ClientMux<W> {
/// Create a new client side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
pub async fn create<R>(
mut rx: R,
tx: W,
wisp_v2: Option<WispV2Handshake>,
) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
) -> Result<
MuxResult<ClientMux<W>, impl Future<Output = Result<(), WispError>> + Send>,
WispError,
>
where
R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static,
R: WebSocketRead + 'static,
{
let tx = LockedWebSocketWrite::new(Box::new(tx));
let tx = LockedWebSocketWrite::new(tx);
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
let (extensions, extra_packet) = handshake_result.kind.into_parts();
@ -146,7 +150,7 @@ impl ClientMux {
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream, WispError> {
) -> Result<MuxStream<W>, WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
@ -206,7 +210,7 @@ impl ClientMux {
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
@ -215,13 +219,13 @@ impl ClientMux {
}
}
impl Drop for ClientMux {
impl<W: WebSocketWrite + 'static> Drop for ClientMux<W> {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
impl Multiplexor for ClientMux {
impl<W: WebSocketWrite + 'static> Multiplexor for ClientMux<W> {
fn has_extension(&self, extension_id: u8) -> bool {
self.supported_extensions
.iter()

View file

@ -5,23 +5,23 @@ use std::sync::{
use crate::{
extensions::AnyProtocolExtension,
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError,
};
use bytes::{Bytes, BytesMut};
use event_listener::Event;
use flume as mpsc;
use futures::{channel::oneshot, select, FutureExt};
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
use nohash_hasher::IntMap;
pub(crate) enum WsEvent {
pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
CreateStream(
StreamType,
String,
u16,
oneshot::Sender<Result<MuxStream, WispError>>,
oneshot::Sender<Result<MuxStream<W>, WispError>>,
),
SendPing(Payload<'static>, oneshot::Sender<Result<(), WispError>>),
SendPong(Payload<'static>),
@ -43,20 +43,21 @@ struct MuxMapValue {
is_closed_event: Arc<Event>,
}
pub struct MuxInner<R: WebSocketRead + Send> {
pub struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
// gets taken by the mux task
rx: Option<R>,
// gets taken by the mux task
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
tx: LockedWebSocketWrite<W>,
// gets taken by the mux task
extensions: Option<Vec<AnyProtocolExtension>>,
tcp_extensions: Vec<u8>,
role: Role,
// gets taken by the mux task
actor_rx: Option<mpsc::Receiver<WsEvent>>,
actor_tx: mpsc::Sender<WsEvent>,
actor_rx: Option<mpsc::Receiver<WsEvent<W>>>,
actor_tx: mpsc::Sender<WsEvent<W>>,
fut_exited: Arc<AtomicBool>,
stream_map: IntMap<u32, MuxMapValue>,
@ -64,16 +65,16 @@ pub struct MuxInner<R: WebSocketRead + Send> {
buffer_size: u32,
target_buffer_size: u32,
server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>,
server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>,
}
pub struct MuxInnerResult<R: WebSocketRead + Send> {
pub mux: MuxInner<R>,
pub struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
pub mux: MuxInner<R, W>,
pub actor_exited: Arc<AtomicBool>,
pub actor_tx: mpsc::Sender<WsEvent>,
pub actor_tx: mpsc::Sender<WsEvent<W>>,
}
impl<R: WebSocketRead + Send> MuxInner<R> {
impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec<u8> {
extensions
.iter()
@ -83,18 +84,19 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
.collect()
}
#[allow(clippy::type_complexity)]
pub fn new_server(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
tx: LockedWebSocketWrite<W>,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
) -> (
MuxInnerResult<R>,
mpsc::Receiver<(ConnectPacket, MuxStream)>,
MuxInnerResult<R, W>,
mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
) {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false));
@ -110,7 +112,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
fut_exited: fut_exited.clone(),
tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions,
extensions: Some(extensions),
buffer_size,
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32,
@ -130,12 +132,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
pub fn new_client(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
tx: LockedWebSocketWrite<W>,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
) -> MuxInnerResult<R> {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
) -> MuxInnerResult<R, W> {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false));
@ -150,7 +152,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
fut_exited: fut_exited.clone(),
tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions,
extensions: Some(extensions),
buffer_size,
target_buffer_size: 0,
@ -183,7 +185,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
&mut self,
stream_id: u32,
stream_type: StreamType,
) -> Result<(MuxMapValue, MuxStream), WispError> {
) -> Result<(MuxMapValue, MuxStream<W>), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
self.buffer_size as usize
} else {
@ -241,11 +243,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
async fn process_wisp_message(
&mut self,
rx: &mut R,
msg: Result<(Frame<'static>, Option<Frame<'static>>), WispError>,
) -> Result<Option<WsEvent>, WispError> {
let (mut frame, optional_frame) = msg?;
tx: &LockedWebSocketWrite<W>,
extensions: &mut [AnyProtocolExtension],
msg: (Frame<'static>, Option<Frame<'static>>),
) -> Result<Option<WsEvent<W>>, WispError> {
let (mut frame, optional_frame) = msg;
if frame.opcode == OpCode::Close {
return Ok(None);
} else if frame.opcode == OpCode::Ping {
@ -262,8 +265,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
}
let packet =
Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?;
let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?;
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
}
@ -271,36 +273,47 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
async fn stream_loop(&mut self) -> Result<(), WispError> {
let mut next_free_stream_id: u32 = 1;
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
let tx = self.tx.clone();
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?;
if let Some(downgrade_packet) = maybe_downgrade_packet {
if self.handle_packet(downgrade_packet, None).await? {
return Ok(());
}
}
let mut read_stream = Box::pin(unfold(
(rx, tx.clone(), extensions),
|(mut rx, tx, mut extensions)| async {
let ret = async {
let msg = rx.wisp_read_split(&tx).await?;
Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await
}
.await;
ret.transpose().map(|x| (x, (rx, tx, extensions)))
},
))
.fuse();
let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse();
while let Some(msg) = select! {
x = recv_fut => {
drop(recv_fut);
recv_fut = fut_rx.recv_async().fuse();
Ok(x.ok())
},
x = read_fut => {
drop(read_fut);
let ret = self.process_wisp_message(&mut rx, x).await;
read_fut = rx.wisp_read_split(&tx).fuse();
ret
x = read_stream.next() => {
x.transpose()
}
}? {
match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let ret: Result<MuxStream<W>, WispError> = async {
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)

View file

@ -8,7 +8,7 @@ pub use server::ServerMux;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::LockedWebSocketWrite,
ws::{LockedWebSocketWrite, WebSocketWrite},
CloseReason, Packet, PacketType, Role, WispError,
};
@ -35,8 +35,8 @@ impl WispHandshakeResultKind {
}
}
async fn send_info_packet(
write: &LockedWebSocketWrite,
async fn send_info_packet<W: WebSocketWrite>(
write: &LockedWebSocketWrite<W>,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> {
write

View file

@ -11,7 +11,7 @@ use futures::channel::oneshot;
use crate::{
extensions::AnyProtocolExtension,
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
WispError,
};
@ -23,9 +23,9 @@ use super::{
WispV2Handshake,
};
async fn handshake<R: WebSocketRead>(
async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
rx: &mut R,
tx: &LockedWebSocketWrite,
tx: &LockedWebSocketWrite<W>,
buffer_size: u32,
v2_info: Option<WispV2Handshake>,
) -> Result<WispHandshakeResult, WispError> {
@ -47,7 +47,9 @@ async fn handshake<R: WebSocketRead>(
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(rx, tx).await?;
extension
.handle_handshake(DynWebSocketRead::from_mut(rx), tx)
.await?;
}
// v2 client
@ -79,36 +81,38 @@ async fn handshake<R: WebSocketRead>(
}
/// Server-side multiplexor.
pub struct ServerMux {
pub struct ServerMux<W: WebSocketWrite + 'static> {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: LockedWebSocketWrite,
actor_tx: mpsc::Sender<WsEvent<W>>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
tx: LockedWebSocketWrite<W>,
actor_exited: Arc<AtomicBool>,
}
impl ServerMux {
impl<W: WebSocketWrite + 'static> ServerMux<W> {
/// Create a new server-side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
pub async fn create<R>(
mut rx: R,
tx: W,
buffer_size: u32,
wisp_v2: Option<WispV2Handshake>,
) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
) -> Result<
MuxResult<ServerMux<W>, impl Future<Output = Result<(), WispError>> + Send>,
WispError,
>
where
R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static,
R: WebSocketRead + Send + 'static,
{
let tx = LockedWebSocketWrite::new(Box::new(tx));
let tx = LockedWebSocketWrite::new(tx);
let ret_tx = tx.clone();
let ret = async {
let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?;
@ -165,7 +169,7 @@ impl ServerMux {
}
/// Wait for a stream to be created.
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream<W>)> {
if self.actor_exited.load(Ordering::Acquire) {
return None;
}
@ -210,7 +214,7 @@ impl ServerMux {
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
@ -219,13 +223,13 @@ impl ServerMux {
}
}
impl Drop for ServerMux {
impl<W: WebSocketWrite + 'static> Drop for ServerMux<W> {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
impl Multiplexor for ServerMux {
impl<W: WebSocketWrite + 'static> Multiplexor for ServerMux<W> {
fn has_extension(&self, extension_id: u8) -> bool {
self.supported_extensions
.iter()

View file

@ -2,7 +2,10 @@ use std::fmt::Display;
use crate::{
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
ws::{
self, DynWebSocketRead, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead,
WebSocketWrite,
},
Role, WispError, WISP_VERSION,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
@ -527,11 +530,11 @@ impl<'a> Packet<'a> {
}
}
pub(crate) async fn maybe_handle_extension(
pub(crate) async fn maybe_handle_extension<R: WebSocketRead + 'static, W: WebSocketWrite>(
frame: Frame<'a>,
extensions: &mut [AnyProtocolExtension],
read: &mut (dyn WebSocketRead + Send),
write: &LockedWebSocketWrite,
read: &mut R,
write: &LockedWebSocketWrite<W>,
) -> Result<Option<Self>, WispError> {
if !frame.finished {
return Err(WispError::WsFrameNotFinished);
@ -568,7 +571,11 @@ impl<'a> Packet<'a> {
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{
extension
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
.handle_packet(
BytesMut::from(bytes).freeze(),
DynWebSocketRead::from_mut(read),
write,
)
.await?;
Ok(None)
} else {

View file

@ -98,7 +98,10 @@ impl MuxStreamIoStream {
impl Stream for MuxStreamIoStream {
type Item = Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx).map_err(std::io::Error::other)
self.project()
.rx
.poll_next(cx)
.map_err(std::io::Error::other)
}
}

View file

@ -4,7 +4,7 @@ pub use compat::*;
use crate::{
inner::WsEvent,
ws::{Frame, LockedWebSocketWrite, Payload},
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketWrite},
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
};
@ -21,7 +21,7 @@ use std::{
};
/// Read side of a multiplexor stream.
pub struct MuxStreamRead {
pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
/// ID of the stream.
pub stream_id: u32,
/// Type of the stream.
@ -29,7 +29,7 @@ pub struct MuxStreamRead {
role: Role,
tx: LockedWebSocketWrite,
tx: LockedWebSocketWrite<W>,
rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>,
@ -42,7 +42,7 @@ pub struct MuxStreamRead {
target_flow_control: u32,
}
impl MuxStreamRead {
impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
/// Read an event from the stream.
pub async fn read(&self) -> Result<Option<Bytes>, WispError> {
if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) {
@ -98,15 +98,15 @@ impl MuxStreamRead {
}
/// Write side of a multiplexor stream.
pub struct MuxStreamWrite {
pub struct MuxStreamWrite<W: WebSocketWrite + 'static> {
/// ID of the stream.
pub stream_id: u32,
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
mux_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
mux_tx: mpsc::Sender<WsEvent<W>>,
tx: LockedWebSocketWrite<W>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
@ -116,7 +116,7 @@ pub struct MuxStreamWrite {
flow_control: Arc<AtomicU32>,
}
impl MuxStreamWrite {
impl<W: WebSocketWrite + 'static> MuxStreamWrite<W> {
pub(crate) async fn write_payload_internal<'a>(
&self,
header: Frame<'static>,
@ -169,7 +169,7 @@ impl MuxStreamWrite {
/// handle.close(0x01);
/// }
/// ```
pub fn get_close_handle(&self) -> MuxStreamCloser {
pub fn get_close_handle(&self) -> MuxStreamCloser<W> {
MuxStreamCloser {
stream_id: self.stream_id,
close_channel: self.mux_tx.clone(),
@ -179,7 +179,7 @@ impl MuxStreamWrite {
}
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
MuxProtocolExtensionStream {
stream_id: self.stream_id,
tx: self.tx.clone(),
@ -244,7 +244,7 @@ impl MuxStreamWrite {
}
}
impl Drop for MuxStreamWrite {
impl<W: WebSocketWrite + 'static> Drop for MuxStreamWrite<W> {
fn drop(&mut self) {
if !self.is_closed.load(Ordering::Acquire) {
self.is_closed.store(true, Ordering::Release);
@ -258,22 +258,22 @@ impl Drop for MuxStreamWrite {
}
/// Multiplexor stream.
pub struct MuxStream {
pub struct MuxStream<W: WebSocketWrite + 'static> {
/// ID of the stream.
pub stream_id: u32,
rx: MuxStreamRead,
tx: MuxStreamWrite,
rx: MuxStreamRead<W>,
tx: MuxStreamWrite<W>,
}
impl MuxStream {
impl<W: WebSocketWrite + 'static> MuxStream<W> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
stream_id: u32,
role: Role,
stream_type: StreamType,
rx: mpsc::Receiver<Bytes>,
mux_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
mux_tx: mpsc::Sender<WsEvent<W>>,
tx: LockedWebSocketWrite<W>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
close_reason: Arc<AtomicCloseReason>,
@ -339,12 +339,12 @@ impl MuxStream {
/// handle.close(0x01);
/// }
/// ```
pub fn get_close_handle(&self) -> MuxStreamCloser {
pub fn get_close_handle(&self) -> MuxStreamCloser<W> {
self.tx.get_close_handle()
}
/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
self.tx.get_protocol_extension_stream()
}
@ -359,7 +359,7 @@ impl MuxStream {
}
/// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
(self.rx, self.tx)
}
@ -374,15 +374,15 @@ impl MuxStream {
/// Close handle for a multiplexor stream.
#[derive(Clone)]
pub struct MuxStreamCloser {
pub struct MuxStreamCloser<W: WebSocketWrite + 'static> {
/// ID of the stream.
pub stream_id: u32,
close_channel: mpsc::Sender<WsEvent>,
close_channel: mpsc::Sender<WsEvent<W>>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
}
impl MuxStreamCloser {
impl<W: WebSocketWrite + 'static> MuxStreamCloser<W> {
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
@ -414,14 +414,14 @@ impl MuxStreamCloser {
}
/// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream {
pub struct MuxProtocolExtensionStream<W: WebSocketWrite + 'static> {
/// ID of the stream.
pub stream_id: u32,
pub(crate) tx: LockedWebSocketWrite,
pub(crate) tx: LockedWebSocketWrite<W>,
pub(crate) is_closed: Arc<AtomicBool>,
}
impl MuxProtocolExtensionStream {
impl<W: WebSocketWrite + 'static> MuxProtocolExtensionStream<W> {
/// Send a protocol extension packet with this stream's ID.
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {

View file

@ -4,12 +4,11 @@
//! for other WebSocket implementations.
//!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
use std::{ops::Deref, sync::Arc};
use std::{future::Future, ops::Deref, pin::Pin, sync::Arc};
use crate::WispError;
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use futures::lock::Mutex;
use futures::{lock::Mutex, TryFutureExt};
/// Payload of the websocket frame.
#[derive(Debug)]
@ -158,55 +157,130 @@ impl<'a> Frame<'a> {
}
/// Generic WebSocket read trait.
#[async_trait]
pub trait WebSocketRead {
pub trait WebSocketRead: Send {
/// Read a frame from the socket.
async fn wisp_read_frame(
fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError>;
tx: &dyn LockingWebSocketWrite,
) -> impl Future<Output = Result<Frame<'static>, WispError>> + Send;
/// Read a split frame from the socket.
async fn wisp_read_split(
fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
self.wisp_read_frame(tx).await.map(|x| (x, None))
tx: &dyn LockingWebSocketWrite,
) -> impl Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>> + Send {
self.wisp_read_frame(tx).map_ok(|x| (x, None))
}
}
#[async_trait]
impl WebSocketRead for Box<dyn WebSocketRead + Send> {
// similar to what dynosaur does
mod wsr_inner {
use std::{future::Future, pin::Pin};
use crate::WispError;
use super::{Frame, LockingWebSocketWrite, WebSocketRead};
trait ErasedWebSocketRead: Send {
fn wisp_read_frame<'a>(
&'a mut self,
tx: &'a dyn LockingWebSocketWrite,
) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>>;
#[allow(clippy::type_complexity)]
fn wisp_read_split<'a>(
&'a mut self,
tx: &'a dyn LockingWebSocketWrite,
) -> Pin<
Box<
dyn Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>>
+ Send
+ 'a,
>,
>;
}
impl<T: WebSocketRead> ErasedWebSocketRead for T {
fn wisp_read_frame<'a>(
&'a mut self,
tx: &'a dyn LockingWebSocketWrite,
) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>> {
Box::pin(self.wisp_read_frame(tx))
}
fn wisp_read_split<'a>(
&'a mut self,
tx: &'a dyn LockingWebSocketWrite,
) -> Pin<
Box<
dyn Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>>
+ Send
+ 'a,
>,
> {
Box::pin(self.wisp_read_split(tx))
}
}
/// WebSocketRead trait object.
#[repr(transparent)]
pub struct DynWebSocketRead {
ptr: dyn ErasedWebSocketRead + 'static,
}
impl WebSocketRead for DynWebSocketRead {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
tx: &dyn LockingWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
self.as_mut().wisp_read_frame(tx).await
self.ptr.wisp_read_frame(tx).await
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
tx: &dyn LockingWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
self.as_mut().wisp_read_split(tx).await
self.ptr.wisp_read_split(tx).await
}
}
impl DynWebSocketRead {
/// Create a WebSocketRead trait object from a boxed WebSocketRead.
pub fn new(val: Box<impl WebSocketRead + 'static>) -> Box<Self> {
let val: Box<dyn ErasedWebSocketRead + 'static> = val;
unsafe { std::mem::transmute(val) }
}
/// Create a WebSocketRead trait object from a WebSocketRead.
pub fn boxed(val: impl WebSocketRead + 'static) -> Box<Self> {
Self::new(Box::new(val))
}
/// Create a WebSocketRead trait object from a WebSocketRead reference.
pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self {
let val: &(dyn ErasedWebSocketRead + 'static) = val;
unsafe { std::mem::transmute(val) }
}
/// Create a WebSocketRead trait object from a mutable WebSocketRead reference.
pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self {
let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val;
unsafe { std::mem::transmute(val) }
}
}
}
pub use wsr_inner::DynWebSocketRead;
/// Generic WebSocket write trait.
#[async_trait]
pub trait WebSocketWrite {
pub trait WebSocketWrite: Send {
/// Write a frame to the socket.
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
/// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>;
fn wisp_write_frame(
&mut self,
frame: Frame<'_>,
) -> impl Future<Output = Result<(), WispError>> + Send;
/// Write a split frame to the socket.
async fn wisp_write_split(
fn wisp_write_split(
&mut self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
) -> impl Future<Output = Result<(), WispError>> + Send {
async move {
let mut payload = BytesMut::from(header.payload);
payload.extend_from_slice(&body.payload);
self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
@ -214,14 +288,66 @@ pub trait WebSocketWrite {
}
}
#[async_trait]
impl WebSocketWrite for Box<dyn WebSocketWrite + Send> {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
self.as_mut().wisp_write_frame(frame).await
/// Close the socket.
fn wisp_close(&mut self) -> impl Future<Output = Result<(), WispError>> + Send;
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.as_mut().wisp_close().await
// similar to what dynosaur does
mod wsw_inner {
use std::{future::Future, pin::Pin};
use crate::WispError;
use super::{Frame, WebSocketWrite};
trait ErasedWebSocketWrite: Send {
fn wisp_write_frame<'a>(
&'a mut self,
frame: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
fn wisp_write_split<'a>(
&'a mut self,
header: Frame<'a>,
body: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
fn wisp_close<'a>(
&'a mut self,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
}
impl<T: WebSocketWrite> ErasedWebSocketWrite for T {
fn wisp_write_frame<'a>(
&'a mut self,
frame: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.wisp_write_frame(frame))
}
fn wisp_write_split<'a>(
&'a mut self,
header: Frame<'a>,
body: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.wisp_write_split(header, body))
}
fn wisp_close<'a>(
&'a mut self,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.wisp_close())
}
}
/// WebSocketWrite trait object.
#[repr(transparent)]
pub struct DynWebSocketWrite {
ptr: dyn ErasedWebSocketWrite + 'static,
}
impl WebSocketWrite for DynWebSocketWrite {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
self.ptr.wisp_write_frame(frame).await
}
async fn wisp_write_split(
@ -229,30 +355,88 @@ impl WebSocketWrite for Box<dyn WebSocketWrite + Send> {
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
self.as_mut().wisp_write_split(header, body).await
self.ptr.wisp_write_split(header, body).await
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.ptr.wisp_close().await
}
}
impl DynWebSocketWrite {
/// Create a new WebSocketWrite trait object from a boxed WebSocketWrite.
pub fn new(val: Box<impl WebSocketWrite + 'static>) -> Box<Self> {
let val: Box<dyn ErasedWebSocketWrite + 'static> = val;
unsafe { std::mem::transmute(val) }
}
/// Create a new WebSocketWrite trait object from a WebSocketWrite.
pub fn boxed(val: impl WebSocketWrite + 'static) -> Box<Self> {
Self::new(Box::new(val))
}
/// Create a new WebSocketWrite trait object from a WebSocketWrite reference.
pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self {
let val: &(dyn ErasedWebSocketWrite + 'static) = val;
unsafe { std::mem::transmute(val) }
}
/// Create a new WebSocketWrite trait object from a mutable WebSocketWrite reference.
pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self {
let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val;
unsafe { std::mem::transmute(val) }
}
}
}
pub use wsw_inner::DynWebSocketWrite;
mod private {
pub trait Sealed {}
}
/// Helper trait object for LockedWebSocketWrite.
pub trait LockingWebSocketWrite: private::Sealed + Sync {
/// Write a frame to the websocket.
fn wisp_write_frame<'a>(
&'a self,
frame: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
/// Write a split frame to the websocket.
fn wisp_write_split<'a>(
&'a self,
header: Frame<'a>,
body: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
/// Close the websocket.
fn wisp_close<'a>(&'a self)
-> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
}
/// Locked WebSocket.
#[derive(Clone)]
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
pub struct LockedWebSocketWrite<T: WebSocketWrite>(Arc<Mutex<T>>);
impl LockedWebSocketWrite {
impl<T: WebSocketWrite> Clone for LockedWebSocketWrite<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: WebSocketWrite> LockedWebSocketWrite<T> {
/// Create a new locked websocket.
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
pub fn new(ws: T) -> Self {
Self(Mutex::new(ws).into())
}
/// Create a new locked websocket from an existing mutex.
pub fn from_locked(locked: Arc<Mutex<T>>) -> Self {
Self(locked)
}
/// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await
}
pub(crate) async fn write_split(
&self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
/// Write a split frame to the websocket.
pub async fn write_split(&self, header: Frame<'_>, body: Frame<'_>) -> Result<(), WispError> {
self.0.lock().await.wisp_write_split(header, body).await
}
@ -261,3 +445,91 @@ impl LockedWebSocketWrite {
self.0.lock().await.wisp_close().await
}
}
impl<T: WebSocketWrite> private::Sealed for LockedWebSocketWrite<T> {}
impl<T: WebSocketWrite> LockingWebSocketWrite for LockedWebSocketWrite<T> {
fn wisp_write_frame<'a>(
&'a self,
frame: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.write_frame(frame))
}
fn wisp_write_split<'a>(
&'a self,
header: Frame<'a>,
body: Frame<'a>,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.write_split(header, body))
}
fn wisp_close<'a>(
&'a self,
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
Box::pin(self.close())
}
}
/// Combines two different WebSocketReads together.
pub enum EitherWebSocketRead<A: WebSocketRead, B: WebSocketRead> {
/// First WebSocketRead variant.
Left(A),
/// Second WebSocketRead variant.
Right(B),
}
impl<A: WebSocketRead, B: WebSocketRead> WebSocketRead for EitherWebSocketRead<A, B> {
async fn wisp_read_frame(
&mut self,
tx: &dyn LockingWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
match self {
Self::Left(x) => x.wisp_read_frame(tx).await,
Self::Right(x) => x.wisp_read_frame(tx).await,
}
}
async fn wisp_read_split(
&mut self,
tx: &dyn LockingWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
match self {
Self::Left(x) => x.wisp_read_split(tx).await,
Self::Right(x) => x.wisp_read_split(tx).await,
}
}
}
/// Combines two different WebSocketWrites together.
pub enum EitherWebSocketWrite<A: WebSocketWrite, B: WebSocketWrite> {
/// First WebSocketWrite variant.
Left(A),
/// Second WebSocketWrite variant.
Right(B),
}
impl<A: WebSocketWrite, B: WebSocketWrite> WebSocketWrite for EitherWebSocketWrite<A, B> {
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
match self {
Self::Left(x) => x.wisp_write_frame(frame).await,
Self::Right(x) => x.wisp_write_frame(frame).await,
}
}
async fn wisp_write_split(
&mut self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
match self {
Self::Left(x) => x.wisp_write_split(header, body).await,
Self::Right(x) => x.wisp_write_split(header, body).await,
}
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
match self {
Self::Left(x) => x.wisp_close().await,
Self::Right(x) => x.wisp_close().await,
}
}
}