mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
use knockoff dynosaur to remove async_trait on wsr/wsw
This commit is contained in:
parent
5e54465e58
commit
9129d767f8
31 changed files with 692 additions and 258 deletions
9
Cargo.lock
generated
9
Cargo.lock
generated
|
@ -698,7 +698,6 @@ name = "epoxy-client"
|
|||
version = "2.1.15"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"event-listener",
|
||||
|
@ -755,6 +754,7 @@ dependencies = [
|
|||
"libc",
|
||||
"log",
|
||||
"nix",
|
||||
"pin-project-lite",
|
||||
"pty-process",
|
||||
"regex",
|
||||
"rustls-pemfile",
|
||||
|
@ -1850,6 +1850,12 @@ dependencies = [
|
|||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reusable-box-future"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9"
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.8"
|
||||
|
@ -3015,6 +3021,7 @@ dependencies = [
|
|||
"getrandom",
|
||||
"nohash-hasher",
|
||||
"pin-project-lite",
|
||||
"reusable-box-future",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
|
|
@ -8,7 +8,6 @@ crate-type = ["cdylib"]
|
|||
|
||||
[dependencies]
|
||||
async-compression = { version = "0.4.12", features = ["futures-io", "gzip", "brotli"], optional = true }
|
||||
async-trait = "0.1.81"
|
||||
bytes = "1.7.1"
|
||||
cfg-if = "1.0.0"
|
||||
event-listener = "5.3.1"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#![feature(let_chains, impl_trait_in_assoc_type)]
|
||||
use std::{error::Error, str::FromStr, sync::Arc};
|
||||
use std::{error::Error, pin::Pin, str::FromStr, sync::Arc};
|
||||
|
||||
#[cfg(feature = "full")]
|
||||
use async_compression::futures::bufread as async_comp;
|
||||
|
@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut};
|
|||
use cfg_if::cfg_if;
|
||||
#[cfg(feature = "full")]
|
||||
use futures_util::future::Either;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use futures_util::{Stream, StreamExt, TryStreamExt};
|
||||
use http::{
|
||||
header::{
|
||||
InvalidHeaderName, InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH,
|
||||
|
@ -41,7 +41,7 @@ use websocket::EpoxyWebSocket;
|
|||
use wisp_mux::StreamType;
|
||||
use wisp_mux::{
|
||||
generic::GenericWebSocketRead,
|
||||
ws::{WebSocketRead, WebSocketWrite},
|
||||
ws::{EitherWebSocketRead, EitherWebSocketWrite},
|
||||
CloseReason,
|
||||
};
|
||||
use ws_wrapper::WebSocketWrapper;
|
||||
|
@ -343,7 +343,7 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator {
|
|||
}
|
||||
.into();
|
||||
|
||||
let read = GenericWebSocketRead::new(SendWrapper::new(
|
||||
let read = GenericWebSocketRead::new(Box::pin(SendWrapper::new(
|
||||
wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into())
|
||||
.into_stream()
|
||||
.map(|x| {
|
||||
|
@ -355,15 +355,16 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator {
|
|||
Uint8Array::new(&arr).to_vec().as_slice(),
|
||||
))
|
||||
}),
|
||||
));
|
||||
))
|
||||
as Pin<Box<dyn Stream<Item = Result<BytesMut, EpoxyError>> + Send>>);
|
||||
let write: WritableStream = object_get(&transport, "write").into();
|
||||
let write = WispTransportWrite {
|
||||
inner: SendWrapper::new(write.get_writer().map_err(EpoxyError::wisp_transport)?),
|
||||
};
|
||||
|
||||
Ok((
|
||||
Box::new(read) as Box<dyn WebSocketRead + Send>,
|
||||
Box::new(write) as Box<dyn WebSocketWrite + Send>,
|
||||
EitherWebSocketRead::Right(read),
|
||||
EitherWebSocketWrite::Right(write),
|
||||
))
|
||||
}))
|
||||
})
|
||||
|
@ -421,8 +422,8 @@ impl EpoxyClient {
|
|||
}
|
||||
}
|
||||
Ok((
|
||||
Box::new(read) as Box<dyn WebSocketRead + Send>,
|
||||
Box::new(write) as Box<dyn WebSocketWrite + Send>,
|
||||
EitherWebSocketRead::Left(read),
|
||||
EitherWebSocketWrite::Left(write),
|
||||
))
|
||||
})
|
||||
}),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use cfg_if::cfg_if;
|
||||
use futures_rustls::{
|
||||
rustls::{ClientConfig, RootCertStore},
|
||||
|
@ -8,7 +9,7 @@ use futures_rustls::{
|
|||
use futures_util::{
|
||||
future::Either,
|
||||
lock::{Mutex, MutexGuard},
|
||||
AsyncRead, AsyncWrite, Future,
|
||||
AsyncRead, AsyncWrite, Future, Stream,
|
||||
};
|
||||
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
|
||||
use pin_project_lite::pin_project;
|
||||
|
@ -16,18 +17,30 @@ use wasm_bindgen_futures::spawn_local;
|
|||
use webpki_roots::TLS_SERVER_ROOTS;
|
||||
use wisp_mux::{
|
||||
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
|
||||
ws::{WebSocketRead, WebSocketWrite},
|
||||
generic::GenericWebSocketRead,
|
||||
ws::{EitherWebSocketRead, EitherWebSocketWrite},
|
||||
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
console_error, console_log, utils::{IgnoreCloseNotify, NoCertificateVerification}, EpoxyClientOptions, EpoxyError
|
||||
console_error, console_log,
|
||||
utils::{IgnoreCloseNotify, NoCertificateVerification, WispTransportWrite},
|
||||
ws_wrapper::{WebSocketReader, WebSocketWrapper},
|
||||
EpoxyClientOptions, EpoxyError,
|
||||
};
|
||||
|
||||
pub type ProviderUnencryptedStream = MuxStreamIo;
|
||||
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
|
||||
pub type ProviderTlsAsyncRW = IgnoreCloseNotify;
|
||||
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
|
||||
pub type ProviderWispTransportRead = EitherWebSocketRead<
|
||||
WebSocketReader,
|
||||
GenericWebSocketRead<
|
||||
Pin<Box<dyn Stream<Item = Result<BytesMut, EpoxyError>> + Send>>,
|
||||
EpoxyError,
|
||||
>,
|
||||
>;
|
||||
pub type ProviderWispTransportWrite = EitherWebSocketWrite<WebSocketWrapper, WispTransportWrite>;
|
||||
pub type ProviderWispTransportGenerator = Box<
|
||||
dyn Fn(
|
||||
bool,
|
||||
|
@ -35,10 +48,7 @@ pub type ProviderWispTransportGenerator = Box<
|
|||
Box<
|
||||
dyn Future<
|
||||
Output = Result<
|
||||
(
|
||||
Box<dyn WebSocketRead + Send>,
|
||||
Box<dyn WebSocketWrite + Send>,
|
||||
),
|
||||
(ProviderWispTransportRead, ProviderWispTransportWrite),
|
||||
EpoxyError,
|
||||
>,
|
||||
> + Sync
|
||||
|
@ -54,7 +64,7 @@ pub struct StreamProvider {
|
|||
wisp_v2: bool,
|
||||
udp_extension: bool,
|
||||
|
||||
current_client: Arc<Mutex<Option<ClientMux>>>,
|
||||
current_client: Arc<Mutex<Option<ClientMux<ProviderWispTransportWrite>>>>,
|
||||
|
||||
h2_config: Arc<ClientConfig>,
|
||||
client_config: Arc<ClientConfig>,
|
||||
|
@ -115,7 +125,7 @@ impl StreamProvider {
|
|||
|
||||
async fn create_client(
|
||||
&self,
|
||||
mut locked: MutexGuard<'_, Option<ClientMux>>,
|
||||
mut locked: MutexGuard<'_, Option<ClientMux<ProviderWispTransportWrite>>>,
|
||||
) -> Result<(), EpoxyError> {
|
||||
let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
|
||||
vec![AnyProtocolExtensionBuilder::new(
|
||||
|
@ -140,7 +150,11 @@ impl StreamProvider {
|
|||
spawn_local(async move {
|
||||
match fut.await {
|
||||
Ok(_) => console_log!("epoxy: wisp multiplexor task ended successfully"),
|
||||
Err(x) => console_error!("epoxy: wisp multiplexor task ended with an error: {} {:?}", x, x),
|
||||
Err(x) => console_error!(
|
||||
"epoxy: wisp multiplexor task ended with an error: {} {:?}",
|
||||
x,
|
||||
x
|
||||
),
|
||||
}
|
||||
current_client.lock().await.take();
|
||||
});
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use std::{pin::Pin, task::{Context, Poll}};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::{AsyncRead, Stream, StreamExt, TryStreamExt};
|
||||
|
@ -12,7 +15,6 @@ use crate::{console_error, EpoxyError};
|
|||
|
||||
use super::ReaderStream;
|
||||
|
||||
|
||||
#[wasm_bindgen(inline_js = r#"
|
||||
export function ws_protocol() {
|
||||
return (
|
||||
|
|
|
@ -8,7 +8,6 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
|
||||
use futures_util::{ready, AsyncRead, Future, Stream};
|
||||
use http::{HeaderValue, Uri};
|
||||
|
@ -179,7 +178,6 @@ pub struct WispTransportWrite {
|
|||
pub inner: SendWrapper<WritableStreamDefaultWriter>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WebSocketWrite for WispTransportWrite {
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
SendWrapper::new(async {
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use std::{
|
||||
io::ErrorKind, pin::Pin, sync::Arc, task::{Context, Poll}
|
||||
io::ErrorKind,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_rustls::{
|
||||
|
|
|
@ -3,7 +3,6 @@ use std::sync::{
|
|||
Arc,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use event_listener::Event;
|
||||
use flume::Receiver;
|
||||
|
@ -14,7 +13,7 @@ use thiserror::Error;
|
|||
use wasm_bindgen::{closure::Closure, JsCast, JsValue};
|
||||
use web_sys::{BinaryType, MessageEvent, WebSocket};
|
||||
use wisp_mux::{
|
||||
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
ws::{Frame, LockingWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
WispError,
|
||||
};
|
||||
|
||||
|
@ -66,11 +65,10 @@ pub struct WebSocketReader {
|
|||
close_event: Arc<Event>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WebSocketRead for WebSocketReader {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
use WebSocketMessage as M;
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
|
@ -185,7 +183,6 @@ impl WebSocketWrapper {
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WebSocketWrite for WebSocketWrapper {
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
use wisp_mux::ws::OpCode::*;
|
||||
|
|
|
@ -24,6 +24,7 @@ lazy_static = "1.5.0"
|
|||
libc = { version = "0.2.158", optional = true }
|
||||
log = { version = "0.4.22", features = ["serde", "std"] }
|
||||
nix = { version = "0.29.0", features = ["term"] }
|
||||
pin-project-lite = "0.2.15"
|
||||
pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true }
|
||||
regex = "1.10.6"
|
||||
rustls-pemfile = "2.1.3"
|
||||
|
|
|
@ -25,7 +25,7 @@ use wisp_mux::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
route::WispResult,
|
||||
route::{WispResult, WispStreamWrite},
|
||||
stream::{ClientStream, ResolvedPacket},
|
||||
CLIENTS, CONFIG,
|
||||
};
|
||||
|
@ -58,7 +58,7 @@ async fn copy_read_fast(
|
|||
}
|
||||
|
||||
async fn copy_write_fast(
|
||||
muxtx: MuxStreamWrite,
|
||||
muxtx: MuxStreamWrite<WispStreamWrite>,
|
||||
tcprx: OwnedReadHalf,
|
||||
#[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter<
|
||||
async_speed_limit::clock::StandardClock,
|
||||
|
@ -83,7 +83,7 @@ async fn copy_write_fast(
|
|||
|
||||
async fn handle_stream(
|
||||
connect: ConnectPacket,
|
||||
muxstream: MuxStream,
|
||||
muxstream: MuxStream<WispStreamWrite>,
|
||||
id: String,
|
||||
event: Arc<Event>,
|
||||
#[cfg(feature = "twisp")] twisp_map: twisp::TwispMap,
|
||||
|
|
|
@ -37,6 +37,8 @@ mod route;
|
|||
mod stats;
|
||||
#[doc(hidden)]
|
||||
mod stream;
|
||||
#[doc(hidden)]
|
||||
mod util_chain;
|
||||
|
||||
#[doc(hidden)]
|
||||
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{fmt::Display, future::Future, io::Cursor};
|
|||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
|
||||
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite};
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder,
|
||||
|
@ -10,25 +10,30 @@ use hyper::{
|
|||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::{debug, error, trace};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
use wisp_mux::{
|
||||
generic::{GenericWebSocketRead, GenericWebSocketWrite},
|
||||
ws::{WebSocketRead, WebSocketWrite},
|
||||
ws::{EitherWebSocketRead, EitherWebSocketWrite},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::SocketTransport,
|
||||
generate_stats,
|
||||
listener::{ServerStream, ServerStreamExt},
|
||||
listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite},
|
||||
stream::WebSocketStreamWrapper,
|
||||
util_chain::{chain, Chain},
|
||||
CONFIG,
|
||||
};
|
||||
|
||||
pub type WispResult = (
|
||||
Box<dyn WebSocketRead + Send>,
|
||||
Box<dyn WebSocketWrite + Send>,
|
||||
);
|
||||
pub type WispStreamRead = EitherWebSocketRead<
|
||||
WebSocketRead<Chain<Cursor<Bytes>, ServerStreamRead>>,
|
||||
GenericWebSocketRead<FramedRead<ServerStreamRead, LengthDelimitedCodec>, std::io::Error>,
|
||||
>;
|
||||
pub type WispStreamWrite = EitherWebSocketWrite<
|
||||
WebSocketWrite<ServerStreamWrite>,
|
||||
GenericWebSocketWrite<FramedWrite<ServerStreamWrite, LengthDelimitedCodec>, std::io::Error>,
|
||||
>;
|
||||
pub type WispResult = (WispStreamRead, WispStreamWrite);
|
||||
|
||||
pub enum ServerRouteResult {
|
||||
Wisp(WispResult, bool),
|
||||
|
@ -190,12 +195,15 @@ pub async fn route(
|
|||
.downcast::<TokioIo<ServerStream>>()
|
||||
.unwrap();
|
||||
let (r, w) = parts.io.into_inner().split();
|
||||
(Cursor::new(parts.read_buf).chain(r), w)
|
||||
(chain(Cursor::new(parts.read_buf), r), w)
|
||||
});
|
||||
|
||||
(callback)(
|
||||
ServerRouteResult::Wisp(
|
||||
(Box::new(read), Box::new(write)),
|
||||
(
|
||||
EitherWebSocketRead::Left(read),
|
||||
EitherWebSocketWrite::Left(write),
|
||||
),
|
||||
is_v2,
|
||||
),
|
||||
maybe_ip,
|
||||
|
@ -229,7 +237,13 @@ pub async fn route(
|
|||
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
|
||||
|
||||
(callback)(
|
||||
ServerRouteResult::Wisp((Box::new(read), Box::new(write)), true),
|
||||
ServerRouteResult::Wisp(
|
||||
(
|
||||
EitherWebSocketRead::Right(read),
|
||||
EitherWebSocketWrite::Right(write),
|
||||
),
|
||||
true,
|
||||
),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ pub enum ClientStream {
|
|||
Invalid,
|
||||
}
|
||||
|
||||
|
||||
// taken from rust 1.82.0
|
||||
fn ipv4_is_global(addr: &Ipv4Addr) -> bool {
|
||||
!(addr.octets()[0] == 0 // "This network"
|
||||
|
|
100
server/src/util_chain.rs
Normal file
100
server/src/util_chain.rs
Normal file
|
@ -0,0 +1,100 @@
|
|||
// taken from tokio io util
|
||||
|
||||
use std::{
|
||||
fmt, io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
pub struct Chain<T, U> {
|
||||
#[pin]
|
||||
first: T,
|
||||
#[pin]
|
||||
second: U,
|
||||
done_first: bool,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chain<T, U>(first: T, second: U) -> Chain<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
Chain {
|
||||
first,
|
||||
second,
|
||||
done_first: false,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> fmt::Debug for Chain<T, U>
|
||||
where
|
||||
T: fmt::Debug,
|
||||
U: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Chain")
|
||||
.field("t", &self.first)
|
||||
.field("u", &self.second)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncRead for Chain<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let me = self.project();
|
||||
|
||||
if !*me.done_first {
|
||||
let rem = buf.remaining();
|
||||
ready!(me.first.poll_read(cx, buf))?;
|
||||
if buf.remaining() == rem {
|
||||
*me.done_first = true;
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
me.second.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncBufRead for Chain<T, U>
|
||||
where
|
||||
T: AsyncBufRead,
|
||||
U: AsyncBufRead,
|
||||
{
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
let me = self.project();
|
||||
|
||||
if !*me.done_first {
|
||||
match ready!(me.first.poll_fill_buf(cx)?) {
|
||||
[] => {
|
||||
*me.done_first = true;
|
||||
}
|
||||
buf => return Poll::Ready(Ok(buf)),
|
||||
}
|
||||
}
|
||||
me.second.poll_fill_buf(cx)
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
let me = self.project();
|
||||
if !*me.done_first {
|
||||
me.first.consume(amt)
|
||||
} else {
|
||||
me.second.consume(amt)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -245,7 +245,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
}));
|
||||
threads.push(tokio::spawn(async move {
|
||||
loop {
|
||||
cr.read().await;
|
||||
let _ = cr.read().await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ futures = "0.3.30"
|
|||
getrandom = { version = "0.2.15", features = ["std"], optional = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
pin-project-lite = "0.2.14"
|
||||
reusable-box-future = "0.2.0"
|
||||
thiserror = "1.0.65"
|
||||
tokio = { version = "1.39.3", optional = true, default-features = false }
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ use ed25519::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{DynWebSocketRead, LockingWebSocketWrite},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
|
@ -183,8 +183,8 @@ impl ProtocolExtension for CertAuthProtocolExtension {
|
|||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -192,8 +192,8 @@ impl ProtocolExtension for CertAuthProtocolExtension {
|
|||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ use async_trait::async_trait;
|
|||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{DynWebSocketRead, LockingWebSocketWrite},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
|
@ -105,16 +105,16 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static {
|
|||
/// This should be used to send or receive data before any streams are created.
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
read: &mut DynWebSocketRead,
|
||||
write: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
/// Handle receiving a packet.
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
packet: Bytes,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
read: &mut DynWebSocketRead,
|
||||
write: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
/// Clone the protocol extension.
|
||||
|
|
|
@ -6,7 +6,7 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{DynWebSocketRead, LockingWebSocketWrite},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
|
@ -48,8 +48,8 @@ impl ProtocolExtension for MotdProtocolExtension {
|
|||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -57,8 +57,8 @@ impl ProtocolExtension for MotdProtocolExtension {
|
|||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use async_trait::async_trait;
|
|||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{DynWebSocketRead, LockingWebSocketWrite},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
|
@ -94,17 +94,17 @@ impl ProtocolExtension for PasswordProtocolExtension {
|
|||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_read: &mut dyn WebSocketRead,
|
||||
_write: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_packet: Bytes,
|
||||
_read: &mut dyn WebSocketRead,
|
||||
_write: &LockedWebSocketWrite,
|
||||
_: Bytes,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Err(WispError::ExtensionImplNotSupported)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{DynWebSocketRead, LockingWebSocketWrite},
|
||||
WispError,
|
||||
};
|
||||
|
||||
|
@ -40,8 +40,8 @@ impl ProtocolExtension for UdpProtocolExtension {
|
|||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -49,8 +49,8 @@ impl ProtocolExtension for UdpProtocolExtension {
|
|||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_: &mut DynWebSocketRead,
|
||||
_: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
use std::ops::Deref;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use fastwebsockets::{
|
||||
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead,
|
||||
|
@ -10,7 +9,7 @@ use fastwebsockets::{
|
|||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::{ws::LockedWebSocketWrite, WispError};
|
||||
use crate::{ws::LockingWebSocketWrite, WispError};
|
||||
|
||||
fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> {
|
||||
match payload {
|
||||
|
@ -87,27 +86,25 @@ impl From<WebSocketError> for crate::WispError {
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<crate::ws::Frame<'static>, WispError> {
|
||||
Ok(self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
|
||||
.await?
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<crate::ws::Frame<'static>, WispError> {
|
||||
let mut frame = self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
|
||||
.await?;
|
||||
|
||||
if frame.opcode == OpCode::Continuation {
|
||||
|
@ -121,7 +118,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
|
|||
|
||||
while !frame.fin {
|
||||
frame = self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
|
||||
.await?;
|
||||
|
||||
if frame.opcode != OpCode::Continuation {
|
||||
|
@ -142,11 +139,11 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
|
|||
|
||||
async fn wisp_read_split(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(crate::ws::Frame<'static>, Option<crate::ws::Frame<'static>>), WispError> {
|
||||
let mut frame_cnt = 1;
|
||||
let mut frame = self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
|
||||
.await?;
|
||||
let mut extra_frame = None;
|
||||
|
||||
|
@ -161,7 +158,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
|
|||
|
||||
while !frame.fin {
|
||||
frame = self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await })
|
||||
.await?;
|
||||
|
||||
if frame.opcode != OpCode::Continuation {
|
||||
|
@ -197,7 +194,6 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S>
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
//! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
|
||||
ws::{Frame, LockingWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
|
||||
WispError,
|
||||
};
|
||||
|
||||
|
@ -30,13 +29,12 @@ impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Sen
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static>
|
||||
WebSocketRead for GenericWebSocketRead<T, E>
|
||||
{
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
_tx: &LockedWebSocketWrite,
|
||||
_tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
match self.0.next().await {
|
||||
Some(data) => Ok(Frame::binary(Payload::Bytes(
|
||||
|
@ -67,7 +65,6 @@ impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static>
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite
|
||||
for GenericWebSocketWrite<T, E>
|
||||
{
|
||||
|
|
|
@ -12,7 +12,7 @@ use futures::channel::oneshot;
|
|||
use crate::{
|
||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||
mux::send_info_packet,
|
||||
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
||||
WispError,
|
||||
};
|
||||
|
@ -24,9 +24,9 @@ use super::{
|
|||
WispV2Handshake,
|
||||
};
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite,
|
||||
tx: &LockedWebSocketWrite<W>,
|
||||
v2_info: Option<WispV2Handshake>,
|
||||
) -> Result<(WispHandshakeResult, u32), WispError> {
|
||||
if let Some(WispV2Handshake {
|
||||
|
@ -47,7 +47,9 @@ async fn handshake<R: WebSocketRead>(
|
|||
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(rx, tx).await?;
|
||||
extension
|
||||
.handle_handshake(DynWebSocketRead::from_mut(rx), tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok((
|
||||
|
@ -86,34 +88,36 @@ async fn handshake<R: WebSocketRead>(
|
|||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
pub struct ClientMux {
|
||||
pub struct ClientMux<W: WebSocketWrite + 'static> {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
/// If this variable is true you must assume no extensions are supported.
|
||||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||
actor_tx: mpsc::Sender<WsEvent>,
|
||||
tx: LockedWebSocketWrite,
|
||||
actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
actor_exited: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl ClientMux {
|
||||
impl<W: WebSocketWrite + 'static> ClientMux<W> {
|
||||
/// Create a new client side multiplexor.
|
||||
///
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn create<R, W>(
|
||||
pub async fn create<R>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
wisp_v2: Option<WispV2Handshake>,
|
||||
) -> Result<MuxResult<ClientMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
) -> Result<
|
||||
MuxResult<ClientMux<W>, impl Future<Output = Result<(), WispError>> + Send>,
|
||||
WispError,
|
||||
>
|
||||
where
|
||||
R: WebSocketRead + Send,
|
||||
W: WebSocketWrite + Send + 'static,
|
||||
R: WebSocketRead + 'static,
|
||||
{
|
||||
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||
let tx = LockedWebSocketWrite::new(tx);
|
||||
|
||||
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
|
||||
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||
|
@ -146,7 +150,7 @@ impl ClientMux {
|
|||
stream_type: StreamType,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<MuxStream, WispError> {
|
||||
) -> Result<MuxStream<W>, WispError> {
|
||||
if self.actor_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
|
@ -206,7 +210,7 @@ impl ClientMux {
|
|||
}
|
||||
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
|
@ -215,13 +219,13 @@ impl ClientMux {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for ClientMux {
|
||||
impl<W: WebSocketWrite + 'static> Drop for ClientMux<W> {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||
}
|
||||
}
|
||||
|
||||
impl Multiplexor for ClientMux {
|
||||
impl<W: WebSocketWrite + 'static> Multiplexor for ClientMux<W> {
|
||||
fn has_extension(&self, extension_id: u8) -> bool {
|
||||
self.supported_extensions
|
||||
.iter()
|
||||
|
|
|
@ -5,23 +5,23 @@ use std::sync::{
|
|||
|
||||
use crate::{
|
||||
extensions::AnyProtocolExtension,
|
||||
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
|
||||
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
|
||||
Role, StreamType, WispError,
|
||||
};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use event_listener::Event;
|
||||
use flume as mpsc;
|
||||
use futures::{channel::oneshot, select, FutureExt};
|
||||
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
|
||||
use nohash_hasher::IntMap;
|
||||
|
||||
pub(crate) enum WsEvent {
|
||||
pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
|
||||
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||
CreateStream(
|
||||
StreamType,
|
||||
String,
|
||||
u16,
|
||||
oneshot::Sender<Result<MuxStream, WispError>>,
|
||||
oneshot::Sender<Result<MuxStream<W>, WispError>>,
|
||||
),
|
||||
SendPing(Payload<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||
SendPong(Payload<'static>),
|
||||
|
@ -43,20 +43,21 @@ struct MuxMapValue {
|
|||
is_closed_event: Arc<Event>,
|
||||
}
|
||||
|
||||
pub struct MuxInner<R: WebSocketRead + Send> {
|
||||
pub struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
|
||||
// gets taken by the mux task
|
||||
rx: Option<R>,
|
||||
// gets taken by the mux task
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
|
||||
tx: LockedWebSocketWrite,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
// gets taken by the mux task
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
tcp_extensions: Vec<u8>,
|
||||
role: Role,
|
||||
|
||||
// gets taken by the mux task
|
||||
actor_rx: Option<mpsc::Receiver<WsEvent>>,
|
||||
actor_tx: mpsc::Sender<WsEvent>,
|
||||
actor_rx: Option<mpsc::Receiver<WsEvent<W>>>,
|
||||
actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
fut_exited: Arc<AtomicBool>,
|
||||
|
||||
stream_map: IntMap<u32, MuxMapValue>,
|
||||
|
@ -64,16 +65,16 @@ pub struct MuxInner<R: WebSocketRead + Send> {
|
|||
buffer_size: u32,
|
||||
target_buffer_size: u32,
|
||||
|
||||
server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||
server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>,
|
||||
}
|
||||
|
||||
pub struct MuxInnerResult<R: WebSocketRead + Send> {
|
||||
pub mux: MuxInner<R>,
|
||||
pub struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
|
||||
pub mux: MuxInner<R, W>,
|
||||
pub actor_exited: Arc<AtomicBool>,
|
||||
pub actor_tx: mpsc::Sender<WsEvent>,
|
||||
pub actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
}
|
||||
|
||||
impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||
impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
|
||||
fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec<u8> {
|
||||
extensions
|
||||
.iter()
|
||||
|
@ -83,18 +84,19 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn new_server(
|
||||
rx: R,
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
tx: LockedWebSocketWrite,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
buffer_size: u32,
|
||||
) -> (
|
||||
MuxInnerResult<R>,
|
||||
mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||
MuxInnerResult<R, W>,
|
||||
mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
|
||||
) {
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
|
||||
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
|
||||
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let ret_fut_tx = fut_tx.clone();
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
|
||||
|
@ -110,7 +112,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
fut_exited: fut_exited.clone(),
|
||||
|
||||
tcp_extensions: Self::get_tcp_extensions(&extensions),
|
||||
extensions,
|
||||
extensions: Some(extensions),
|
||||
buffer_size,
|
||||
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32,
|
||||
|
||||
|
@ -130,12 +132,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
pub fn new_client(
|
||||
rx: R,
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
tx: LockedWebSocketWrite,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
buffer_size: u32,
|
||||
) -> MuxInnerResult<R> {
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
|
||||
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
) -> MuxInnerResult<R, W> {
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
|
||||
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let ret_fut_tx = fut_tx.clone();
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
|
||||
|
@ -150,7 +152,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
fut_exited: fut_exited.clone(),
|
||||
|
||||
tcp_extensions: Self::get_tcp_extensions(&extensions),
|
||||
extensions,
|
||||
extensions: Some(extensions),
|
||||
buffer_size,
|
||||
target_buffer_size: 0,
|
||||
|
||||
|
@ -183,7 +185,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
&mut self,
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||
) -> Result<(MuxMapValue, MuxStream<W>), WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
|
||||
self.buffer_size as usize
|
||||
} else {
|
||||
|
@ -241,11 +243,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
}
|
||||
|
||||
async fn process_wisp_message(
|
||||
&mut self,
|
||||
rx: &mut R,
|
||||
msg: Result<(Frame<'static>, Option<Frame<'static>>), WispError>,
|
||||
) -> Result<Option<WsEvent>, WispError> {
|
||||
let (mut frame, optional_frame) = msg?;
|
||||
tx: &LockedWebSocketWrite<W>,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
msg: (Frame<'static>, Option<Frame<'static>>),
|
||||
) -> Result<Option<WsEvent<W>>, WispError> {
|
||||
let (mut frame, optional_frame) = msg;
|
||||
if frame.opcode == OpCode::Close {
|
||||
return Ok(None);
|
||||
} else if frame.opcode == OpCode::Ping {
|
||||
|
@ -262,8 +265,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
}
|
||||
}
|
||||
|
||||
let packet =
|
||||
Packet::maybe_handle_extension(frame, &mut self.extensions, rx, &self.tx).await?;
|
||||
let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?;
|
||||
|
||||
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
|
||||
}
|
||||
|
@ -271,36 +273,47 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
|||
async fn stream_loop(&mut self) -> Result<(), WispError> {
|
||||
let mut next_free_stream_id: u32 = 1;
|
||||
|
||||
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
|
||||
|
||||
let tx = self.tx.clone();
|
||||
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
|
||||
let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
|
||||
if let Some(downgrade_packet) = maybe_downgrade_packet {
|
||||
if self.handle_packet(downgrade_packet, None).await? {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut read_stream = Box::pin(unfold(
|
||||
(rx, tx.clone(), extensions),
|
||||
|(mut rx, tx, mut extensions)| async {
|
||||
let ret = async {
|
||||
let msg = rx.wisp_read_split(&tx).await?;
|
||||
Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await
|
||||
}
|
||||
.await;
|
||||
ret.transpose().map(|x| (x, (rx, tx, extensions)))
|
||||
},
|
||||
))
|
||||
.fuse();
|
||||
|
||||
let mut recv_fut = fut_rx.recv_async().fuse();
|
||||
let mut read_fut = rx.wisp_read_split(&tx).fuse();
|
||||
while let Some(msg) = select! {
|
||||
x = recv_fut => {
|
||||
drop(recv_fut);
|
||||
recv_fut = fut_rx.recv_async().fuse();
|
||||
Ok(x.ok())
|
||||
},
|
||||
x = read_fut => {
|
||||
drop(read_fut);
|
||||
let ret = self.process_wisp_message(&mut rx, x).await;
|
||||
read_fut = rx.wisp_read_split(&tx).fuse();
|
||||
ret
|
||||
x = read_stream.next() => {
|
||||
x.transpose()
|
||||
}
|
||||
}? {
|
||||
match msg {
|
||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||
let ret: Result<MuxStream, WispError> = async {
|
||||
let ret: Result<MuxStream<W>, WispError> = async {
|
||||
let stream_id = next_free_stream_id;
|
||||
let next_stream_id = next_free_stream_id
|
||||
.checked_add(1)
|
||||
|
|
|
@ -8,7 +8,7 @@ pub use server::ServerMux;
|
|||
|
||||
use crate::{
|
||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||
ws::LockedWebSocketWrite,
|
||||
ws::{LockedWebSocketWrite, WebSocketWrite},
|
||||
CloseReason, Packet, PacketType, Role, WispError,
|
||||
};
|
||||
|
||||
|
@ -35,8 +35,8 @@ impl WispHandshakeResultKind {
|
|||
}
|
||||
}
|
||||
|
||||
async fn send_info_packet(
|
||||
write: &LockedWebSocketWrite,
|
||||
async fn send_info_packet<W: WebSocketWrite>(
|
||||
write: &LockedWebSocketWrite<W>,
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<(), WispError> {
|
||||
write
|
||||
|
|
|
@ -11,7 +11,7 @@ use futures::channel::oneshot;
|
|||
|
||||
use crate::{
|
||||
extensions::AnyProtocolExtension,
|
||||
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
|
||||
WispError,
|
||||
};
|
||||
|
@ -23,9 +23,9 @@ use super::{
|
|||
WispV2Handshake,
|
||||
};
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite,
|
||||
tx: &LockedWebSocketWrite<W>,
|
||||
buffer_size: u32,
|
||||
v2_info: Option<WispV2Handshake>,
|
||||
) -> Result<WispHandshakeResult, WispError> {
|
||||
|
@ -47,7 +47,9 @@ async fn handshake<R: WebSocketRead>(
|
|||
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(rx, tx).await?;
|
||||
extension
|
||||
.handle_handshake(DynWebSocketRead::from_mut(rx), tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// v2 client
|
||||
|
@ -79,36 +81,38 @@ async fn handshake<R: WebSocketRead>(
|
|||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
pub struct ServerMux {
|
||||
pub struct ServerMux<W: WebSocketWrite + 'static> {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
/// If this variable is true you must assume no extensions are supported.
|
||||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||
actor_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||
tx: LockedWebSocketWrite,
|
||||
actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
actor_exited: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl ServerMux {
|
||||
impl<W: WebSocketWrite + 'static> ServerMux<W> {
|
||||
/// Create a new server-side multiplexor.
|
||||
///
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn create<R, W>(
|
||||
pub async fn create<R>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
buffer_size: u32,
|
||||
wisp_v2: Option<WispV2Handshake>,
|
||||
) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
) -> Result<
|
||||
MuxResult<ServerMux<W>, impl Future<Output = Result<(), WispError>> + Send>,
|
||||
WispError,
|
||||
>
|
||||
where
|
||||
R: WebSocketRead + Send,
|
||||
W: WebSocketWrite + Send + 'static,
|
||||
R: WebSocketRead + Send + 'static,
|
||||
{
|
||||
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||
let tx = LockedWebSocketWrite::new(tx);
|
||||
let ret_tx = tx.clone();
|
||||
let ret = async {
|
||||
let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?;
|
||||
|
@ -165,7 +169,7 @@ impl ServerMux {
|
|||
}
|
||||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||
if self.actor_exited.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
|
@ -210,7 +214,7 @@ impl ServerMux {
|
|||
}
|
||||
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
|
@ -219,13 +223,13 @@ impl ServerMux {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerMux {
|
||||
impl<W: WebSocketWrite + 'static> Drop for ServerMux<W> {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||
}
|
||||
}
|
||||
|
||||
impl Multiplexor for ServerMux {
|
||||
impl<W: WebSocketWrite + 'static> Multiplexor for ServerMux<W> {
|
||||
fn has_extension(&self, extension_id: u8) -> bool {
|
||||
self.supported_extensions
|
||||
.iter()
|
||||
|
|
|
@ -2,7 +2,10 @@ use std::fmt::Display;
|
|||
|
||||
use crate::{
|
||||
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||
ws::{
|
||||
self, DynWebSocketRead, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead,
|
||||
WebSocketWrite,
|
||||
},
|
||||
Role, WispError, WISP_VERSION,
|
||||
};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
@ -527,11 +530,11 @@ impl<'a> Packet<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn maybe_handle_extension(
|
||||
pub(crate) async fn maybe_handle_extension<R: WebSocketRead + 'static, W: WebSocketWrite>(
|
||||
frame: Frame<'a>,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
read: &mut (dyn WebSocketRead + Send),
|
||||
write: &LockedWebSocketWrite,
|
||||
read: &mut R,
|
||||
write: &LockedWebSocketWrite<W>,
|
||||
) -> Result<Option<Self>, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
|
@ -568,7 +571,11 @@ impl<'a> Packet<'a> {
|
|||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||
{
|
||||
extension
|
||||
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
|
||||
.handle_packet(
|
||||
BytesMut::from(bytes).freeze(),
|
||||
DynWebSocketRead::from_mut(read),
|
||||
write,
|
||||
)
|
||||
.await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
|
|
|
@ -98,7 +98,10 @@ impl MuxStreamIoStream {
|
|||
impl Stream for MuxStreamIoStream {
|
||||
type Item = Result<Bytes, std::io::Error>;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project().rx.poll_next(cx).map_err(std::io::Error::other)
|
||||
self.project()
|
||||
.rx
|
||||
.poll_next(cx)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ pub use compat::*;
|
|||
|
||||
use crate::{
|
||||
inner::WsEvent,
|
||||
ws::{Frame, LockedWebSocketWrite, Payload},
|
||||
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketWrite},
|
||||
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
|
||||
};
|
||||
|
||||
|
@ -21,7 +21,7 @@ use std::{
|
|||
};
|
||||
|
||||
/// Read side of a multiplexor stream.
|
||||
pub struct MuxStreamRead {
|
||||
pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
/// Type of the stream.
|
||||
|
@ -29,7 +29,7 @@ pub struct MuxStreamRead {
|
|||
|
||||
role: Role,
|
||||
|
||||
tx: LockedWebSocketWrite,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
rx: mpsc::Receiver<Bytes>,
|
||||
|
||||
is_closed: Arc<AtomicBool>,
|
||||
|
@ -42,7 +42,7 @@ pub struct MuxStreamRead {
|
|||
target_flow_control: u32,
|
||||
}
|
||||
|
||||
impl MuxStreamRead {
|
||||
impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
|
||||
/// Read an event from the stream.
|
||||
pub async fn read(&self) -> Result<Option<Bytes>, WispError> {
|
||||
if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) {
|
||||
|
@ -98,15 +98,15 @@ impl MuxStreamRead {
|
|||
}
|
||||
|
||||
/// Write side of a multiplexor stream.
|
||||
pub struct MuxStreamWrite {
|
||||
pub struct MuxStreamWrite<W: WebSocketWrite + 'static> {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
/// Type of the stream.
|
||||
pub stream_type: StreamType,
|
||||
|
||||
role: Role,
|
||||
mux_tx: mpsc::Sender<WsEvent>,
|
||||
tx: LockedWebSocketWrite,
|
||||
mux_tx: mpsc::Sender<WsEvent<W>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
|
||||
is_closed: Arc<AtomicBool>,
|
||||
close_reason: Arc<AtomicCloseReason>,
|
||||
|
@ -116,7 +116,7 @@ pub struct MuxStreamWrite {
|
|||
flow_control: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl MuxStreamWrite {
|
||||
impl<W: WebSocketWrite + 'static> MuxStreamWrite<W> {
|
||||
pub(crate) async fn write_payload_internal<'a>(
|
||||
&self,
|
||||
header: Frame<'static>,
|
||||
|
@ -169,7 +169,7 @@ impl MuxStreamWrite {
|
|||
/// handle.close(0x01);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser<W> {
|
||||
MuxStreamCloser {
|
||||
stream_id: self.stream_id,
|
||||
close_channel: self.mux_tx.clone(),
|
||||
|
@ -179,7 +179,7 @@ impl MuxStreamWrite {
|
|||
}
|
||||
|
||||
/// Get a protocol extension stream to send protocol extension packets.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: self.stream_id,
|
||||
tx: self.tx.clone(),
|
||||
|
@ -244,7 +244,7 @@ impl MuxStreamWrite {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for MuxStreamWrite {
|
||||
impl<W: WebSocketWrite + 'static> Drop for MuxStreamWrite<W> {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_closed.load(Ordering::Acquire) {
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
|
@ -258,22 +258,22 @@ impl Drop for MuxStreamWrite {
|
|||
}
|
||||
|
||||
/// Multiplexor stream.
|
||||
pub struct MuxStream {
|
||||
pub struct MuxStream<W: WebSocketWrite + 'static> {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
rx: MuxStreamRead,
|
||||
tx: MuxStreamWrite,
|
||||
rx: MuxStreamRead<W>,
|
||||
tx: MuxStreamWrite<W>,
|
||||
}
|
||||
|
||||
impl MuxStream {
|
||||
impl<W: WebSocketWrite + 'static> MuxStream<W> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
stream_id: u32,
|
||||
role: Role,
|
||||
stream_type: StreamType,
|
||||
rx: mpsc::Receiver<Bytes>,
|
||||
mux_tx: mpsc::Sender<WsEvent>,
|
||||
tx: LockedWebSocketWrite,
|
||||
mux_tx: mpsc::Sender<WsEvent<W>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
is_closed_event: Arc<Event>,
|
||||
close_reason: Arc<AtomicCloseReason>,
|
||||
|
@ -339,12 +339,12 @@ impl MuxStream {
|
|||
/// handle.close(0x01);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser<W> {
|
||||
self.tx.get_close_handle()
|
||||
}
|
||||
|
||||
/// Get a protocol extension stream to send protocol extension packets.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
self.tx.get_protocol_extension_stream()
|
||||
}
|
||||
|
||||
|
@ -359,7 +359,7 @@ impl MuxStream {
|
|||
}
|
||||
|
||||
/// Split the stream into read and write parts, consuming it.
|
||||
pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
|
||||
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
|
||||
(self.rx, self.tx)
|
||||
}
|
||||
|
||||
|
@ -374,15 +374,15 @@ impl MuxStream {
|
|||
|
||||
/// Close handle for a multiplexor stream.
|
||||
#[derive(Clone)]
|
||||
pub struct MuxStreamCloser {
|
||||
pub struct MuxStreamCloser<W: WebSocketWrite + 'static> {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
close_channel: mpsc::Sender<WsEvent>,
|
||||
close_channel: mpsc::Sender<WsEvent<W>>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
close_reason: Arc<AtomicCloseReason>,
|
||||
}
|
||||
|
||||
impl MuxStreamCloser {
|
||||
impl<W: WebSocketWrite + 'static> MuxStreamCloser<W> {
|
||||
/// Close the stream. You will no longer be able to write or read after this has been called.
|
||||
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
|
@ -414,14 +414,14 @@ impl MuxStreamCloser {
|
|||
}
|
||||
|
||||
/// Stream for sending arbitrary protocol extension packets.
|
||||
pub struct MuxProtocolExtensionStream {
|
||||
pub struct MuxProtocolExtensionStream<W: WebSocketWrite + 'static> {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
pub(crate) tx: LockedWebSocketWrite,
|
||||
pub(crate) tx: LockedWebSocketWrite<W>,
|
||||
pub(crate) is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl MuxProtocolExtensionStream {
|
||||
impl<W: WebSocketWrite + 'static> MuxProtocolExtensionStream<W> {
|
||||
/// Send a protocol extension packet with this stream's ID.
|
||||
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
|
|
386
wisp/src/ws.rs
386
wisp/src/ws.rs
|
@ -4,12 +4,11 @@
|
|||
//! for other WebSocket implementations.
|
||||
//!
|
||||
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use std::{future::Future, ops::Deref, pin::Pin, sync::Arc};
|
||||
|
||||
use crate::WispError;
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::lock::Mutex;
|
||||
use futures::{lock::Mutex, TryFutureExt};
|
||||
|
||||
/// Payload of the websocket frame.
|
||||
#[derive(Debug)]
|
||||
|
@ -158,101 +157,286 @@ impl<'a> Frame<'a> {
|
|||
}
|
||||
|
||||
/// Generic WebSocket read trait.
|
||||
#[async_trait]
|
||||
pub trait WebSocketRead {
|
||||
pub trait WebSocketRead: Send {
|
||||
/// Read a frame from the socket.
|
||||
async fn wisp_read_frame(
|
||||
fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError>;
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> impl Future<Output = Result<Frame<'static>, WispError>> + Send;
|
||||
|
||||
/// Read a split frame from the socket.
|
||||
async fn wisp_read_split(
|
||||
fn wisp_read_split(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
|
||||
self.wisp_read_frame(tx).await.map(|x| (x, None))
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> impl Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>> + Send {
|
||||
self.wisp_read_frame(tx).map_ok(|x| (x, None))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WebSocketRead for Box<dyn WebSocketRead + Send> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
self.as_mut().wisp_read_frame(tx).await
|
||||
// similar to what dynosaur does
|
||||
mod wsr_inner {
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use crate::WispError;
|
||||
|
||||
use super::{Frame, LockingWebSocketWrite, WebSocketRead};
|
||||
|
||||
trait ErasedWebSocketRead: Send {
|
||||
fn wisp_read_frame<'a>(
|
||||
&'a mut self,
|
||||
tx: &'a dyn LockingWebSocketWrite,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>>;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn wisp_read_split<'a>(
|
||||
&'a mut self,
|
||||
tx: &'a dyn LockingWebSocketWrite,
|
||||
) -> Pin<
|
||||
Box<
|
||||
dyn Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>>
|
||||
+ Send
|
||||
+ 'a,
|
||||
>,
|
||||
>;
|
||||
}
|
||||
|
||||
async fn wisp_read_split(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
|
||||
self.as_mut().wisp_read_split(tx).await
|
||||
impl<T: WebSocketRead> ErasedWebSocketRead for T {
|
||||
fn wisp_read_frame<'a>(
|
||||
&'a mut self,
|
||||
tx: &'a dyn LockingWebSocketWrite,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>> {
|
||||
Box::pin(self.wisp_read_frame(tx))
|
||||
}
|
||||
|
||||
fn wisp_read_split<'a>(
|
||||
&'a mut self,
|
||||
tx: &'a dyn LockingWebSocketWrite,
|
||||
) -> Pin<
|
||||
Box<
|
||||
dyn Future<Output = Result<(Frame<'static>, Option<Frame<'static>>), WispError>>
|
||||
+ Send
|
||||
+ 'a,
|
||||
>,
|
||||
> {
|
||||
Box::pin(self.wisp_read_split(tx))
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocketRead trait object.
|
||||
#[repr(transparent)]
|
||||
pub struct DynWebSocketRead {
|
||||
ptr: dyn ErasedWebSocketRead + 'static,
|
||||
}
|
||||
impl WebSocketRead for DynWebSocketRead {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
self.ptr.wisp_read_frame(tx).await
|
||||
}
|
||||
|
||||
async fn wisp_read_split(
|
||||
&mut self,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
|
||||
self.ptr.wisp_read_split(tx).await
|
||||
}
|
||||
}
|
||||
impl DynWebSocketRead {
|
||||
/// Create a WebSocketRead trait object from a boxed WebSocketRead.
|
||||
pub fn new(val: Box<impl WebSocketRead + 'static>) -> Box<Self> {
|
||||
let val: Box<dyn ErasedWebSocketRead + 'static> = val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
/// Create a WebSocketRead trait object from a WebSocketRead.
|
||||
pub fn boxed(val: impl WebSocketRead + 'static) -> Box<Self> {
|
||||
Self::new(Box::new(val))
|
||||
}
|
||||
/// Create a WebSocketRead trait object from a WebSocketRead reference.
|
||||
pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self {
|
||||
let val: &(dyn ErasedWebSocketRead + 'static) = val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
/// Create a WebSocketRead trait object from a mutable WebSocketRead reference.
|
||||
pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self {
|
||||
let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use wsr_inner::DynWebSocketRead;
|
||||
|
||||
/// Generic WebSocket write trait.
|
||||
#[async_trait]
|
||||
pub trait WebSocketWrite {
|
||||
pub trait WebSocketWrite: Send {
|
||||
/// Write a frame to the socket.
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
|
||||
|
||||
/// Close the socket.
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||
fn wisp_write_frame(
|
||||
&mut self,
|
||||
frame: Frame<'_>,
|
||||
) -> impl Future<Output = Result<(), WispError>> + Send;
|
||||
|
||||
/// Write a split frame to the socket.
|
||||
async fn wisp_write_split(
|
||||
fn wisp_write_split(
|
||||
&mut self,
|
||||
header: Frame<'_>,
|
||||
body: Frame<'_>,
|
||||
) -> Result<(), WispError> {
|
||||
let mut payload = BytesMut::from(header.payload);
|
||||
payload.extend_from_slice(&body.payload);
|
||||
self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
|
||||
.await
|
||||
) -> impl Future<Output = Result<(), WispError>> + Send {
|
||||
async move {
|
||||
let mut payload = BytesMut::from(header.payload);
|
||||
payload.extend_from_slice(&body.payload);
|
||||
self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Close the socket.
|
||||
fn wisp_close(&mut self) -> impl Future<Output = Result<(), WispError>> + Send;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WebSocketWrite for Box<dyn WebSocketWrite + Send> {
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
self.as_mut().wisp_write_frame(frame).await
|
||||
// similar to what dynosaur does
|
||||
mod wsw_inner {
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use crate::WispError;
|
||||
|
||||
use super::{Frame, WebSocketWrite};
|
||||
|
||||
trait ErasedWebSocketWrite: Send {
|
||||
fn wisp_write_frame<'a>(
|
||||
&'a mut self,
|
||||
frame: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
|
||||
fn wisp_write_split<'a>(
|
||||
&'a mut self,
|
||||
header: Frame<'a>,
|
||||
body: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
|
||||
fn wisp_close<'a>(
|
||||
&'a mut self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.as_mut().wisp_close().await
|
||||
impl<T: WebSocketWrite> ErasedWebSocketWrite for T {
|
||||
fn wisp_write_frame<'a>(
|
||||
&'a mut self,
|
||||
frame: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.wisp_write_frame(frame))
|
||||
}
|
||||
|
||||
fn wisp_write_split<'a>(
|
||||
&'a mut self,
|
||||
header: Frame<'a>,
|
||||
body: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.wisp_write_split(header, body))
|
||||
}
|
||||
|
||||
fn wisp_close<'a>(
|
||||
&'a mut self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.wisp_close())
|
||||
}
|
||||
}
|
||||
|
||||
async fn wisp_write_split(
|
||||
&mut self,
|
||||
header: Frame<'_>,
|
||||
body: Frame<'_>,
|
||||
) -> Result<(), WispError> {
|
||||
self.as_mut().wisp_write_split(header, body).await
|
||||
/// WebSocketWrite trait object.
|
||||
#[repr(transparent)]
|
||||
pub struct DynWebSocketWrite {
|
||||
ptr: dyn ErasedWebSocketWrite + 'static,
|
||||
}
|
||||
impl WebSocketWrite for DynWebSocketWrite {
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
self.ptr.wisp_write_frame(frame).await
|
||||
}
|
||||
|
||||
async fn wisp_write_split(
|
||||
&mut self,
|
||||
header: Frame<'_>,
|
||||
body: Frame<'_>,
|
||||
) -> Result<(), WispError> {
|
||||
self.ptr.wisp_write_split(header, body).await
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.ptr.wisp_close().await
|
||||
}
|
||||
}
|
||||
impl DynWebSocketWrite {
|
||||
/// Create a new WebSocketWrite trait object from a boxed WebSocketWrite.
|
||||
pub fn new(val: Box<impl WebSocketWrite + 'static>) -> Box<Self> {
|
||||
let val: Box<dyn ErasedWebSocketWrite + 'static> = val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
/// Create a new WebSocketWrite trait object from a WebSocketWrite.
|
||||
pub fn boxed(val: impl WebSocketWrite + 'static) -> Box<Self> {
|
||||
Self::new(Box::new(val))
|
||||
}
|
||||
/// Create a new WebSocketWrite trait object from a WebSocketWrite reference.
|
||||
pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self {
|
||||
let val: &(dyn ErasedWebSocketWrite + 'static) = val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
/// Create a new WebSocketWrite trait object from a mutable WebSocketWrite reference.
|
||||
pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self {
|
||||
let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val;
|
||||
unsafe { std::mem::transmute(val) }
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use wsw_inner::DynWebSocketWrite;
|
||||
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
/// Helper trait object for LockedWebSocketWrite.
|
||||
pub trait LockingWebSocketWrite: private::Sealed + Sync {
|
||||
/// Write a frame to the websocket.
|
||||
fn wisp_write_frame<'a>(
|
||||
&'a self,
|
||||
frame: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
|
||||
/// Write a split frame to the websocket.
|
||||
fn wisp_write_split<'a>(
|
||||
&'a self,
|
||||
header: Frame<'a>,
|
||||
body: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
|
||||
/// Close the websocket.
|
||||
fn wisp_close<'a>(&'a self)
|
||||
-> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>>;
|
||||
}
|
||||
|
||||
/// Locked WebSocket.
|
||||
#[derive(Clone)]
|
||||
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
|
||||
pub struct LockedWebSocketWrite<T: WebSocketWrite>(Arc<Mutex<T>>);
|
||||
|
||||
impl LockedWebSocketWrite {
|
||||
impl<T: WebSocketWrite> Clone for LockedWebSocketWrite<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WebSocketWrite> LockedWebSocketWrite<T> {
|
||||
/// Create a new locked websocket.
|
||||
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
||||
pub fn new(ws: T) -> Self {
|
||||
Self(Mutex::new(ws).into())
|
||||
}
|
||||
|
||||
/// Create a new locked websocket from an existing mutex.
|
||||
pub fn from_locked(locked: Arc<Mutex<T>>) -> Self {
|
||||
Self(locked)
|
||||
}
|
||||
|
||||
/// Write a frame to the websocket.
|
||||
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
|
||||
pub(crate) async fn write_split(
|
||||
&self,
|
||||
header: Frame<'_>,
|
||||
body: Frame<'_>,
|
||||
) -> Result<(), WispError> {
|
||||
/// Write a split frame to the websocket.
|
||||
pub async fn write_split(&self, header: Frame<'_>, body: Frame<'_>) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_write_split(header, body).await
|
||||
}
|
||||
|
||||
|
@ -261,3 +445,91 @@ impl LockedWebSocketWrite {
|
|||
self.0.lock().await.wisp_close().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WebSocketWrite> private::Sealed for LockedWebSocketWrite<T> {}
|
||||
|
||||
impl<T: WebSocketWrite> LockingWebSocketWrite for LockedWebSocketWrite<T> {
|
||||
fn wisp_write_frame<'a>(
|
||||
&'a self,
|
||||
frame: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.write_frame(frame))
|
||||
}
|
||||
|
||||
fn wisp_write_split<'a>(
|
||||
&'a self,
|
||||
header: Frame<'a>,
|
||||
body: Frame<'a>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.write_split(header, body))
|
||||
}
|
||||
|
||||
fn wisp_close<'a>(
|
||||
&'a self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Send + 'a>> {
|
||||
Box::pin(self.close())
|
||||
}
|
||||
}
|
||||
|
||||
/// Combines two different WebSocketReads together.
|
||||
pub enum EitherWebSocketRead<A: WebSocketRead, B: WebSocketRead> {
|
||||
/// First WebSocketRead variant.
|
||||
Left(A),
|
||||
/// Second WebSocketRead variant.
|
||||
Right(B),
|
||||
}
|
||||
impl<A: WebSocketRead, B: WebSocketRead> WebSocketRead for EitherWebSocketRead<A, B> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
match self {
|
||||
Self::Left(x) => x.wisp_read_frame(tx).await,
|
||||
Self::Right(x) => x.wisp_read_frame(tx).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn wisp_read_split(
|
||||
&mut self,
|
||||
tx: &dyn LockingWebSocketWrite,
|
||||
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
|
||||
match self {
|
||||
Self::Left(x) => x.wisp_read_split(tx).await,
|
||||
Self::Right(x) => x.wisp_read_split(tx).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combines two different WebSocketWrites together.
|
||||
pub enum EitherWebSocketWrite<A: WebSocketWrite, B: WebSocketWrite> {
|
||||
/// First WebSocketWrite variant.
|
||||
Left(A),
|
||||
/// Second WebSocketWrite variant.
|
||||
Right(B),
|
||||
}
|
||||
impl<A: WebSocketWrite, B: WebSocketWrite> WebSocketWrite for EitherWebSocketWrite<A, B> {
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
match self {
|
||||
Self::Left(x) => x.wisp_write_frame(frame).await,
|
||||
Self::Right(x) => x.wisp_write_frame(frame).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn wisp_write_split(
|
||||
&mut self,
|
||||
header: Frame<'_>,
|
||||
body: Frame<'_>,
|
||||
) -> Result<(), WispError> {
|
||||
match self {
|
||||
Self::Left(x) => x.wisp_write_split(header, body).await,
|
||||
Self::Right(x) => x.wisp_write_split(header, body).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
match self {
|
||||
Self::Left(x) => x.wisp_close().await,
|
||||
Self::Right(x) => x.wisp_close().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue