add wisp to client

This commit is contained in:
Toshit Chawda 2024-01-30 21:15:17 -08:00
parent be7d92b4c5
commit c5cf95fcb1
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
12 changed files with 210 additions and 320 deletions

95
Cargo.lock generated
View file

@ -133,12 +133,6 @@ version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.5.0"
@ -248,6 +242,7 @@ name = "epoxy-client"
version = "1.0.0"
dependencies = [
"async-compression",
"async_io_stream",
"base64",
"bytes",
"console_error_panic_hook",
@ -255,11 +250,10 @@ dependencies = [
"fastwebsockets",
"futures-util",
"getrandom",
"http 1.0.0",
"http",
"http-body-util",
"hyper",
"js-sys",
"penguin-mux-wasm",
"pin-project-lite",
"rand",
"ring",
@ -488,17 +482,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
[[package]]
name = "http"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http"
version = "1.0.0"
@ -517,7 +500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.0.0",
"http",
]
[[package]]
@ -528,7 +511,7 @@ checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"pin-project-lite",
]
@ -554,7 +537,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"httparse",
"httpdate",
@ -573,7 +556,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.0.0",
"http",
"http-body",
"hyper",
"pin-project-lite",
@ -744,16 +727,6 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "parking_lot"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.9"
@ -767,23 +740,6 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "penguin-mux-wasm"
version = "0.1.0"
source = "git+https://github.com/r58Playz/penguin-mux-wasm#69b413aedb6f50f55eac646fda361abe430eb022"
dependencies = [
"bytes",
"futures-util",
"http 0.2.11",
"parking_lot",
"rand",
"thiserror",
"tokio",
"tokio-tungstenite",
"tracing",
"wasm-bindgen-futures",
]
[[package]]
name = "pharos"
version = "0.5.3"
@ -1129,7 +1085,6 @@ dependencies = [
"libc",
"mio",
"num_cpus",
"parking_lot",
"pin-project-lite",
"socket2",
"tokio-macros",
@ -1168,18 +1123,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
@ -1201,21 +1144,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
@ -1231,20 +1162,6 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"log",
"rand",
"thiserror",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.17.0"

View file

@ -16,7 +16,6 @@ http = "1.0.0"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["client", "http1"] }
pin-project-lite = "0.2.13"
penguin-mux-wasm = { git = "https://github.com/r58Playz/penguin-mux-wasm" }
tokio = { version = "1.35.1", default_features = false }
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4.39"
@ -33,7 +32,8 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"]
fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] }
rand = "0.8.5"
base64 = "0.21.7"
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] }
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] }
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
[dependencies.getrandom]
features = ["js"]

View file

