update fastwebsockets, revert custom asyncread impl

This commit is contained in:
Toshit Chawda 2024-07-10 21:59:17 -07:00
parent 1916a8e7c8
commit 5571a63f40
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
13 changed files with 646 additions and 703 deletions

70
Cargo.lock generated
View file

@ -143,6 +143,16 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "async_io_stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c"
dependencies = [
"futures",
"rustc_version",
]
[[package]] [[package]]
name = "atomic-counter" name = "atomic-counter"
version = "1.0.1" version = "1.0.1"
@ -174,7 +184,7 @@ dependencies = [
"futures-util", "futures-util",
"http 0.2.12", "http 0.2.12",
"http-body 0.4.6", "http-body 0.4.6",
"hyper 0.14.29", "hyper 0.14.30",
"itoa", "itoa",
"matchit", "matchit",
"memchr", "memchr",
@ -295,9 +305,9 @@ checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.106" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8"
[[package]] [[package]]
name = "certs-grabber" name = "certs-grabber"
@ -315,9 +325,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.8" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -325,9 +335,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.8" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -511,7 +521,7 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]] [[package]]
name = "epoxy-client" name = "epoxy-client"
version = "2.0.5" version = "2.0.6"
dependencies = [ dependencies = [
"async-compression", "async-compression",
"async-trait", "async-trait",
@ -526,7 +536,7 @@ dependencies = [
"getrandom", "getrandom",
"http 1.1.0", "http 1.1.0",
"http-body-util", "http-body-util",
"hyper 1.4.0", "hyper 1.4.1",
"hyper-util-wasm", "hyper-util-wasm",
"js-sys", "js-sys",
"parking_lot_core", "parking_lot_core",
@ -557,7 +567,7 @@ dependencies = [
"fastwebsockets", "fastwebsockets",
"futures-util", "futures-util",
"http-body-util", "http-body-util",
"hyper 1.4.0", "hyper 1.4.1",
"hyper-util", "hyper-util",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -599,14 +609,14 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]] [[package]]
name = "fastwebsockets" name = "fastwebsockets"
version = "0.7.2" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93da8b19e29f202ef35ddd20ddea8c86166850fce5ba2a3c3f3e3174cdbb0620" checksum = "26da0c7b5cef45c521a6f9cdfffdfeb6c9f5804fbac332deb5ae254634c7a6be"
dependencies = [ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bytes", "bytes",
"http-body-util", "http-body-util",
"hyper 1.4.0", "hyper 1.4.1",
"hyper-util", "hyper-util",
"pin-project", "pin-project",
"rand", "rand",
@ -962,9 +972,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "0.14.29" version = "0.14.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -986,9 +996,9 @@ dependencies = [
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.4.0" version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -1011,7 +1021,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [ dependencies = [
"hyper 0.14.29", "hyper 0.14.30",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tokio-io-timeout", "tokio-io-timeout",
@ -1027,7 +1037,7 @@ dependencies = [
"futures-util", "futures-util",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.0", "http-body 1.0.0",
"hyper 1.4.0", "hyper 1.4.1",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]
@ -1043,7 +1053,7 @@ dependencies = [
"futures-util", "futures-util",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.0", "http-body 1.0.0",
"hyper 1.4.0", "hyper 1.4.1",
"pin-project-lite", "pin-project-lite",
"tower", "tower",
"tower-service", "tower-service",
@ -1547,6 +1557,15 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.34" version = "0.38.34"
@ -1653,6 +1672,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "semver"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.4.0" version = "0.4.0"
@ -1737,7 +1762,7 @@ dependencies = [
"futures", "futures",
"http-body-util", "http-body-util",
"humantime", "humantime",
"hyper 1.4.0", "hyper 1.4.1",
"simple_moving_average", "simple_moving_average",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
@ -1959,7 +1984,7 @@ dependencies = [
"h2 0.3.26", "h2 0.3.26",
"http 0.2.12", "http 0.2.12",
"http-body 0.4.6", "http-body 0.4.6",
"hyper 0.14.29", "hyper 0.14.30",
"hyper-timeout", "hyper-timeout",
"percent-encoding", "percent-encoding",
"pin-project", "pin-project",
@ -2462,6 +2487,7 @@ name = "wisp-mux"
version = "5.0.0" version = "5.0.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"async_io_stream",
"bytes", "bytes",
"dashmap", "dashmap",
"event-listener", "event-listener",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "epoxy-client" name = "epoxy-client"
version = "2.0.5" version = "2.0.6"
edition = "2021" edition = "2021"
[lib] [lib]
@ -13,7 +13,7 @@ base64 = { version = "0.22.1", optional = true }
bytes = "1.6.0" bytes = "1.6.0"
cfg-if = "1.0.0" cfg-if = "1.0.0"
event-listener = "5.3.1" event-listener = "5.3.1"
fastwebsockets = { version = "0.7.2", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true }
flume = "0.11.0" flume = "0.11.0"
futures-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] } futures-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"] }
futures-util = { version = "0.3.30", features = ["sink"] } futures-util = { version = "0.3.30", features = ["sink"] }

View file

@ -196,9 +196,11 @@ onmessage = async (msg) => {
[], [],
{ "x-header": "abc" }, { "x-header": "abc" },
); );
let i = 0;
while (true) { while (true) {
log("sending `data`"); log(`sending \`data${i}\``);
await ws.send("data"); await ws.send("data"+i);
i++;
await (new Promise((res, _) => setTimeout(res, 10))); await (new Promise((res, _) => setTimeout(res, 10)));
} }
} else if (should_tls_test) { } else if (should_tls_test) {

View file

@ -1,6 +1,6 @@
{ {
"name": "@mercuryworkshop/epoxy-tls", "name": "@mercuryworkshop/epoxy-tls",
"version": "2.0.5-1", "version": "2.0.6-1",
"description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser", "description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser",
"scripts": { "scripts": {
"build": "./build.sh" "build": "./build.sh"

View file

@ -1,13 +1,12 @@
use std::{pin::Pin, sync::Arc, task::Poll}; use std::{pin::Pin, sync::Arc, task::Poll};
use bytes::Bytes;
use futures_rustls::{ use futures_rustls::{
rustls::{ClientConfig, RootCertStore}, rustls::{ClientConfig, RootCertStore},
TlsConnector, TlsStream, TlsConnector, TlsStream,
}; };
use futures_util::{ use futures_util::{
future::Either, future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future
lock::{Mutex, MutexGuard},
AsyncRead, AsyncWrite, Future,
}; };
use hyper_util_wasm::client::legacy::connect::{Connected, Connection}; use hyper_util_wasm::client::legacy::connect::{Connected, Connection};
use js_sys::{Array, Reflect, Uint8Array}; use js_sys::{Array, Reflect, Uint8Array};
@ -17,246 +16,247 @@ use tower_service::Service;
use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, IoStream, MuxStreamIo, StreamType,
}; };
use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
fn object_to_trustanchor(obj: JsValue) -> Result<TrustAnchor<'static>, JsValue> { fn object_to_trustanchor(obj: JsValue) -> Result<TrustAnchor<'static>, JsValue> {
let subject: Uint8Array = Reflect::get(&obj, &"subject".into())?.dyn_into()?; let subject: Uint8Array = Reflect::get(&obj, &"subject".into())?.dyn_into()?;
let pub_key_info: Uint8Array = let pub_key_info: Uint8Array =
Reflect::get(&obj, &"subject_public_key_info".into())?.dyn_into()?; Reflect::get(&obj, &"subject_public_key_info".into())?.dyn_into()?;
let name_constraints: Option<Uint8Array> = Reflect::get(&obj, &"name_constraints".into()) let name_constraints: Option<Uint8Array> = Reflect::get(&obj, &"name_constraints".into())
.and_then(|x| x.dyn_into()) .and_then(|x| x.dyn_into())
.ok(); .ok();
Ok(TrustAnchor { Ok(TrustAnchor {
subject: Der::from(subject.to_vec()), subject: Der::from(subject.to_vec()),
subject_public_key_info: Der::from(pub_key_info.to_vec()), subject_public_key_info: Der::from(pub_key_info.to_vec()),
name_constraints: name_constraints.map(|x| Der::from(x.to_vec())), name_constraints: name_constraints.map(|x| Der::from(x.to_vec())),
}) })
} }
pub struct StreamProvider { pub struct StreamProvider {
wisp_url: String, wisp_url: String,
wisp_v2: bool, wisp_v2: bool,
udp_extension: bool, udp_extension: bool,
websocket_protocols: Vec<String>, websocket_protocols: Vec<String>,
client_config: Arc<ClientConfig>, client_config: Arc<ClientConfig>,
current_client: Arc<Mutex<Option<ClientMux>>>, current_client: Arc<Mutex<Option<ClientMux>>>,
} }
pub type ProviderUnencryptedStream = MuxStreamIo; pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; pub type ProviderUnencryptedAsyncRW = IoStream<ProviderUnencryptedStream, Bytes>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>; pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>; pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
impl StreamProvider { impl StreamProvider {
pub fn new( pub fn new(
wisp_url: String, wisp_url: String,
certs: Array, certs: Array,
options: &EpoxyClientOptions, options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> { ) -> Result<Self, EpoxyError> {
let certs: Result<Vec<TrustAnchor>, JsValue> = let certs: Result<Vec<TrustAnchor>, JsValue> =
certs.iter().map(object_to_trustanchor).collect(); certs.iter().map(object_to_trustanchor).collect();
let certstore = RootCertStore::from_iter(certs.map_err(|_| EpoxyError::InvalidCertStore)?); let certstore = RootCertStore::from_iter(certs.map_err(|_| EpoxyError::InvalidCertStore)?);
let client_config = Arc::new( let client_config = Arc::new(
ClientConfig::builder() ClientConfig::builder()
.with_root_certificates(certstore) .with_root_certificates(certstore)
.with_no_client_auth(), .with_no_client_auth(),
); );
Ok(Self { Ok(Self {
wisp_url, wisp_url,
current_client: Arc::new(Mutex::new(None)), current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2, wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required, udp_extension: options.udp_extension_required,
websocket_protocols: options.websocket_protocols.clone(), websocket_protocols: options.websocket_protocols.clone(),
client_config, client_config,
}) })
} }
async fn create_client( async fn create_client(
&self, &self,
mut locked: MutexGuard<'_, Option<ClientMux>>, mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> { ) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> = let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder())]; vec![Box::new(UdpProtocolExtensionBuilder())];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
Some(extensions_vec.as_slice()) Some(extensions_vec.as_slice())
} else { } else {
None None
}; };
let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?; let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?;
if !write.wait_for_open().await { if !write.wait_for_open().await {
return Err(EpoxyError::WebSocketConnectFailed); return Err(EpoxyError::WebSocketConnectFailed);
} }
let client = ClientMux::create(read, write, extensions).await?; let client = ClientMux::create(read, write, extensions).await?;
let (mux, fut) = if self.udp_extension { let (mux, fut) = if self.udp_extension {
client.with_udp_extension_required().await? client.with_udp_extension_required().await?
} else { } else {
client.with_no_required_extensions() client.with_no_required_extensions()
}; };
locked.replace(mux); locked.replace(mux);
let current_client = self.current_client.clone(); let current_client = self.current_client.clone();
spawn_local(async move { spawn_local(async move {
fut.await; fut.await;
current_client.lock().await.take(); current_client.lock().await.take();
}); });
Ok(()) Ok(())
} }
pub async fn replace_client(&self) -> Result<(), EpoxyError> { pub async fn replace_client(&self) -> Result<(), EpoxyError> {
self.create_client(self.current_client.lock().await).await self.create_client(self.current_client.lock().await).await
} }
pub async fn get_stream( pub async fn get_stream(
&self, &self,
stream_type: StreamType, stream_type: StreamType,
host: String, host: String,
port: u16, port: u16,
) -> Result<ProviderUnencryptedStream, EpoxyError> { ) -> Result<ProviderUnencryptedStream, EpoxyError> {
Box::pin(async { Box::pin(async {
let locked = self.current_client.lock().await; let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() { if let Some(mux) = locked.as_ref() {
Ok(mux Ok(mux
.client_new_stream(stream_type, host, port) .client_new_stream(stream_type, host, port)
.await? .await?
.into_io()) .into_io())
} else { } else {
self.create_client(locked).await?; self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await self.get_stream(stream_type, host, port).await
} }
}) })
.await .await
} }
pub async fn get_asyncread( pub async fn get_asyncread(
&self, &self,
stream_type: StreamType, stream_type: StreamType,
host: String, host: String,
port: u16, port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> { ) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self Ok(self
.get_stream(stream_type, host, port) .get_stream(stream_type, host, port)
.await? .await?
.into_asyncrw()) .into_asyncrw())
} }
pub async fn get_tls_stream( pub async fn get_tls_stream(
&self, &self,
host: String, host: String,
port: u16, port: u16,
) -> Result<ProviderTlsAsyncRW, EpoxyError> { ) -> Result<ProviderTlsAsyncRW, EpoxyError> {
let stream = self let stream = self
.get_asyncread(StreamType::Tcp, host.clone(), port) .get_asyncread(StreamType::Tcp, host.clone(), port)
.await?; .await?;
let connector = TlsConnector::from(self.client_config.clone()); let connector = TlsConnector::from(self.client_config.clone());
Ok(connector.connect(host.try_into()?, stream).await?.into()) Ok(connector.connect(host.try_into()?, stream).await?.into())
} }
} }
pin_project! { pin_project! {
pub struct HyperIo { pub struct HyperIo {
#[pin] #[pin]
inner: ProviderAsyncRW, inner: ProviderAsyncRW,
} }
} }
impl hyper::rt::Read for HyperIo { impl hyper::rt::Read for HyperIo {
fn poll_read( fn poll_read(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>, mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
let buf_slice: &mut [u8] = unsafe { std::mem::transmute(buf.as_mut()) }; let buf_slice: &mut [u8] = unsafe { std::mem::transmute(buf.as_mut()) };
match self.project().inner.poll_read(cx, buf_slice) { match self.project().inner.poll_read(cx, buf_slice) {
Poll::Ready(bytes_read) => { Poll::Ready(bytes_read) => {
let bytes_read = bytes_read?; let bytes_read = bytes_read?;
unsafe { unsafe {
buf.advance(bytes_read); buf.advance(bytes_read);
} }
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
} }
} }
} }
impl hyper::rt::Write for HyperIo { impl hyper::rt::Write for HyperIo {
fn poll_write( fn poll_write(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
self.project().inner.poll_write(cx, buf) self.project().inner.poll_write(cx, buf)
} }
fn poll_flush( fn poll_flush(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx) self.project().inner.poll_flush(cx)
} }
fn poll_shutdown( fn poll_shutdown(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_close(cx) self.project().inner.poll_close(cx)
} }
fn poll_write_vectored( fn poll_write_vectored(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>], bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs) self.project().inner.poll_write_vectored(cx, bufs)
} }
} }
impl Connection for HyperIo { impl Connection for HyperIo {
fn connected(&self) -> Connected { fn connected(&self) -> Connected {
Connected::new() Connected::new()
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct StreamProviderService(pub Arc<StreamProvider>); pub struct StreamProviderService(pub Arc<StreamProvider>);
impl Service<hyper::Uri> for StreamProviderService { impl Service<hyper::Uri> for StreamProviderService {
type Response = HyperIo; type Response = HyperIo;
type Error = EpoxyError; type Error = EpoxyError;
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>; type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready( fn poll_ready(
&mut self, &mut self,
_: &mut std::task::Context<'_>, _: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> { ) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: hyper::Uri) -> Self::Future { fn call(&mut self, req: hyper::Uri) -> Self::Future {
let provider = self.0.clone(); let provider = self.0.clone();
Box::pin(async move { Box::pin(async move {
let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?; let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme)?;
let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string();
let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme { let port = req.port_u16().map(Ok).unwrap_or_else(|| match scheme {
"https" | "wss" => Ok(443), "https" | "wss" => Ok(443),
"http" | "ws" => Ok(80), "http" | "ws" => Ok(80),
_ => Err(EpoxyError::NoUrlPort), _ => Err(EpoxyError::NoUrlPort),
})?; })?;
Ok(HyperIo { Ok(HyperIo {
inner: match scheme { inner: match scheme {
"https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?),
"http" | "ws" => { "http" | "ws" => {
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
} }
_ => return Err(EpoxyError::InvalidUrlScheme), _ => return Err(EpoxyError::InvalidUrlScheme),
}, },
}) })
}) })
} }
} }

View file

@ -9,11 +9,24 @@ use http::{HeaderValue, Uri};
use hyper::{body::Body, rt::Executor}; use hyper::{body::Body, rt::Executor};
use js_sys::{Array, ArrayBuffer, Object, Reflect, Uint8Array}; use js_sys::{Array, ArrayBuffer, Object, Reflect, Uint8Array};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen::{prelude::*, JsCast, JsValue};
use wasm_bindgen_futures::JsFuture; use wasm_bindgen_futures::JsFuture;
use crate::EpoxyError; use crate::EpoxyError;
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console, js_name = log)]
pub fn js_console_log(s: &str);
}
#[macro_export]
macro_rules! console_log {
($($expr:expr),*) => {
$crate::utils::js_console_log(&format!($($expr),*));
};
}
pub trait UriExt { pub trait UriExt {
fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, EpoxyError>; fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, EpoxyError>;
} }