@ -8,17 +8,20 @@ mod wrappers;
use tokioio::TokioIo;
use utils::{ReplaceErr, UriExt};
use websocket::EpxWebSocket;
use wrappers::{IncomingBody, WsStreamWrapper};
use wrappers::IncomingBody;
use std::sync::Arc;
use async_compression::tokio::bufread as async_comp;
use async_io_stream::IoStream;
use bytes::Bytes;
use futures_util::StreamExt;
use futures_util::{
stream::SplitSink,
StreamExt,
};
use http::{uri, HeaderName, HeaderValue, Request, Response};
use hyper::{body::Incoming, client::conn::http1::Builder, Uri};
use js_sys::{Array, Function, Object, Reflect, Uint8Array};
use penguin_mux_wasm::{Multiplexor, MuxStream};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
use tokio_util::{
either::Either,
@ -26,6 +29,8 @@ use tokio_util::{
};
use wasm_bindgen::prelude::*;
use web_sys::TextEncoder;
use wisp_mux::{ClientMux, MuxStreamIo, StreamType};
use ws_stream_wasm::{WsMeta, WsStream, WsMessage};
type HttpBody = http_body_util::Full<Bytes>;
@ -40,8 +45,8 @@ enum EpxCompression {
Gzip,
}
type EpxTlsStream = TlsStream<MuxStream<WsStreamWrapper>>;
type EpxUnencryptedStream = MuxStream<WsStreamWrapper>;
type EpxTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
type EpxUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
type EpxStream = Either<EpxTlsStream, EpxUnencryptedStream>;
async fn send_req(
@ -113,7 +118,7 @@ async fn start() {
#[wasm_bindgen]
pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>,
mux: Multiplexor<WsStreamWrapper>,
mux: ClientMux<SplitSink<WsStream, WsMessage>>,
useragent: String,
redirect_limit: usize,
}
@ -138,11 +143,18 @@ impl EpoxyClient {
}
debug!("connecting to ws {:?}", ws_url);
let ws = WsStreamWrapper::connect(ws_url, None)
let (_, ws) = WsMeta::connect(ws_url, vec!["wisp-v1"])
.await
.replace_err("Failed to connect to websocket")?;
debug!("connected!");
let mux = Multiplexor::new(ws, penguin_mux_wasm::Role::Client, None, None);
let (wtx, wrx) = ws.split();
let (mux, fut) = ClientMux::new(wrx, wtx);
wasm_bindgen_futures::spawn_local(async move {
if let Err(err) = fut.await {
error!("epoxy: error in mux future! {:?}", err);
}
});
let mut certstore = RootCertStore::empty();
certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
@ -161,14 +173,16 @@ impl EpoxyClient {
})
}
async fn get_http_io(&self, url: &Uri) -> Result<EpxStream, JsError> {
async fn get_http_io(&mut self, url: &Uri) -> Result<EpxStream, JsError> {
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = utils::get_url_port(url)?;
let channel = self
.mux
.client_new_stream_channel(url_host.as_bytes(), url_port)
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
.await
.replace_err("Failed to create multiplexor channel")?;
.replace_err("Failed to create multiplexor channel")?
.into_io()
.into_asyncrw();
if utils::get_is_secure(url)? {
let cloned_uri = url_host.to_string().clone();
@ -189,7 +203,7 @@ impl EpoxyClient {
}
async fn send_req(
&self,
&mut self,
req: http::Request<HttpBody>,
should_redirect: bool,
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> {
@ -217,7 +231,7 @@ impl EpoxyClient {
// shut up
#[allow(clippy::too_many_arguments)]
pub async fn connect_ws(
&self,
&mut self,
onopen: Function,
onclose: Function,
onerror: Function,
@ -232,7 +246,11 @@ impl EpoxyClient {
.await
}
pub async fn fetch(&self, url: String, options: Object) -> Result<web_sys::Response, JsError> {
pub async fn fetch(
&mut self,
url: String,
options: Object,
) -> Result<web_sys::Response, JsError> {
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?;
let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?;
if *uri_scheme != uri::Scheme::HTTP && *uri_scheme != uri::Scheme::HTTPS {

View file

@ -30,7 +30,7 @@ impl EpxWebSocket {
// shut up
#[allow(clippy::too_many_arguments)]
pub async fn connect(
tcp: &EpoxyClient,
tcp: &mut EpoxyClient,
onopen: Function,
onclose: Function,
onerror: Function,

View file

@ -4,117 +4,9 @@ use std::{
task::{Context, Poll},
};
use futures_util::{Sink, Stream};
use futures_util::Stream;
use hyper::body::Body;
use penguin_mux_wasm::ws;
use pin_project_lite::pin_project;
use ws_stream_wasm::{WsErr, WsMessage, WsMeta, WsStream};
pin_project! {
pub struct WsStreamWrapper {
#[pin]
ws: WsStream,
}
}
impl WsStreamWrapper {
pub async fn connect(
url: impl AsRef<str>,
protocols: impl Into<Option<Vec<&str>>>,
) -> Result<Self, WsErr> {
let (_, wsstream) = WsMeta::connect(url, protocols).await?;
Ok(WsStreamWrapper { ws: wsstream })
}
}
impl Stream for WsStreamWrapper {
type Item = Result<ws::Message, ws::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let ret = this.ws.poll_next(cx);
match ret {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(item.map(|x| {
Ok(match x {
WsMessage::Text(txt) => ws::Message::Text(txt),
WsMessage::Binary(bin) => ws::Message::Binary(bin),
})
})),
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
}
}
}
fn wserr_to_ws_err(err: WsErr) -> ws::Error {
debug!("err: {:?}", err);
match err {
WsErr::ConnectionNotOpen => ws::Error::AlreadyClosed,
_ => ws::Error::Io(std::io::Error::other(format!("{:?}", err))),
}
}
impl Sink<ws::Message> for WsStreamWrapper {
type Error = ws::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
let ret = this.ws.poll_ready(cx);
match ret {
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}),
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
}
}
fn start_send(self: Pin<&mut Self>, item: ws::Message) -> Result<(), Self::Error> {
use ws::Message::*;
let item = match item {
Text(txt) => WsMessage::Text(txt),
Binary(bin) => WsMessage::Binary(bin),
Close(_) => {
debug!("closing");
return match self.ws.wrapped().close() {
Ok(_) => Ok(()),
Err(err) => Err(ws::Error::Io(std::io::Error::other(format!(
"ws close err: {:?}",
err
)))),
};
}
Ping(_) | Pong(_) | Frame(_) => return Ok(()),
};
let this = self.project();
let ret = this.ws.start_send(item);
match ret {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}
}
// no point wrapping this as it's not going to do anything
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
let ret = this.ws.poll_close(cx);
match ret {
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}),
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
}
}
}
impl ws::WebSocketStream for WsStreamWrapper {
fn ping_auto_pong(&self) -> bool {
true
}
}
pin_project! {
pub struct IncomingBody {

View file

@ -101,7 +101,7 @@ async fn accept_http(
async fn handle_mux(
packet: ConnectPacket,
mut stream: MuxStream<impl ws::WebSocketWrite>,
mut stream: MuxStream<impl ws::WebSocketWrite + Send + 'static>,
) -> Result<bool, WispError> {
let uri = format!(
"{}:{}",
@ -174,9 +174,7 @@ async fn accept_ws(
println!("{:?}: connected", addr);
let mut mux = ServerMux::new(rx, tx);
mux.server_loop(&mut |packet, stream| async move {
ServerMux::handle(rx, tx, &mut |packet, stream| async move {
let mut close_err = stream.get_close_handle();
let mut close_ok = stream.get_close_handle();
tokio::spawn(async move {

View file

@ -17,3 +17,4 @@ ws_stream_wasm = { version = "0.7.4", optional = true }
[features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
ws_stream_wasm = ["dep:ws_stream_wasm"]
tokio_io = ["async_io_stream/tokio_io"]

View file

@ -53,10 +53,10 @@ impl From<WebSocketError> for crate::WispError {
}
}
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
@ -65,7 +65,7 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
}
}
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.try_into()?).await.map_err(|e| e.into())
}

View file

@ -10,7 +10,7 @@ pub use crate::packet::*;
pub use crate::stream::*;
use dashmap::DashMap;
use futures::{channel::mpsc, StreamExt};
use futures::{channel::mpsc, Future, StreamExt};
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
@ -68,38 +68,66 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
pub struct ServerMux<R, W>
pub struct ServerMux<W>
where
R: ws::WebSocketRead,
W: ws::WebSocketWrite,
{
rx: R,
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
pub fn new(read: R, write: W) -> Self {
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn handle<'a, FR, R>(
read: R,
write: W,
handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
) -> impl Future<Output = Result<(), WispError>> + 'a
where
FR: std::future::Future<Output = Result<(), WispError>> + 'a,
R: ws::WebSocketRead + 'a,
W: ws::WebSocketWrite + 'a,
{
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
Self {
rx: read,
tx: ws::LockedWebSocketWrite::new(write),
stream_map: Arc::new(DashMap::new()),
close_rx: rx,
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(DashMap::new());
let inner = ServerMux {
stream_map: map.clone(),
tx: write.clone(),
close_tx: tx,
}
};
inner.into_future(read, rx, handler_fn)
}
pub async fn server_bg_loop(&mut self) {
while let Some(msg) = self.close_rx.next().await {
async fn into_future<R, FR>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
FR: std::future::Future<Output = Result<(), WispError>>,
{
futures::try_join! {
self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()),
self.server_msg_loop(rx, handler_fn)
}
.map(|_| ())
}
async fn server_close_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
tx: ws::LockedWebSocketWrite<W>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.clone().remove(&stream_id).is_some() {
if stream_map.clone().remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
tx.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
} else {
@ -108,20 +136,23 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
}
}
}
Ok(())
}
pub async fn server_loop<FR>(
&mut self,
async fn server_msg_loop<R, FR>(
&self,
mut rx: R,
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
) -> Result<(), WispError>
where
FR: std::future::Future<Output = Result<(), crate::WispError>>,
R: ws::WebSocketRead,
FR: std::future::Future<Output = Result<(), WispError>>,
{
self.tx
.write_frame(Packet::new_continue(0, u32::MAX).into())
.await?;
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
@ -164,34 +195,31 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
}
}
pub struct ClientMux<R, W>
pub struct ClientMuxInner<W>
where
R: ws::WebSocketRead,
W: ws::WebSocketWrite,
{
rx: R,
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
next_free_stream_id: AtomicU32,
}
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
pub async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ())
}
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ClientMux<R, W> {
pub fn new(read: R, write: W) -> Self {
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
Self {
rx: read,
tx: ws::LockedWebSocketWrite::new(write),
stream_map: Arc::new(DashMap::new()),
next_free_stream_id: AtomicU32::new(1),
close_rx: rx,
close_tx: tx,
}
}
pub async fn client_bg_loop(&mut self) {
while let Some(msg) = self.close_rx.next().await {
async fn client_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.clone().remove(&stream_id).is_some() {
@ -206,14 +234,14 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ClientMux<R, W> {
}
}
}
Ok(())
}
pub async fn client_loop(&mut self) -> Result<(), WispError> {
self.tx
.write_frame(Packet::new_continue(0, u32::MAX).into())
.await?;
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
@ -235,12 +263,52 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ClientMux<R, W> {
}
Ok(())
}
}
pub struct ClientMux<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
pub fn new<R>(read: R, write: W) -> (Self, impl Future<Output = Result<(), WispError>>)
where
R: ws::WebSocketRead,
{
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let map = Arc::new(DashMap::new());
let write = ws::LockedWebSocketWrite::new(write);
(
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
},
ClientMuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
)
}
pub async fn client_new_stream(
&mut self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream<impl ws::WebSocketWrite>, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
self.tx
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
.await?;
self.next_free_stream_id.store(
stream_id
.checked_add(1)

View file

@ -4,7 +4,7 @@ use futures::{
channel::{mpsc, oneshot},
sink, stream,
task::{Context, Poll},
AsyncRead, AsyncWrite, Sink, Stream, StreamExt,
Sink, Stream, StreamExt,
};
use pin_project_lite::pin_project;
use std::{
@ -44,7 +44,7 @@ impl MuxStreamRead {
}
}
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes>>> {
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move {
let evt = rx.read().await?;
Some((
@ -68,7 +68,7 @@ where
is_closed: Arc<AtomicBool>,
}
impl<W: crate::ws::WebSocketWrite> MuxStreamWrite<W> {
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
pub async fn write(&mut self, data: Bytes) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed);
@ -101,10 +101,7 @@ impl<W: crate::ws::WebSocketWrite> MuxStreamWrite<W> {
Ok(())
}
pub(crate) fn into_sink<'a>(self) -> Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + 'a>>
where
W: 'a,
{
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>> {
Box::pin(sink::unfold(self, |mut tx, data| async move {
tx.write(data).await?;
Ok(tx)
@ -130,7 +127,7 @@ where
tx: MuxStreamWrite<W>,
}
impl<W: crate::ws::WebSocketWrite> MuxStream<W> {
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
pub(crate) fn new(
stream_id: u32,
rx: mpsc::UnboundedReceiver<WsEvent>,
@ -174,10 +171,7 @@ impl<W: crate::ws::WebSocketWrite> MuxStream<W> {
(self.rx, self.tx)
}
pub fn into_io<'a>(self) -> MuxStreamIo<'a>
where
W: 'a,
{
pub fn into_io(self) -> MuxStreamIo {
MuxStreamIo {
rx: self.rx.into_stream(),
tx: self.tx.into_sink(),
@ -208,55 +202,54 @@ impl MuxStreamCloser {
}
pin_project! {
pub struct MuxStreamIo<'a> {
pub struct MuxStreamIo {
#[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + 'a>>,
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
#[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + 'a>>,
tx: Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>>,
}
}
impl<'a> MuxStreamIo<'a> {
pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a {
IoStream::new(self.map(|x| Ok::<Vec<u8>, std::io::Error>(x.to_vec())))
impl MuxStreamIo {
pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> {
IoStream::new(self)
}
}
impl Stream for MuxStreamIo<'_> {
type Item = Bytes;
impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx)
self.project()
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
}
}
impl Sink<Bytes> for MuxStreamIo<'_> {
type Error = crate::WispError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
self.project().tx.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_close(cx)
}
}
impl Sink<Vec<u8>> for MuxStreamIo<'_> {
impl Sink<Vec<u8>> for MuxStreamIo {
type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx).map_err(std::io::Error::other)
self.project()
.tx
.poll_ready(cx)
.map_err(std::io::Error::other)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.project().tx.start_send(item.into()).map_err(std::io::Error::other)
self.project()
.tx
.start_send(item.into())
.map_err(std::io::Error::other)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_flush(cx).map_err(std::io::Error::other)
self.project()
.tx
.poll_flush(cx)
.map_err(std::io::Error::other)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_close(cx).map_err(std::io::Error::other)
self.project()
.tx
.poll_close(cx)
.map_err(std::io::Error::other)
}
}

View file

@ -46,20 +46,20 @@ impl Frame {
pub trait WebSocketRead {
fn wisp_read_frame(
&mut self,
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send;
}
pub trait WebSocketWrite {
fn wisp_write_frame(
&mut self,
frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send;
}
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> {
pub fn new(ws: S) -> Self {
Self(Arc::new(Mutex::new(ws)))
}

View file

@ -1,4 +1,4 @@
use futures::{SinkExt, StreamExt};
use futures::{stream::{SplitStream, SplitSink}, SinkExt, StreamExt};
use ws_stream_wasm::{WsErr, WsMessage, WsStream};
impl From<WsMessage> for crate::ws::Frame {
@ -37,10 +37,10 @@ impl From<WsErr> for crate::WispError {
}
}
impl crate::ws::WebSocketRead for WsStream {
impl crate::ws::WebSocketRead for SplitStream<WsStream> {
async fn wisp_read_frame(
&mut self,
_: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
_: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.next()
@ -50,8 +50,11 @@ impl crate::ws::WebSocketRead for WsStream {
}
}
impl crate::ws::WebSocketWrite for WsStream {
impl crate::ws::WebSocketWrite for SplitSink<WsStream, WsMessage> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.send(frame.try_into()?).await.map_err(|e| e.into())
self
.send(frame.try_into()?)
.await
.map_err(|e| e.into())
}
}