View file

@ -118,7 +118,6 @@ impl EpoxyWebSocket {
); );
} }
OpCode::Close => { OpCode::Close => {
let _ = onclose.call0(&JsValue::null());
break; break;
} }
// ping/pong/continue // ping/pong/continue

View file

@ -10,7 +10,7 @@ clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wr
clio = { version = "0.3.5", features = ["clap-parse"] } clio = { version = "0.3.5", features = ["clap-parse"] }
console-subscriber = { version = "0.2.0", optional = true } console-subscriber = { version = "0.2.0", optional = true }
dashmap = "5.5.3" dashmap = "5.5.3"
fastwebsockets = { version = "0.7.1", features = ["upgrade", "simdutf8", "unstable-split"] } fastwebsockets = { version = "0.8.0", features = ["upgrade", "simdutf8", "unstable-split"] }
futures-util = { version = "0.3.30", features = ["sink"] } futures-util = { version = "0.3.30", features = ["sink"] }
http-body-util = "0.1.0" http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["server", "http1"] } hyper = { version = "1.1.0", features = ["server", "http1"] }

View file

@ -17,7 +17,7 @@ use hyper_util::rt::TokioIo;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tokio::{ use tokio::{
io::{copy, AsyncBufReadExt, AsyncWriteExt}, io::{copy, copy_bidirectional, AsyncBufReadExt, AsyncWriteExt},
net::{lookup_host, TcpListener, TcpStream, UdpSocket}, net::{lookup_host, TcpListener, TcpStream, UdpSocket},
select, select,
}; };
@ -34,7 +34,7 @@ use wisp_mux::{
udp::UdpProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder, ProtocolExtensionBuilder,
}, },
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError, CloseReason, ConnectPacket, MuxStream, IoStream, ServerMux, StreamType, WispError,
}; };
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -269,6 +269,8 @@ async fn accept_http(
} }
} }
// re-enable once MuxStreamAsyncRW is fixed
/*
async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> { async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> {
let (muxrx, muxtx) = mux.into_split(); let (muxrx, muxtx) = mux.into_split();
let mut muxrx = muxrx.compat(); let mut muxrx = muxrx.compat();
@ -300,6 +302,7 @@ async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()>
x = slow_fut => x.map(|_| ()), x = slow_fut => x.map(|_| ()),
} }
} }
*/
async fn handle_mux( async fn handle_mux(
packet: ConnectPacket, packet: ConnectPacket,
@ -311,9 +314,9 @@ async fn handle_mux(
); );
match packet.stream_type { match packet.stream_type {
StreamType::Tcp => { StreamType::Tcp => {
let tcp_stream = TcpStream::connect(uri).await?; let mut tcp_stream = TcpStream::connect(uri).await?;
let mux = stream.into_io().into_asyncrw(); let mut mux = stream.into_io().into_asyncrw().compat();
copy_buf(mux, tcp_stream).await?; copy_bidirectional(&mut mux, &mut tcp_stream).await?;
} }
StreamType::Udp => { StreamType::Udp => {
let uri = lookup_host(uri) let uri = lookup_host(uri)

View file

@ -8,7 +8,7 @@ atomic-counter = "1.0.1"
bytes = "1.5.0" bytes = "1.5.0"
clap = { version = "4.5.4", features = ["cargo", "derive"] } clap = { version = "4.5.4", features = ["cargo", "derive"] }
console-subscriber = { version = "0.2.0", optional = true } console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
futures = "0.3.30" futures = "0.3.30"
http-body-util = "0.1.0" http-body-util = "0.1.0"
humantime = "2.1.0" humantime = "2.1.0"

View file

@ -10,10 +10,11 @@ edition = "2021"
[dependencies] [dependencies]
async-trait = "0.1.79" async-trait = "0.1.79"
async_io_stream = "0.3.3"
bytes = "1.5.0" bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] } dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0" event-listener = "5.0.0"
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true }
flume = "0.11.0" flume = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3" futures-timer = "3.0.3"

View file

@ -14,6 +14,7 @@ mod stream;
pub mod ws; pub mod ws;
pub use crate::{packet::*, stream::*}; pub use crate::{packet::*, stream::*};
pub use async_io_stream::IoStream;
use bytes::Bytes; use bytes::Bytes;
use dashmap::DashMap; use dashmap::DashMap;

View file

@ -1,559 +1,457 @@
use crate::{ use crate::{
sink_unfold, sink_unfold,
ws::{Frame, LockedWebSocketWrite}, ws::{Frame, LockedWebSocketWrite},
CloseReason, Packet, Role, StreamType, WispError, CloseReason, Packet, Role, StreamType, WispError,
}; };
use async_io_stream::IoStream;
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use event_listener::Event; use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
select, select, stream,
stream::{self, IntoAsyncRead, SplitSink, SplitStream}, task::{Context, Poll},
task::{Context, Poll}, FutureExt, Sink, Stream,
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, StreamExt, TryStreamExt,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
pin::Pin, pin::Pin,
sync::{ sync::{
atomic::{AtomicBool, AtomicU32, Ordering}, atomic::{AtomicBool, AtomicU32, Ordering},
Arc, Arc,
}, },
task::ready,
}; };
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
Close(Packet, oneshot::Sender<Result<(), WispError>>), Close(Packet, oneshot::Sender<Result<(), WispError>>),
CreateStream( CreateStream(
StreamType, StreamType,
String, String,
u16, u16,
oneshot::Sender<Result<MuxStream, WispError>>, oneshot::Sender<Result<MuxStream, WispError>>,
), ),
EndFut(Option<CloseReason>), EndFut(Option<CloseReason>),
} }
/// Read side of a multiplexor stream. /// Read side of a multiplexor stream.
pub struct MuxStreamRead { pub struct MuxStreamRead {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
/// Type of the stream. /// Type of the stream.
pub stream_type: StreamType, pub stream_type: StreamType,
role: Role, role: Role,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32, flow_control_read: AtomicU32,
target_flow_control: u32, target_flow_control: u32,
} }
impl MuxStreamRead { impl MuxStreamRead {
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&self) -> Option<Bytes> { pub async fn read(&self) -> Option<Bytes> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return None; return None;
} }
let bytes = select! { let bytes = select! {
x = self.rx.recv_async() => x.ok()?, x = self.rx.recv_async() => x.ok()?,
_ = self.is_closed_event.listen().fuse() => return None _ = self.is_closed_event.listen().fuse() => return None
}; };
if self.role == Role::Server && self.stream_type == StreamType::Tcp { if self.role == Role::Server && self.stream_type == StreamType::Tcp {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) {
self.tx self.tx
.write_frame( .write_frame(
Packet::new_continue( Packet::new_continue(
self.stream_id, self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val, self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
) )
.into(), .into(),
) )
.await .await
.ok()?; .ok()?;
self.flow_control_read.store(0, Ordering::Release); self.flow_control_read.store(0, Ordering::Release);
} }
} }
Some(bytes) Some(bytes)
} }
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> { pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |rx| async move { Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await?, rx)) Some((rx.read().await?, rx))
})) }))
} }
} }
/// Write side of a multiplexor stream. /// Write side of a multiplexor stream.
pub struct MuxStreamWrite { pub struct MuxStreamWrite {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
/// Type of the stream. /// Type of the stream.
pub stream_type: StreamType, pub stream_type: StreamType,
role: Role, role: Role,
mux_tx: mpsc::Sender<WsEvent>, mux_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
} }
impl MuxStreamWrite { impl MuxStreamWrite {
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
if self.role == Role::Client if self.role == Role::Client
&& self.stream_type == StreamType::Tcp && self.stream_type == StreamType::Tcp
&& self.flow_control.load(Ordering::Acquire) == 0 && self.flow_control.load(Ordering::Acquire) == 0
{ {
self.continue_recieved.listen().await; self.continue_recieved.listen().await;
} }
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.tx self.tx
.write_frame(Frame::from(Packet::new_data(self.stream_id, data))) .write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
.await?; .await?;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store( self.flow_control.store(
self.flow_control.load(Ordering::Acquire).saturating_sub(1), self.flow_control.load(Ordering::Acquire).saturating_sub(1),
Ordering::Release, Ordering::Release,
); );
} }
Ok(()) Ok(())
} }
/// Get a handle to close the connection. /// Get a handle to close the connection.
/// ///
/// Useful to close the connection without having access to the stream. /// Useful to close the connection without having access to the stream.
/// ///
/// # Example /// # Example
/// ``` /// ```
/// let handle = stream.get_close_handle(); /// let handle = stream.get_close_handle();
/// if let Err(error) = handle_stream(stream) { /// if let Err(error) = handle_stream(stream) {
/// handle.close(0x01); /// handle.close(0x01);
/// } /// }
/// ``` /// ```
pub fn get_close_handle(&self) -> MuxStreamCloser { pub fn get_close_handle(&self) -> MuxStreamCloser {
MuxStreamCloser { MuxStreamCloser {
stream_id: self.stream_id, stream_id: self.stream_id,
close_channel: self.mux_tx.clone(), close_channel: self.mux_tx.clone(),
is_closed: self.is_closed.clone(), is_closed: self.is_closed.clone(),
} }
} }
/// Get a protocol extension stream to send protocol extension packets. /// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream { MuxProtocolExtensionStream {
stream_id: self.stream_id, stream_id: self.stream_id,
tx: self.tx.clone(), tx: self.tx.clone(),
is_closed: self.is_closed.clone(), is_closed: self.is_closed.clone(),
} }
} }
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.is_closed.store(true, Ordering::Release); self.is_closed.store(true, Ordering::Release);
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.mux_tx self.mux_tx
.send_async(WsEvent::Close( .send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(()) Ok(())
} }
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> { pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> {
let handle = self.get_close_handle(); let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold( Box::pin(sink_unfold::unfold(
self, self,
|tx, data| async move { |tx, data| async move {
tx.write(data).await?; tx.write(data).await?;
Ok(tx) Ok(tx)
}, },
handle, handle,
move |handle| async { move |handle| async {
handle.close(CloseReason::Unknown).await?; handle.close(CloseReason::Unknown).await?;
Ok(handle) Ok(handle)
}, },
)) ))
} }
} }
impl Drop for MuxStreamWrite { impl Drop for MuxStreamWrite {
fn drop(&mut self) { fn drop(&mut self) {
if !self.is_closed.load(Ordering::Acquire) { if !self.is_closed.load(Ordering::Acquire) {
self.is_closed.store(true, Ordering::Release); self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel(); let (tx, _) = oneshot::channel();
let _ = self.mux_tx.send(WsEvent::Close( let _ = self.mux_tx.send(WsEvent::Close(
Packet::new_close(self.stream_id, CloseReason::Unknown), Packet::new_close(self.stream_id, CloseReason::Unknown),
tx, tx,
)); ));
} }
} }
} }
/// Multiplexor stream. /// Multiplexor stream.
pub struct MuxStream { pub struct MuxStream {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
rx: MuxStreamRead, rx: MuxStreamRead,
tx: MuxStreamWrite, tx: MuxStreamWrite,
} }
impl MuxStream { impl MuxStream {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
stream_id: u32, stream_id: u32,
role: Role, role: Role,
stream_type: StreamType, stream_type: StreamType,
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Bytes>,
mux_tx: mpsc::Sender<WsEvent>, mux_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
target_flow_control: u32, target_flow_control: u32,
) -> Self { ) -> Self {
Self { Self {
stream_id, stream_id,
rx: MuxStreamRead { rx: MuxStreamRead {
stream_id, stream_id,
stream_type, stream_type,
role, role,
tx: tx.clone(), tx: tx.clone(),
rx, rx,
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(), is_closed_event: is_closed_event.clone(),
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0), flow_control_read: AtomicU32::new(0),
target_flow_control, target_flow_control,
}, },
tx: MuxStreamWrite { tx: MuxStreamWrite {
stream_id, stream_id,
stream_type, stream_type,
role, role,
mux_tx, mux_tx,
tx, tx,
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
continue_recieved: continue_recieved.clone(), continue_recieved: continue_recieved.clone(),
}, },
} }
} }
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&self) -> Option<Bytes> { pub async fn read(&self) -> Option<Bytes> {
self.rx.read().await self.rx.read().await
} }
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
/// Get a handle to close the connection. /// Get a handle to close the connection.
/// ///
/// Useful to close the connection without having access to the stream. /// Useful to close the connection without having access to the stream.
/// ///
/// # Example /// # Example
/// ``` /// ```
/// let handle = stream.get_close_handle(); /// let handle = stream.get_close_handle();
/// if let Err(error) = handle_stream(stream) { /// if let Err(error) = handle_stream(stream) {
/// handle.close(0x01); /// handle.close(0x01);
/// } /// }
/// ``` /// ```
pub fn get_close_handle(&self) -> MuxStreamCloser { pub fn get_close_handle(&self) -> MuxStreamCloser {
self.tx.get_close_handle() self.tx.get_close_handle()
} }
/// Get a protocol extension stream to send protocol extension packets. /// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
self.tx.get_protocol_extension_stream() self.tx.get_protocol_extension_stream()
} }
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await self.tx.close(reason).await
} }
/// Split the stream into read and write parts, consuming it. /// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
(self.rx, self.tx) (self.rx, self.tx)
} }
/// Turn the stream into one that implements futures `Stream + Sink`, consuming it. /// Turn the stream into one that implements futures `Stream + Sink`, consuming it.
pub fn into_io(self) -> MuxStreamIo { pub fn into_io(self) -> MuxStreamIo {
MuxStreamIo { MuxStreamIo {
rx: self.rx.into_stream(), rx: MuxStreamIoStream {
tx: self.tx.into_sink(), rx: self.rx.into_stream(),
} },
} tx: MuxStreamIoSink {
tx: self.tx.into_sink(),
},
}
}
} }
/// Close handle for a multiplexor stream. /// Close handle for a multiplexor stream.
#[derive(Clone)] #[derive(Clone)]
pub struct MuxStreamCloser { pub struct MuxStreamCloser {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
close_channel: mpsc::Sender<WsEvent>, close_channel: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
} }
impl MuxStreamCloser { impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.is_closed.store(true, Ordering::Release); self.is_closed.store(true, Ordering::Release);
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel self.close_channel
.send_async(WsEvent::Close( .send_async(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(()) Ok(())
} }
} }
/// Stream for sending arbitrary protocol extension packets. /// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream { pub struct MuxProtocolExtensionStream {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
pub(crate) tx: LockedWebSocketWrite, pub(crate) tx: LockedWebSocketWrite,
pub(crate) is_closed: Arc<AtomicBool>, pub(crate) is_closed: Arc<AtomicBool>,
} }
impl MuxProtocolExtensionStream { impl MuxProtocolExtensionStream {
/// Send a protocol extension packet with this stream's ID. /// Send a protocol extension packet with this stream's ID.
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
encoded.put_u8(packet_type); encoded.put_u8(packet_type);
encoded.put_u32_le(self.stream_id); encoded.put_u32_le(self.stream_id);
encoded.extend(data); encoded.extend(data);
self.tx.write_frame(Frame::binary(encoded)).await self.tx.write_frame(Frame::binary(encoded)).await
} }
} }
pin_project! { pin_project! {
/// Multiplexor stream that implements futures `Stream + Sink`. /// Multiplexor stream that implements futures `Stream + Sink`.
pub struct MuxStreamIo { pub struct MuxStreamIo {
#[pin] #[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>, rx: MuxStreamIoStream,
#[pin] #[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>, tx: MuxStreamIoSink,
} }
} }
impl MuxStreamIo { impl MuxStreamIo {
/// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`.
pub fn into_asyncrw(self) -> MuxStreamAsyncRW { pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Bytes> {
let (tx, rx) = self.split(); IoStream::new(self)
MuxStreamAsyncRW { }
rx: MuxStreamAsyncRead::new(rx),
tx: MuxStreamAsyncWrite::new(tx), /// Split the stream into read and write parts, consuming it.
} pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) {
} (self.rx, self.tx)
}
} }
impl Stream for MuxStreamIo { impl Stream for MuxStreamIo {
type Item = Result<Bytes, std::io::Error>; type Item = Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx).map(|x| x.map(Ok)) self.project().rx.poll_next(cx)
} }
} }
impl Sink<Bytes> for MuxStreamIo { impl Sink<Bytes> for MuxStreamIo {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project().tx.poll_ready(cx)
.tx }
.poll_ready(cx) fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
.map_err(std::io::Error::other) self.project().tx.start_send(item)
} }
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project().tx.poll_flush(cx)
.tx }
.start_send(item) fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
.map_err(std::io::Error::other) self.project().tx.poll_close(cx)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(std::io::Error::other)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(std::io::Error::other)
}
} }
pin_project! { pin_project! {
/// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. /// Read side of a multiplexor stream that implements futures `Stream`.
pub struct MuxStreamAsyncRW { pub struct MuxStreamIoStream {
#[pin] #[pin]
rx: MuxStreamAsyncRead, rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
#[pin] }
tx: MuxStreamAsyncWrite,
}
} }
impl MuxStreamAsyncRW { impl Stream for MuxStreamIoStream {
/// Split the stream into read and write parts, consuming it. type Item = Result<Bytes, std::io::Error>;
pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
(self.rx, self.tx) self.project().rx.poll_next(cx).map(|x| x.map(Ok))
} }
}
impl AsyncRead for MuxStreamAsyncRW {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read_vectored(cx, bufs)
}
}
impl AsyncBufRead for MuxStreamAsyncRW {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().rx.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt)
}
}
impl AsyncWrite for MuxStreamAsyncRW {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().tx.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_close(cx)
}
} }
pin_project! { pin_project! {
/// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`. /// Write side of a multiplexor stream that implements futures `Sink`.
pub struct MuxStreamAsyncRead { pub struct MuxStreamIoSink {
#[pin] #[pin]
rx: IntoAsyncRead<SplitStream<MuxStreamIo>>, tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>,
} }
} }
impl MuxStreamAsyncRead { impl Sink<Bytes> for MuxStreamIoSink {
pub(crate) fn new(stream: SplitStream<MuxStreamIo>) -> Self { type Error = std::io::Error;
Self { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
rx: stream.into_async_read(), self.project()
} .tx
} .poll_ready(cx)
} .map_err(std::io::Error::other)
}
impl AsyncRead for MuxStreamAsyncRead { fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
fn poll_read( self.project()
self: Pin<&mut Self>, .tx
cx: &mut Context<'_>, .start_send(item)
buf: &mut [u8], .map_err(std::io::Error::other)
) -> Poll<std::io::Result<usize>> { }
self.project().rx.poll_read(cx, buf) fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
} self.project()
.tx
fn poll_read_vectored( .poll_flush(cx)
self: Pin<&mut Self>, .map_err(std::io::Error::other)
cx: &mut Context<'_>, }
bufs: &mut [std::io::IoSliceMut<'_>], fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
) -> Poll<std::io::Result<usize>> { self.project()
self.project().rx.poll_read_vectored(cx, bufs) .tx
} .poll_close(cx)
} .map_err(std::io::Error::other)
}
impl AsyncBufRead for MuxStreamAsyncRead {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().rx.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt)
}
}
pin_project! {
/// Write side of a multiplexor stream that implements futures `AsyncWrite`.
pub struct MuxStreamAsyncWrite {
#[pin]
tx: SplitSink<MuxStreamIo, Bytes>,
}
}
impl MuxStreamAsyncWrite {
pub(crate) fn new(sink: SplitSink<MuxStreamIo, Bytes>) -> Self {
Self { tx: sink }
}
}
impl AsyncWrite for MuxStreamAsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut this = self.project();
ready!(this.tx.as_mut().poll_ready(cx))?;
match this.tx.start_send(Bytes::copy_from_slice(buf)) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_close(cx)
}
} }