mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
commit
1c22591817
26 changed files with 2403 additions and 459 deletions
759
Cargo.lock
generated
759
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,6 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
members = ["server", "client", "wisp", "simple-wisp-client"]
|
members = ["server", "client", "wisp", "simple-wisp-client", "certs-grabber"]
|
||||||
default-members = ["server"]
|
default-members = ["server"]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
|
|
13
certs-grabber/Cargo.toml
Normal file
13
certs-grabber/Cargo.toml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
[package]
|
||||||
|
name = "certs-grabber"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
hex = "0.4.3"
|
||||||
|
ring = "0.17.8"
|
||||||
|
rustls-pki-types = "1.4.1"
|
||||||
|
rustls-webpki = "0.102.2"
|
||||||
|
tokio = { version = "1.37.0", features = ["full"] }
|
||||||
|
webpki-ccadb = "0.1.0"
|
||||||
|
x509-parser = "0.16.0"
|
64
certs-grabber/src/main.rs
Normal file
64
certs-grabber/src/main.rs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
use ring::digest::{digest, SHA256};
|
||||||
|
use rustls_pki_types::{CertificateDer, TrustAnchor};
|
||||||
|
use webpki::anchor_from_trusted_cert;
|
||||||
|
use webpki_ccadb::fetch_ccadb_roots;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let tls_roots_map = fetch_ccadb_roots().await;
|
||||||
|
let mut code = String::with_capacity(256 * 1_024);
|
||||||
|
code.push_str("const ROOTS = [");
|
||||||
|
for (_, root) in tls_roots_map {
|
||||||
|
// Verify the DER FP matches the metadata FP.
|
||||||
|
let der = root.der();
|
||||||
|
let calculated_fp = digest(&SHA256, &der);
|
||||||
|
let metadata_fp = hex::decode(&root.sha256_fingerprint).expect("malformed fingerprint");
|
||||||
|
assert_eq!(calculated_fp.as_ref(), metadata_fp.as_slice());
|
||||||
|
|
||||||
|
let ta_der = CertificateDer::from(der.as_ref());
|
||||||
|
let TrustAnchor {
|
||||||
|
subject,
|
||||||
|
subject_public_key_info,
|
||||||
|
name_constraints,
|
||||||
|
} = anchor_from_trusted_cert(&ta_der).expect("malformed trust anchor der");
|
||||||
|
|
||||||
|
/*
|
||||||
|
let (_, parsed_cert) =
|
||||||
|
x509_parser::parse_x509_certificate(&der).expect("malformed x509 der");
|
||||||
|
let issuer = name_to_string(parsed_cert.issuer());
|
||||||
|
let subject_str = name_to_string(parsed_cert.subject());
|
||||||
|
let label = root.common_name_or_certificate_name.clone();
|
||||||
|
let serial = root.serial().to_string();
|
||||||
|
let sha256_fp = root.sha256_fp();
|
||||||
|
*/
|
||||||
|
|
||||||
|
code.write_fmt(format_args!(
|
||||||
|
"{{subject:new Uint8Array([{}]),subject_public_key_info:new Uint8Array([{}]),name_constraints:{}}},",
|
||||||
|
subject
|
||||||
|
.as_ref()
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.collect::<Vec<String>>().join(","),
|
||||||
|
subject_public_key_info
|
||||||
|
.as_ref()
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.collect::<Vec<String>>().join(","),
|
||||||
|
if let Some(constraints) = name_constraints {
|
||||||
|
format!("new Uint8Array([{}])",constraints
|
||||||
|
.as_ref()
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.collect::<Vec<String>>().join(","))
|
||||||
|
} else {
|
||||||
|
"null".into()
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
code.pop();
|
||||||
|
code.push_str("];");
|
||||||
|
println!("{}", code);
|
||||||
|
}
|
|
@ -17,7 +17,6 @@ wasm-bindgen = { version = "0.2.91", features = ["enable-interning"] }
|
||||||
wasm-bindgen-futures = "0.4.39"
|
wasm-bindgen-futures = "0.4.39"
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
js-sys = "0.3.66"
|
js-sys = "0.3.66"
|
||||||
webpki-roots = "0.26.0"
|
|
||||||
tokio-rustls = "0.25.0"
|
tokio-rustls = "0.25.0"
|
||||||
web-sys = { version = "0.3.66", features = ["Request", "RequestInit", "Headers", "Response", "ResponseInit", "WebSocket", "BinaryType", "MessageEvent"] }
|
web-sys = { version = "0.3.66", features = ["Request", "RequestInit", "Headers", "Response", "ResponseInit", "WebSocket", "BinaryType", "MessageEvent"] }
|
||||||
wasm-streams = "0.4.0"
|
wasm-streams = "0.4.0"
|
||||||
|
@ -25,7 +24,7 @@ tokio-util = { version = "0.7.10", features = ["io"] }
|
||||||
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
|
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
|
||||||
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
|
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
|
||||||
base64 = "0.21.7"
|
base64 = "0.21.7"
|
||||||
wisp-mux = { path = "../wisp", features = ["tokio_io"] }
|
wisp-mux = { path = "../wisp", features = ["tokio_io", "wasm"] }
|
||||||
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
|
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
|
||||||
getrandom = { version = "0.2.12", features = ["js"] }
|
getrandom = { version = "0.2.12", features = ["js"] }
|
||||||
hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] }
|
hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] }
|
||||||
|
@ -35,6 +34,7 @@ console_error_panic_hook = "0.1.7"
|
||||||
send_wrapper = "0.6.0"
|
send_wrapper = "0.6.0"
|
||||||
event-listener = "5.2.0"
|
event-listener = "5.2.0"
|
||||||
wasmtimer = "0.2.0"
|
wasmtimer = "0.2.0"
|
||||||
|
async-trait = "0.1.80"
|
||||||
|
|
||||||
[dependencies.ring]
|
[dependencies.ring]
|
||||||
features = ["wasm32_unknown_unknown_js"]
|
features = ["wasm32_unknown_unknown_js"]
|
||||||
|
@ -46,3 +46,4 @@ features = ["web"]
|
||||||
default-env = "0.1.1"
|
default-env = "0.1.1"
|
||||||
wasm-bindgen-test = "0.3.42"
|
wasm-bindgen-test = "0.3.42"
|
||||||
web-sys = { version = "0.3.69", features = ["FormData", "UrlSearchParams"] }
|
web-sys = { version = "0.3.69", features = ["FormData", "UrlSearchParams"] }
|
||||||
|
webpki-roots = "0.26.0"
|
||||||
|
|
|
@ -41,5 +41,14 @@ echo "}\ndeclare function epoxy(maybe_memory?: WebAssembly.Memory): Promise<type
|
||||||
cp out/epoxy_client.d.ts pkg/epoxy.d.ts
|
cp out/epoxy_client.d.ts pkg/epoxy.d.ts
|
||||||
cp out/epoxy_client_bg.wasm pkg/epoxy.wasm
|
cp out/epoxy_client_bg.wasm pkg/epoxy.wasm
|
||||||
|
|
||||||
|
echo "[epx] fetching certs"
|
||||||
|
(
|
||||||
|
cd ../certs-grabber
|
||||||
|
cargo run
|
||||||
|
) > pkg/certs.js
|
||||||
|
cat pkg/certs.js > pkg/certs-module.js
|
||||||
|
echo "export default ROOTS;" >> pkg/certs-module.js
|
||||||
|
echo "[epx] fetching certs finished"
|
||||||
|
|
||||||
rm -r out/
|
rm -r out/
|
||||||
echo "[epx] done!"
|
echo "[epx] done!"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import epoxy from "./pkg/epoxy-module-bundled.js";
|
import epoxy from "./pkg/epoxy-module-bundled.js";
|
||||||
|
import CERTS from "./pkg/certs-module.js";
|
||||||
|
|
||||||
onmessage = async (msg) => {
|
onmessage = async (msg) => {
|
||||||
console.debug("recieved demo:", msg);
|
console.debug("recieved demo:", msg);
|
||||||
|
@ -29,13 +30,13 @@ onmessage = async (msg) => {
|
||||||
postMessage(JSON.stringify(str, null, 4));
|
postMessage(JSON.stringify(str, null, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
const { EpoxyClient, certs } = await epoxy();
|
const { EpoxyClient } = await epoxy();
|
||||||
|
|
||||||
console.log("certs:", certs());
|
console.log("certs:", CERTS);
|
||||||
|
|
||||||
const tconn0 = performance.now();
|
const tconn0 = performance.now();
|
||||||
// args: websocket url, user agent, redirect limit
|
// args: websocket url, user agent, redirect limit, certs
|
||||||
let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10);
|
let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10, CERTS);
|
||||||
const tconn1 = performance.now();
|
const tconn1 = performance.now();
|
||||||
log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
|
log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
|
||||||
|
|
||||||
|
@ -237,9 +238,9 @@ onmessage = async (msg) => {
|
||||||
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
|
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let resp = await epoxy_client.fetch("https://httpbin.org/get");
|
let resp = await epoxy_client.fetch("https://www.example.com/");
|
||||||
console.log(resp, Object.fromEntries(resp.headers));
|
console.log(resp, Object.fromEntries(resp.headers));
|
||||||
plog(await resp.json());
|
log(await resp.text());
|
||||||
}
|
}
|
||||||
log("done");
|
log("done");
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,6 +10,7 @@ mod wrappers;
|
||||||
use tls_stream::EpxTlsStream;
|
use tls_stream::EpxTlsStream;
|
||||||
use tokioio::TokioIo;
|
use tokioio::TokioIo;
|
||||||
use udp_stream::EpxUdpStream;
|
use udp_stream::EpxUdpStream;
|
||||||
|
use utils::object_to_trustanchor;
|
||||||
pub use utils::{Boolinator, ReplaceErr, UriExt};
|
pub use utils::{Boolinator, ReplaceErr, UriExt};
|
||||||
use websocket::EpxWebSocket;
|
use websocket::EpxWebSocket;
|
||||||
use wrappers::{IncomingBody, ServiceWrapper, TlsWispService, WebSocketWrapper};
|
use wrappers::{IncomingBody, ServiceWrapper, TlsWispService, WebSocketWrapper};
|
||||||
|
@ -70,42 +71,10 @@ fn init() {
|
||||||
intern("Content-Type");
|
intern("Content-Type");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cert_to_jval(cert: &TrustAnchor) -> Result<JsValue, JsValue> {
|
|
||||||
let val = Object::new();
|
|
||||||
Reflect::set(
|
|
||||||
&val,
|
|
||||||
&jval!("subject"),
|
|
||||||
&Uint8Array::from(cert.subject.as_ref()),
|
|
||||||
)?;
|
|
||||||
Reflect::set(
|
|
||||||
&val,
|
|
||||||
&jval!("subject_public_key_info"),
|
|
||||||
&Uint8Array::from(cert.subject_public_key_info.as_ref()),
|
|
||||||
)?;
|
|
||||||
Reflect::set(
|
|
||||||
&val,
|
|
||||||
&jval!("name_constraints"),
|
|
||||||
&jval!(cert
|
|
||||||
.name_constraints
|
|
||||||
.as_ref()
|
|
||||||
.map(|x| Uint8Array::from(x.as_ref()))),
|
|
||||||
)?;
|
|
||||||
Ok(val.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[wasm_bindgen]
|
|
||||||
pub fn certs() -> Result<JsValue, JsValue> {
|
|
||||||
Ok(webpki_roots::TLS_SERVER_ROOTS
|
|
||||||
.iter()
|
|
||||||
.map(cert_to_jval)
|
|
||||||
.collect::<Result<Array, JsValue>>()?
|
|
||||||
.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[wasm_bindgen(inspectable)]
|
#[wasm_bindgen(inspectable)]
|
||||||
pub struct EpoxyClient {
|
pub struct EpoxyClient {
|
||||||
rustls_config: Arc<rustls::ClientConfig>,
|
rustls_config: Arc<rustls::ClientConfig>,
|
||||||
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>,
|
mux: Arc<RwLock<ClientMux>>,
|
||||||
hyper_client: Client<TlsWispService, HttpBody>,
|
hyper_client: Client<TlsWispService, HttpBody>,
|
||||||
#[wasm_bindgen(getter_with_clone)]
|
#[wasm_bindgen(getter_with_clone)]
|
||||||
pub useragent: String,
|
pub useragent: String,
|
||||||
|
@ -120,6 +89,7 @@ impl EpoxyClient {
|
||||||
ws_url: String,
|
ws_url: String,
|
||||||
useragent: String,
|
useragent: String,
|
||||||
redirect_limit: usize,
|
redirect_limit: usize,
|
||||||
|
certs: Array,
|
||||||
) -> Result<EpoxyClient, JsError> {
|
) -> Result<EpoxyClient, JsError> {
|
||||||
let ws_uri = ws_url
|
let ws_uri = ws_url
|
||||||
.parse::<uri::Uri>()
|
.parse::<uri::Uri>()
|
||||||
|
@ -137,7 +107,13 @@ impl EpoxyClient {
|
||||||
utils::spawn_mux_fut(mux.clone(), fut, ws_url.clone());
|
utils::spawn_mux_fut(mux.clone(), fut, ws_url.clone());
|
||||||
|
|
||||||
let mut certstore = RootCertStore::empty();
|
let mut certstore = RootCertStore::empty();
|
||||||
certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
let certs: Result<Vec<TrustAnchor>, JsValue> =
|
||||||
|
certs.iter().map(object_to_trustanchor).collect();
|
||||||
|
certstore.extend(
|
||||||
|
certs
|
||||||
|
.replace_err("Failed to get certificates from cert store")?
|
||||||
|
.into_iter(),
|
||||||
|
);
|
||||||
|
|
||||||
let rustls_config = Arc::new(
|
let rustls_config = Arc::new(
|
||||||
rustls::ClientConfig::builder()
|
rustls::ClientConfig::builder()
|
||||||
|
@ -164,7 +140,7 @@ impl EpoxyClient {
|
||||||
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> {
|
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> {
|
||||||
let channel = self
|
let channel = self
|
||||||
.mux
|
.mux
|
||||||
.read()
|
.write()
|
||||||
.await
|
.await
|
||||||
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
|
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -33,7 +33,7 @@ impl EpxUdpStream {
|
||||||
|
|
||||||
let io = tcp
|
let io = tcp
|
||||||
.mux
|
.mux
|
||||||
.read()
|
.write()
|
||||||
.await
|
.await
|
||||||
.client_new_stream(StreamType::Udp, url_host.to_string(), url_port)
|
.client_new_stream(StreamType::Udp, url_host.to_string(), url_port)
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
use crate::*;
|
use crate::*;
|
||||||
|
|
||||||
|
use rustls_pki_types::Der;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen_futures::JsFuture;
|
use wasm_bindgen_futures::JsFuture;
|
||||||
|
|
||||||
use hyper::rt::Executor;
|
use hyper::rt::Executor;
|
||||||
use js_sys::ArrayBuffer;
|
use js_sys::ArrayBuffer;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use wisp_mux::WispError;
|
use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, WispError};
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -194,26 +195,24 @@ pub async fn make_mux(
|
||||||
url: &str,
|
url: &str,
|
||||||
) -> Result<
|
) -> Result<
|
||||||
(
|
(
|
||||||
ClientMux<WebSocketWrapper>,
|
ClientMux,
|
||||||
impl Future<Output = Result<(), WispError>>,
|
impl Future<Output = Result<(), WispError>> + Send,
|
||||||
),
|
),
|
||||||
WispError,
|
WispError,
|
||||||
> {
|
> {
|
||||||
let (wtx, wrx) = WebSocketWrapper::connect(url, vec![])
|
let (wtx, wrx) =
|
||||||
.await
|
WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
|
||||||
.map_err(|_| WispError::WsImplSocketClosed)?;
|
|
||||||
wtx.wait_for_open().await;
|
wtx.wait_for_open().await;
|
||||||
let mux = ClientMux::new(wrx, wtx).await?;
|
ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await
|
||||||
|
|
||||||
Ok(mux)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_mux_fut(
|
pub fn spawn_mux_fut(
|
||||||
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>,
|
mux: Arc<RwLock<ClientMux>>,
|
||||||
fut: impl Future<Output = Result<(), WispError>> + 'static,
|
fut: impl Future<Output = Result<(), WispError>> + Send + 'static,
|
||||||
url: String,
|
url: String,
|
||||||
) {
|
) {
|
||||||
wasm_bindgen_futures::spawn_local(async move {
|
wasm_bindgen_futures::spawn_local(async move {
|
||||||
|
debug!("epoxy: mux future started");
|
||||||
if let Err(e) = fut.await {
|
if let Err(e) = fut.await {
|
||||||
log!("epoxy: error in mux future, restarting: {:?}", e);
|
log!("epoxy: error in mux future, restarting: {:?}", e);
|
||||||
while let Err(e) = replace_mux(mux.clone(), &url).await {
|
while let Err(e) = replace_mux(mux.clone(), &url).await {
|
||||||
|
@ -225,13 +224,10 @@ pub fn spawn_mux_fut(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn replace_mux(
|
pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> {
|
||||||
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>,
|
|
||||||
url: &str,
|
|
||||||
) -> Result<(), WispError> {
|
|
||||||
let (mux_replace, fut) = make_mux(url).await?;
|
let (mux_replace, fut) = make_mux(url).await?;
|
||||||
let mut mux_write = mux.write().await;
|
let mut mux_write = mux.write().await;
|
||||||
mux_write.close().await?;
|
let _ = mux_write.close().await;
|
||||||
*mux_write = mux_replace;
|
*mux_write = mux_replace;
|
||||||
drop(mux_write);
|
drop(mux_write);
|
||||||
spawn_mux_fut(mux, fut, url.into());
|
spawn_mux_fut(mux, fut, url.into());
|
||||||
|
@ -264,3 +260,17 @@ pub async fn jval_to_u8_array_req(val: JsValue) -> Result<(Uint8Array, web_sys::
|
||||||
req,
|
req,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn object_to_trustanchor(obj: JsValue) -> Result<TrustAnchor<'static>, JsValue> {
|
||||||
|
let subject: Uint8Array = Reflect::get(&obj, &jval!("subject"))?.dyn_into()?;
|
||||||
|
let pub_key_info: Uint8Array =
|
||||||
|
Reflect::get(&obj, &jval!("subject_public_key_info"))?.dyn_into()?;
|
||||||
|
let name_constraints: Option<Uint8Array> = Reflect::get(&obj, &jval!("name_constraints"))
|
||||||
|
.and_then(|x| x.dyn_into())
|
||||||
|
.ok();
|
||||||
|
Ok(TrustAnchor {
|
||||||
|
subject: Der::from(subject.to_vec()),
|
||||||
|
subject_public_key_info: Der::from(pub_key_info.to_vec()),
|
||||||
|
name_constraints: name_constraints.map(|x| Der::from(x.to_vec())),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ impl EpxWebSocket {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// ping/pong/continue
|
// ping/pong/continue
|
||||||
_ => {},
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -115,7 +115,13 @@ impl EpxWebSocket {
|
||||||
.call0(&Object::default())
|
.call0(&Object::default())
|
||||||
.replace_err("Failed to call onopen")?;
|
.replace_err("Failed to call onopen")?;
|
||||||
|
|
||||||
Ok(Self { tx, onerror, origin, protocols, url: url.to_string() })
|
Ok(Self {
|
||||||
|
tx,
|
||||||
|
onerror,
|
||||||
|
origin,
|
||||||
|
protocols,
|
||||||
|
url: url.to_string(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
if let Err(ret) = ret {
|
if let Err(ret) = ret {
|
||||||
|
|
|
@ -53,7 +53,7 @@ impl Stream for IncomingBody {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String);
|
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux>>, pub String);
|
||||||
|
|
||||||
impl tower_service::Service<hyper::Uri> for ServiceWrapper {
|
impl tower_service::Service<hyper::Uri> for ServiceWrapper {
|
||||||
type Response = TokioIo<EpxIoUnencryptedStream>;
|
type Response = TokioIo<EpxIoUnencryptedStream>;
|
||||||
|
@ -69,7 +69,7 @@ impl tower_service::Service<hyper::Uri> for ServiceWrapper {
|
||||||
let mux_url = self.1.clone();
|
let mux_url = self.1.clone();
|
||||||
async move {
|
async move {
|
||||||
let stream = mux
|
let stream = mux
|
||||||
.read()
|
.write()
|
||||||
.await
|
.await
|
||||||
.client_new_stream(
|
.client_new_stream(
|
||||||
StreamType::Tcp,
|
StreamType::Tcp,
|
||||||
|
@ -143,6 +143,7 @@ impl tower_service::Service<hyper::Uri> for TlsWispService {
|
||||||
pub enum WebSocketError {
|
pub enum WebSocketError {
|
||||||
Unknown,
|
Unknown,
|
||||||
SendFailed,
|
SendFailed,
|
||||||
|
CloseFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for WebSocketError {
|
impl std::fmt::Display for WebSocketError {
|
||||||
|
@ -151,6 +152,7 @@ impl std::fmt::Display for WebSocketError {
|
||||||
match self {
|
match self {
|
||||||
Unknown => write!(f, "Unknown error"),
|
Unknown => write!(f, "Unknown error"),
|
||||||
SendFailed => write!(f, "Send failed"),
|
SendFailed => write!(f, "Send failed"),
|
||||||
|
CloseFailed => write!(f, "Close failed"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,11 +195,9 @@ pub struct WebSocketReader {
|
||||||
close_event: Arc<Event>,
|
close_event: Arc<Event>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl WebSocketRead for WebSocketReader {
|
impl WebSocketRead for WebSocketReader {
|
||||||
async fn wisp_read_frame(
|
async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
||||||
&mut self,
|
|
||||||
_: &LockedWebSocketWrite<impl WebSocketWrite>,
|
|
||||||
) -> Result<Frame, WispError> {
|
|
||||||
use WebSocketMessage::*;
|
use WebSocketMessage::*;
|
||||||
if self.closed.load(Ordering::Acquire) {
|
if self.closed.load(Ordering::Acquire) {
|
||||||
return Err(WispError::WsImplSocketClosed);
|
return Err(WispError::WsImplSocketClosed);
|
||||||
|
@ -215,10 +215,7 @@ impl WebSocketRead for WebSocketReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebSocketWrapper {
|
impl WebSocketWrapper {
|
||||||
pub async fn connect(
|
pub fn connect(url: &str, protocols: Vec<String>) -> Result<(Self, WebSocketReader), JsValue> {
|
||||||
url: &str,
|
|
||||||
protocols: Vec<String>,
|
|
||||||
) -> Result<(Self, WebSocketReader), JsValue> {
|
|
||||||
let (read_tx, read_rx) = mpsc::unbounded_channel();
|
let (read_tx, read_rx) = mpsc::unbounded_channel();
|
||||||
let closed = Arc::new(AtomicBool::new(false));
|
let closed = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
@ -306,6 +303,7 @@ impl WebSocketWrapper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl WebSocketWrite for WebSocketWrapper {
|
impl WebSocketWrite for WebSocketWrapper {
|
||||||
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> {
|
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> {
|
||||||
use wisp_mux::ws::OpCode::*;
|
use wisp_mux::ws::OpCode::*;
|
||||||
|
@ -328,6 +326,12 @@ impl WebSocketWrite for WebSocketWrapper {
|
||||||
_ => Err(WispError::WsImplNotSupported),
|
_ => Err(WispError::WsImplNotSupported),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||||
|
self.inner
|
||||||
|
.close()
|
||||||
|
.map_err(|_| WebSocketError::CloseFailed.into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for WebSocketWrapper {
|
impl Drop for WebSocketWrapper {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use default_env::default_env;
|
use default_env::default_env;
|
||||||
use epoxy_client::EpoxyClient;
|
use epoxy_client::EpoxyClient;
|
||||||
use js_sys::{JsString, Object, Reflect, Uint8Array, JSON};
|
use js_sys::{Array, JsString, Object, Reflect, Uint8Array, JSON};
|
||||||
|
use rustls_pki_types::TrustAnchor;
|
||||||
use tokio::sync::OnceCell;
|
use tokio::sync::OnceCell;
|
||||||
use wasm_bindgen::JsValue;
|
use wasm_bindgen::JsValue;
|
||||||
use wasm_bindgen_futures::JsFuture;
|
use wasm_bindgen_futures::JsFuture;
|
||||||
|
@ -12,11 +13,40 @@ wasm_bindgen_test_configure!(run_in_dedicated_worker);
|
||||||
static USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36";
|
static USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36";
|
||||||
static EPOXY_CLIENT: OnceCell<EpoxyClient> = OnceCell::const_new();
|
static EPOXY_CLIENT: OnceCell<EpoxyClient> = OnceCell::const_new();
|
||||||
|
|
||||||
|
pub fn trustanchor_to_object(cert: &TrustAnchor) -> Result<JsValue, JsValue> {
|
||||||
|
let val = Object::new();
|
||||||
|
Reflect::set(
|
||||||
|
&val,
|
||||||
|
&JsValue::from("subject"),
|
||||||
|
&Uint8Array::from(cert.subject.as_ref()),
|
||||||
|
)?;
|
||||||
|
Reflect::set(
|
||||||
|
&val,
|
||||||
|
&JsValue::from("subject_public_key_info"),
|
||||||
|
&Uint8Array::from(cert.subject_public_key_info.as_ref()),
|
||||||
|
)?;
|
||||||
|
Reflect::set(
|
||||||
|
&val,
|
||||||
|
&JsValue::from("name_constraints"),
|
||||||
|
&JsValue::from(
|
||||||
|
cert.name_constraints
|
||||||
|
.as_ref()
|
||||||
|
.map(|x| Uint8Array::from(x.as_ref())),
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
Ok(val.into())
|
||||||
|
}
|
||||||
|
|
||||||
async fn get_client_w_ua(useragent: &str, redirect_limit: usize) -> EpoxyClient {
|
async fn get_client_w_ua(useragent: &str, redirect_limit: usize) -> EpoxyClient {
|
||||||
EpoxyClient::new(
|
EpoxyClient::new(
|
||||||
"ws://localhost:4000".into(),
|
"ws://localhost:4000".into(),
|
||||||
useragent.into(),
|
useragent.into(),
|
||||||
redirect_limit,
|
redirect_limit,
|
||||||
|
webpki_roots::TLS_SERVER_ROOTS
|
||||||
|
.iter()
|
||||||
|
.map(trustanchor_to_object)
|
||||||
|
.collect::<Result<Array, JsValue>>()
|
||||||
|
.expect("Failed to create certs"),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.ok()
|
.ok()
|
||||||
|
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
|
@ -0,0 +1 @@
|
||||||
|
imports_granularity = "Crate"
|
|
@ -1,26 +1,35 @@
|
||||||
#![feature(let_chains, ip)]
|
#![feature(let_chains, ip)]
|
||||||
use std::io::Error;
|
use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use fastwebsockets::{
|
use fastwebsockets::{
|
||||||
upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
|
upgrade::{self, UpgradeFut},
|
||||||
WebSocketError,
|
CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError,
|
||||||
};
|
};
|
||||||
use futures_util::{SinkExt, StreamExt, TryFutureExt};
|
use futures_util::{SinkExt, StreamExt, TryFutureExt};
|
||||||
use hyper::{
|
use hyper::{
|
||||||
body::Incoming, server::conn::http1, service::service_fn, Request, Response,
|
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
|
||||||
StatusCode,
|
|
||||||
};
|
};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket};
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
|
use tokio::{
|
||||||
|
io::copy_bidirectional,
|
||||||
|
net::{lookup_host, TcpListener, TcpStream, UdpSocket},
|
||||||
|
};
|
||||||
use tokio_util::codec::{BytesCodec, Framed};
|
use tokio_util::codec::{BytesCodec, Framed};
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
|
|
||||||
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError};
|
use wisp_mux::{
|
||||||
|
extensions::{
|
||||||
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
|
udp::UdpProtocolExtensionBuilder,
|
||||||
|
ProtocolExtensionBuilder,
|
||||||
|
},
|
||||||
|
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
||||||
|
|
||||||
|
@ -54,6 +63,20 @@ struct Cli {
|
||||||
/// Whether the server should block ports other than 80 or 443
|
/// Whether the server should block ports other than 80 or 443
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
block_non_http: bool,
|
block_non_http: bool,
|
||||||
|
/// Path to a file containing `user:password` separated by newlines. This is plaintext!!!
|
||||||
|
///
|
||||||
|
/// `user` cannot contain `:`. Whitespace will be trimmed.
|
||||||
|
#[arg(long)]
|
||||||
|
auth: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MuxOptions {
|
||||||
|
pub block_local: bool,
|
||||||
|
pub block_udp: bool,
|
||||||
|
pub block_non_http: bool,
|
||||||
|
pub enforce_auth: bool,
|
||||||
|
pub auth: Arc<Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
|
@ -136,19 +159,47 @@ async fn main() -> Result<(), Error> {
|
||||||
"/".to_string()
|
"/".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut auth = HashMap::new();
|
||||||
|
let enforce_auth = opt.auth.is_some();
|
||||||
|
if let Some(file) = opt.auth {
|
||||||
|
let file = std::fs::read_to_string(file)?;
|
||||||
|
for entry in file.split('\n').filter_map(|x| {
|
||||||
|
if x.contains(':') {
|
||||||
|
Some(x.trim())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
let split: Vec<_> = entry.split(':').collect();
|
||||||
|
let username = split[0];
|
||||||
|
let password = split[1..].join(":");
|
||||||
|
println!(
|
||||||
|
"adding username {:?} password {:?} to allowed auth",
|
||||||
|
username, password
|
||||||
|
);
|
||||||
|
auth.insert(username.to_string(), password.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let pw_ext = PasswordProtocolExtensionBuilder::new_server(auth);
|
||||||
|
|
||||||
|
let mux_options = MuxOptions {
|
||||||
|
block_local: opt.block_local,
|
||||||
|
block_non_http: opt.block_non_http,
|
||||||
|
block_udp: opt.block_udp,
|
||||||
|
auth: Arc::new(vec![
|
||||||
|
Box::new(UdpProtocolExtensionBuilder()),
|
||||||
|
Box::new(pw_ext),
|
||||||
|
]),
|
||||||
|
enforce_auth,
|
||||||
|
};
|
||||||
|
|
||||||
println!("listening on `{}` with prefix `{}`", addr, prefix);
|
println!("listening on `{}` with prefix `{}`", addr, prefix);
|
||||||
while let Ok((stream, addr)) = socket.accept().await {
|
while let Ok((stream, addr)) = socket.accept().await {
|
||||||
let prefix = prefix.clone();
|
let prefix = prefix.clone();
|
||||||
|
let mux_options = mux_options.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let service = service_fn(move |res| {
|
let service = service_fn(move |res| {
|
||||||
accept_http(
|
accept_http(res, addr.clone(), prefix.clone(), mux_options.clone())
|
||||||
res,
|
|
||||||
addr.clone(),
|
|
||||||
prefix.clone(),
|
|
||||||
opt.block_local,
|
|
||||||
opt.block_udp,
|
|
||||||
opt.block_non_http,
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
let conn = http1::Builder::new()
|
let conn = http1::Builder::new()
|
||||||
.serve_connection(TokioIo::new(stream), service)
|
.serve_connection(TokioIo::new(stream), service)
|
||||||
|
@ -166,9 +217,7 @@ async fn accept_http(
|
||||||
mut req: Request<Incoming>,
|
mut req: Request<Incoming>,
|
||||||
addr: String,
|
addr: String,
|
||||||
prefix: String,
|
prefix: String,
|
||||||
block_local: bool,
|
mux_options: MuxOptions,
|
||||||
block_udp: bool,
|
|
||||||
block_non_http: bool,
|
|
||||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||||
let uri = req.uri().path().to_string();
|
let uri = req.uri().path().to_string();
|
||||||
if upgrade::is_upgrade_request(&req)
|
if upgrade::is_upgrade_request(&req)
|
||||||
|
@ -177,12 +226,17 @@ async fn accept_http(
|
||||||
let (res, fut) = upgrade::upgrade(&mut req)?;
|
let (res, fut) = upgrade::upgrade(&mut req)?;
|
||||||
|
|
||||||
if uri.is_empty() {
|
if uri.is_empty() {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await });
|
||||||
accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await
|
|
||||||
});
|
|
||||||
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
|
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await
|
accept_wsproxy(
|
||||||
|
fut,
|
||||||
|
uri,
|
||||||
|
addr.clone(),
|
||||||
|
mux_options.block_local,
|
||||||
|
mux_options.block_non_http,
|
||||||
|
)
|
||||||
|
.await
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +264,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
|
||||||
.await
|
.await
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||||
let mut mux_stream = stream.into_io().into_asyncrw();
|
let mut mux_stream = stream.into_io().into_asyncrw();
|
||||||
tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream)
|
copy_bidirectional(&mut mux_stream, &mut tcp_stream)
|
||||||
.await
|
.await
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||||
}
|
}
|
||||||
|
@ -245,6 +299,10 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
StreamType::Unknown(_) => {
|
||||||
|
stream.close(CloseReason::ServerStreamInvalidInfo).await?;
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
@ -252,16 +310,43 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
|
||||||
async fn accept_ws(
|
async fn accept_ws(
|
||||||
ws: UpgradeFut,
|
ws: UpgradeFut,
|
||||||
addr: String,
|
addr: String,
|
||||||
block_local: bool,
|
mux_options: MuxOptions,
|
||||||
block_non_http: bool,
|
|
||||||
block_udp: bool,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||||
let (rx, tx) = ws.await?.split(tokio::io::split);
|
let (rx, tx) = ws.await?.split(tokio::io::split);
|
||||||
let rx = FragmentCollectorRead::new(rx);
|
let rx = FragmentCollectorRead::new(rx);
|
||||||
|
|
||||||
println!("{:?}: connected", addr);
|
println!("{:?}: connected", addr);
|
||||||
|
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
|
||||||
|
// size is set to 128
|
||||||
|
let (mut mux, fut) = if mux_options.enforce_auth {
|
||||||
|
let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
|
||||||
|
if !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == PasswordProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"{:?}: client did not support auth or password was invalid",
|
||||||
|
addr
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
(mux, fut)
|
||||||
|
} else {
|
||||||
|
ServerMux::new(
|
||||||
|
rx,
|
||||||
|
tx,
|
||||||
|
128,
|
||||||
|
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX);
|
println!(
|
||||||
|
"{:?}: downgraded: {} extensions supported: {:?}",
|
||||||
|
addr, mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = fut.await {
|
if let Err(e) = fut.await {
|
||||||
|
@ -271,14 +356,14 @@ async fn accept_ws(
|
||||||
|
|
||||||
while let Some((packet, mut stream)) = mux.server_new_stream().await {
|
while let Some((packet, mut stream)) = mux.server_new_stream().await {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if (block_non_http
|
if (mux_options.block_non_http
|
||||||
&& !(packet.destination_port == 80 || packet.destination_port == 443))
|
&& !(packet.destination_port == 80 || packet.destination_port == 443))
|
||||||
|| (block_udp && packet.stream_type == StreamType::Udp)
|
|| (mux_options.block_udp && packet.stream_type == StreamType::Udp)
|
||||||
{
|
{
|
||||||
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
|
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if block_local {
|
if mux_options.block_local {
|
||||||
match lookup_host(format!(
|
match lookup_host(format!(
|
||||||
"{}:{}",
|
"{}:{}",
|
||||||
packet.destination_hostname, packet.destination_port
|
packet.destination_hostname, packet.destination_port
|
||||||
|
@ -310,10 +395,9 @@ async fn accept_ws(
|
||||||
})
|
})
|
||||||
.and_then(|should_send| async move {
|
.and_then(|should_send| async move {
|
||||||
if should_send {
|
if should_send {
|
||||||
close_ok.close(CloseReason::Voluntary).await
|
let _ = close_ok.close(CloseReason::Voluntary).await;
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
|
|
|
@ -11,7 +11,13 @@ use hyper::{
|
||||||
};
|
};
|
||||||
use simple_moving_average::{SingleSumSMA, SMA};
|
use simple_moving_average::{SingleSumSMA, SMA};
|
||||||
use std::{
|
use std::{
|
||||||
error::Error, future::Future, io::{stdout, IsTerminal, Write}, net::SocketAddr, process::exit, sync::Arc, time::{Duration, Instant}, usize
|
error::Error,
|
||||||
|
future::Future,
|
||||||
|
io::{stdout, IsTerminal, Write},
|
||||||
|
net::SocketAddr,
|
||||||
|
process::exit,
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
net::TcpStream,
|
net::TcpStream,
|
||||||
|
@ -21,7 +27,14 @@ use tokio::{
|
||||||
};
|
};
|
||||||
use tokio_native_tls::{native_tls, TlsConnector};
|
use tokio_native_tls::{native_tls, TlsConnector};
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
use wisp_mux::{ClientMux, StreamType, WispError};
|
use wisp_mux::{
|
||||||
|
extensions::{
|
||||||
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
|
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
|
||||||
|
ProtocolExtensionBuilder,
|
||||||
|
},
|
||||||
|
ClientMux, StreamType, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum WispClientError {
|
enum WispClientError {
|
||||||
|
@ -71,6 +84,17 @@ struct Cli {
|
||||||
/// Duration to run the test for
|
/// Duration to run the test for
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
duration: Option<humantime::Duration>,
|
duration: Option<humantime::Duration>,
|
||||||
|
/// Ask for UDP
|
||||||
|
#[arg(short, long)]
|
||||||
|
udp: bool,
|
||||||
|
/// Enable auth: format is `username:password`
|
||||||
|
///
|
||||||
|
/// Usernames and passwords are sent in plaintext!!
|
||||||
|
#[arg(long)]
|
||||||
|
auth: Option<String>,
|
||||||
|
/// Make a Wisp V1 connection
|
||||||
|
#[arg(long)]
|
||||||
|
wisp_v1: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
|
@ -94,6 +118,13 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
let addr_dest = opts.tcp.ip().to_string();
|
let addr_dest = opts.tcp.ip().to_string();
|
||||||
let addr_dest_port = opts.tcp.port();
|
let addr_dest_port = opts.tcp.port();
|
||||||
|
|
||||||
|
let auth = opts.auth.map(|auth| {
|
||||||
|
let split: Vec<_> = auth.split(':').collect();
|
||||||
|
let username = split[0].to_string();
|
||||||
|
let password = split[1..].join(":");
|
||||||
|
PasswordProtocolExtensionBuilder::new_client(username, password)
|
||||||
|
});
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
|
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
|
||||||
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
|
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
|
||||||
|
@ -117,7 +148,6 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
fastwebsockets::handshake::generate_key(),
|
fastwebsockets::handshake::generate_key(),
|
||||||
)
|
)
|
||||||
.header("Sec-WebSocket-Version", "13")
|
.header("Sec-WebSocket-Version", "13")
|
||||||
.header("Sec-WebSocket-Protocol", "wisp-v1")
|
|
||||||
.body(Empty::<Bytes>::new())?;
|
.body(Empty::<Bytes>::new())?;
|
||||||
|
|
||||||
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
|
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
|
||||||
|
@ -125,7 +155,53 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
let (rx, tx) = ws.split(tokio::io::split);
|
let (rx, tx) = ws.split(tokio::io::split);
|
||||||
let rx = FragmentCollectorRead::new(rx);
|
let rx = FragmentCollectorRead::new(rx);
|
||||||
|
|
||||||
let (mux, fut) = ClientMux::new(rx, tx).await?;
|
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
|
||||||
|
if opts.udp {
|
||||||
|
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
|
||||||
|
}
|
||||||
|
let enforce_auth = auth.is_some();
|
||||||
|
if let Some(auth) = auth {
|
||||||
|
extensions.push(Box::new(auth));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut mux, fut) = if opts.wisp_v1 {
|
||||||
|
ClientMux::new(rx, tx, None).await?
|
||||||
|
} else {
|
||||||
|
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if opts.udp
|
||||||
|
&& !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == UdpProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"server did not support udp, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if enforce_auth
|
||||||
|
&& !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == PasswordProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
|
||||||
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
|
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
|
||||||
|
|
||||||
threads.push(tokio::spawn(fut));
|
threads.push(tokio::spawn(fut));
|
||||||
|
@ -177,7 +253,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
avg.get_average() * opts.packet_size,
|
avg.get_average() * opts.packet_size,
|
||||||
);
|
);
|
||||||
if is_term {
|
if is_term {
|
||||||
print!("\x1b[2K{}\r", stat);
|
println!("\x1b[1A\x1b[2K{}\r", stat);
|
||||||
} else {
|
} else {
|
||||||
println!("{}", stat);
|
println!("{}", stat);
|
||||||
}
|
}
|
||||||
|
@ -208,6 +284,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
|
|
||||||
let out = select_all(threads.into_iter()).await;
|
let out = select_all(threads.into_iter()).await;
|
||||||
|
|
||||||
|
let duration_since = Instant::now().duration_since(start_time);
|
||||||
|
|
||||||
if let Err(err) = out.0? {
|
if let Err(err) = out.0? {
|
||||||
println!("\n\nerr: {:?}", err);
|
println!("\n\nerr: {:?}", err);
|
||||||
exit(1);
|
exit(1);
|
||||||
|
@ -215,10 +293,10 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
|
|
||||||
out.2.into_iter().for_each(|x| x.abort());
|
out.2.into_iter().for_each(|x| x.abort());
|
||||||
|
|
||||||
let duration_since = Instant::now().duration_since(start_time);
|
mux.close().await?;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
|
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
|
||||||
cnt.get(),
|
cnt.get(),
|
||||||
opts.packet_size,
|
opts.packet_size,
|
||||||
cnt.get() * opts.packet_size,
|
cnt.get() * opts.packet_size,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "wisp-mux"
|
name = "wisp-mux"
|
||||||
version = "3.0.0"
|
version = "4.0.0"
|
||||||
license = "LGPL-3.0-only"
|
license = "LGPL-3.0-only"
|
||||||
description = "A library for easily creating Wisp servers and clients."
|
description = "A library for easily creating Wisp servers and clients."
|
||||||
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
|
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
|
||||||
|
@ -9,12 +9,15 @@ readme = "README.md"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "0.1.79"
|
||||||
async_io_stream = "0.3.3"
|
async_io_stream = "0.3.3"
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
dashmap = { version = "5.5.3", features = ["inline"] }
|
dashmap = { version = "5.5.3", features = ["inline"] }
|
||||||
event-listener = "5.0.0"
|
event-listener = "5.0.0"
|
||||||
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
|
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
|
||||||
|
flume = "0.11.0"
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
|
futures-timer = "3.0.3"
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
pin-project-lite = "0.2.13"
|
pin-project-lite = "0.2.13"
|
||||||
tokio = { version = "1.35.1", optional = true, default-features = false }
|
tokio = { version = "1.35.1", optional = true, default-features = false }
|
||||||
|
@ -22,6 +25,7 @@ tokio = { version = "1.35.1", optional = true, default-features = false }
|
||||||
[features]
|
[features]
|
||||||
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
||||||
tokio_io = ["async_io_stream/tokio_io"]
|
tokio_io = ["async_io_stream/tokio_io"]
|
||||||
|
wasm = ["futures-timer/wasm-bindgen"]
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
all-features = true
|
all-features = true
|
||||||
|
|
106
wisp/src/extensions/mod.rs
Normal file
106
wisp/src/extensions/mod.rs
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
//! Wisp protocol extensions.
|
||||||
|
pub mod password;
|
||||||
|
pub mod udp;
|
||||||
|
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||||
|
Role, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Type-erased protocol extension that implements Clone.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);
|
||||||
|
|
||||||
|
impl AnyProtocolExtension {
|
||||||
|
/// Create a new type-erased protocol extension.
|
||||||
|
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
|
||||||
|
Self(Box::new(extension))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for AnyProtocolExtension {
|
||||||
|
type Target = dyn ProtocolExtension;
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DerefMut for AnyProtocolExtension {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.0.deref_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for AnyProtocolExtension {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self(self.0.box_clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AnyProtocolExtension> for Bytes {
|
||||||
|
fn from(value: AnyProtocolExtension) -> Self {
|
||||||
|
let mut bytes = BytesMut::with_capacity(5);
|
||||||
|
let payload = value.encode();
|
||||||
|
bytes.put_u8(value.get_id());
|
||||||
|
bytes.put_u32_le(payload.len() as u32);
|
||||||
|
bytes.extend(payload);
|
||||||
|
bytes.freeze()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Wisp protocol extension.
|
||||||
|
///
|
||||||
|
/// See [the
|
||||||
|
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions).
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ProtocolExtension: std::fmt::Debug {
|
||||||
|
/// Get the protocol extension ID.
|
||||||
|
fn get_id(&self) -> u8;
|
||||||
|
/// Get the protocol extension's supported packets.
|
||||||
|
///
|
||||||
|
/// Used to decide whether to call the protocol extension's packet handler.
|
||||||
|
fn get_supported_packets(&self) -> &'static [u8];
|
||||||
|
|
||||||
|
/// Encode self into Bytes.
|
||||||
|
fn encode(&self) -> Bytes;
|
||||||
|
|
||||||
|
/// Handle the handshake part of a Wisp connection.
|
||||||
|
///
|
||||||
|
/// This should be used to send or receive data before any streams are created.
|
||||||
|
async fn handle_handshake(
|
||||||
|
&mut self,
|
||||||
|
read: &mut dyn WebSocketRead,
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError>;
|
||||||
|
|
||||||
|
/// Handle receiving a packet.
|
||||||
|
async fn handle_packet(
|
||||||
|
&mut self,
|
||||||
|
packet: Bytes,
|
||||||
|
read: &mut dyn WebSocketRead,
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError>;
|
||||||
|
|
||||||
|
/// Clone the protocol extension.
|
||||||
|
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait to build a Wisp protocol extension from a payload.
|
||||||
|
pub trait ProtocolExtensionBuilder {
|
||||||
|
/// Get the protocol extension ID.
|
||||||
|
///
|
||||||
|
/// Used to decide whether this builder should be used.
|
||||||
|
fn get_id(&self) -> u8;
|
||||||
|
|
||||||
|
/// Build a protocol extension from the extension's metadata.
|
||||||
|
fn build_from_bytes(&self, bytes: Bytes, role: Role)
|
||||||
|
-> Result<AnyProtocolExtension, WispError>;
|
||||||
|
|
||||||
|
/// Build a protocol extension to send to the other side.
|
||||||
|
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
|
||||||
|
}
|
276
wisp/src/extensions/password.rs
Normal file
276
wisp/src/extensions/password.rs
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
//! Password protocol extension.
|
||||||
|
//!
|
||||||
|
//! Passwords are sent in plain text!!
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//! Server:
|
||||||
|
//! ```
|
||||||
|
//! let mut passwords = HashMap::new();
|
||||||
|
//! passwords.insert("user1".to_string(), "pw".to_string());
|
||||||
|
//! let (mux, fut) = ServerMux::new(
|
||||||
|
//! rx,
|
||||||
|
//! tx,
|
||||||
|
//! 128,
|
||||||
|
//! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))])
|
||||||
|
//! );
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Client:
|
||||||
|
//! ```
|
||||||
|
//! let (mux, fut) = ClientMux::new(
|
||||||
|
//! rx,
|
||||||
|
//! tx,
|
||||||
|
//! 128,
|
||||||
|
//! Some(&[
|
||||||
|
//! Box::new(PasswordProtocolExtensionBuilder::new_client(
|
||||||
|
//! "user1".to_string(),
|
||||||
|
//! "pw".to_string()
|
||||||
|
//! ))
|
||||||
|
//! ])
|
||||||
|
//! );
|
||||||
|
//! ```
|
||||||
|
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication)
|
||||||
|
|
||||||
|
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||||
|
Role, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Password protocol extension.
|
||||||
|
///
|
||||||
|
/// **Passwords are sent in plain text!!**
|
||||||
|
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
||||||
|
/// or the password's length does not fit within a u16.**
|
||||||
|
pub struct PasswordProtocolExtension {
|
||||||
|
/// The username to log in with.
|
||||||
|
///
|
||||||
|
/// This string's length must fit within a u8.
|
||||||
|
pub username: String,
|
||||||
|
/// The password to log in with.
|
||||||
|
///
|
||||||
|
/// This string's length must fit within a u16.
|
||||||
|
pub password: String,
|
||||||
|
role: Role,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PasswordProtocolExtension {
|
||||||
|
/// Password protocol extension ID.
|
||||||
|
pub const ID: u8 = 0x02;
|
||||||
|
|
||||||
|
/// Create a new password protocol extension for the server.
|
||||||
|
///
|
||||||
|
/// This signifies that the server requires a password.
|
||||||
|
pub fn new_server() -> Self {
|
||||||
|
Self {
|
||||||
|
username: String::new(),
|
||||||
|
password: String::new(),
|
||||||
|
role: Role::Server,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new password protocol extension for the client, with a username and password.
|
||||||
|
///
|
||||||
|
/// The username's length must fit within a u8. The password's length must fit within a
|
||||||
|
/// u16.
|
||||||
|
pub fn new_client(username: String, password: String) -> Self {
|
||||||
|
Self {
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
role: Role::Client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ProtocolExtension for PasswordProtocolExtension {
|
||||||
|
fn get_id(&self) -> u8 {
|
||||||
|
Self::ID
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_supported_packets(&self) -> &'static [u8] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self) -> Bytes {
|
||||||
|
match self.role {
|
||||||
|
Role::Server => Bytes::new(),
|
||||||
|
Role::Client => {
|
||||||
|
let username = Bytes::from(self.username.clone().into_bytes());
|
||||||
|
let password = Bytes::from(self.password.clone().into_bytes());
|
||||||
|
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||||
|
let password_len = u16::try_from(password.len()).expect("password was too long");
|
||||||
|
|
||||||
|
let mut bytes =
|
||||||
|
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||||
|
bytes.put_u8(username_len);
|
||||||
|
bytes.put_u16_le(password_len);
|
||||||
|
bytes.extend(username);
|
||||||
|
bytes.extend(password);
|
||||||
|
bytes.freeze()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_handshake(
|
||||||
|
&mut self,
|
||||||
|
_: &mut dyn WebSocketRead,
|
||||||
|
_: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_packet(
|
||||||
|
&mut self,
|
||||||
|
_: Bytes,
|
||||||
|
_: &mut dyn WebSocketRead,
|
||||||
|
_: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum PasswordProtocolExtensionError {
|
||||||
|
Utf8Error(FromUtf8Error),
|
||||||
|
InvalidUsername,
|
||||||
|
InvalidPassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for PasswordProtocolExtensionError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
use PasswordProtocolExtensionError as E;
|
||||||
|
match self {
|
||||||
|
E::Utf8Error(e) => write!(f, "{}", e),
|
||||||
|
E::InvalidUsername => write!(f, "Invalid username"),
|
||||||
|
E::InvalidPassword => write!(f, "Invalid password"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for PasswordProtocolExtensionError {}
|
||||||
|
|
||||||
|
impl From<PasswordProtocolExtensionError> for WispError {
|
||||||
|
fn from(value: PasswordProtocolExtensionError) -> Self {
|
||||||
|
WispError::ExtensionImplError(Box::new(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<FromUtf8Error> for PasswordProtocolExtensionError {
|
||||||
|
fn from(value: FromUtf8Error) -> Self {
|
||||||
|
PasswordProtocolExtensionError::Utf8Error(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
|
||||||
|
fn from(value: PasswordProtocolExtension) -> Self {
|
||||||
|
AnyProtocolExtension(Box::new(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Password protocol extension builder.
|
||||||
|
///
|
||||||
|
/// **Passwords are sent in plain text!!**
|
||||||
|
pub struct PasswordProtocolExtensionBuilder {
|
||||||
|
/// Map of users and their passwords to allow. Only used on server.
|
||||||
|
pub users: HashMap<String, String>,
|
||||||
|
/// Username to authenticate with. Only used on client.
|
||||||
|
pub username: String,
|
||||||
|
/// Password to authenticate with. Only used on client.
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PasswordProtocolExtensionBuilder {
|
||||||
|
/// Create a new password protocol extension builder for the server, with a map of users
|
||||||
|
/// and passwords to allow.
|
||||||
|
pub fn new_server(users: HashMap<String, String>) -> Self {
|
||||||
|
Self {
|
||||||
|
users,
|
||||||
|
username: String::new(),
|
||||||
|
password: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new password protocol extension builder for the client, with a username and
|
||||||
|
/// password to authenticate with.
|
||||||
|
pub fn new_client(username: String, password: String) -> Self {
|
||||||
|
Self {
|
||||||
|
users: HashMap::new(),
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
||||||
|
fn get_id(&self) -> u8 {
|
||||||
|
PasswordProtocolExtension::ID
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_from_bytes(
|
||||||
|
&self,
|
||||||
|
mut payload: Bytes,
|
||||||
|
role: crate::Role,
|
||||||
|
) -> Result<AnyProtocolExtension, WispError> {
|
||||||
|
match role {
|
||||||
|
Role::Server => {
|
||||||
|
if payload.remaining() < 3 {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
|
||||||
|
let username_len = payload.get_u8();
|
||||||
|
let password_len = payload.get_u16_le();
|
||||||
|
if payload.remaining() < (password_len + username_len as u16) as usize {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
|
||||||
|
use PasswordProtocolExtensionError as EError;
|
||||||
|
let username =
|
||||||
|
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
|
||||||
|
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||||
|
let password =
|
||||||
|
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
|
||||||
|
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||||
|
|
||||||
|
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
|
||||||
|
return Err(EError::InvalidUsername.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
if *user.1 != password {
|
||||||
|
return Err(EError::InvalidPassword.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(PasswordProtocolExtension {
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
role,
|
||||||
|
}
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
Role::Client => {
|
||||||
|
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
|
||||||
|
match role {
|
||||||
|
Role::Server => PasswordProtocolExtension::new_server(),
|
||||||
|
Role::Client => {
|
||||||
|
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
}
|
93
wisp/src/extensions/udp.rs
Normal file
93
wisp/src/extensions/udp.rs
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
//! UDP protocol extension.
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//! ```
|
||||||
|
//! let (mux, fut) = ServerMux::new(
|
||||||
|
//! rx,
|
||||||
|
//! tx,
|
||||||
|
//! 128,
|
||||||
|
//! Some(&[Box::new(UdpProtocolExtensionBuilder())])
|
||||||
|
//! );
|
||||||
|
//! ```
|
||||||
|
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp)
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||||
|
WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
/// UDP protocol extension.
|
||||||
|
pub struct UdpProtocolExtension();
|
||||||
|
|
||||||
|
impl UdpProtocolExtension {
|
||||||
|
/// UDP protocol extension ID.
|
||||||
|
pub const ID: u8 = 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ProtocolExtension for UdpProtocolExtension {
|
||||||
|
fn get_id(&self) -> u8 {
|
||||||
|
Self::ID
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_supported_packets(&self) -> &'static [u8] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self) -> Bytes {
|
||||||
|
Bytes::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_handshake(
|
||||||
|
&mut self,
|
||||||
|
_: &mut dyn WebSocketRead,
|
||||||
|
_: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_packet(
|
||||||
|
&mut self,
|
||||||
|
_: Bytes,
|
||||||
|
_: &mut dyn WebSocketRead,
|
||||||
|
_: &LockedWebSocketWrite,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||||
|
Box::new(Self())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<UdpProtocolExtension> for AnyProtocolExtension {
|
||||||
|
fn from(value: UdpProtocolExtension) -> Self {
|
||||||
|
AnyProtocolExtension(Box::new(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// UDP protocol extension builder.
|
||||||
|
pub struct UdpProtocolExtensionBuilder();
|
||||||
|
|
||||||
|
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
||||||
|
fn get_id(&self) -> u8 {
|
||||||
|
UdpProtocolExtension::ID
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_from_bytes(
|
||||||
|
&self,
|
||||||
|
_: Bytes,
|
||||||
|
_: crate::Role,
|
||||||
|
) -> Result<AnyProtocolExtension, WispError> {
|
||||||
|
Ok(UdpProtocolExtension().into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||||
|
UdpProtocolExtension().into()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,9 +1,14 @@
|
||||||
use bytes::Bytes;
|
use std::ops::Deref;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::BytesMut;
|
||||||
use fastwebsockets::{
|
use fastwebsockets::{
|
||||||
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||||
};
|
};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
|
use crate::{ws::LockedWebSocketWrite, WispError};
|
||||||
|
|
||||||
impl From<OpCode> for crate::ws::OpCode {
|
impl From<OpCode> for crate::ws::OpCode {
|
||||||
fn from(opcode: OpCode) -> Self {
|
fn from(opcode: OpCode) -> Self {
|
||||||
use OpCode::*;
|
use OpCode::*;
|
||||||
|
@ -21,29 +26,25 @@ impl From<OpCode> for crate::ws::OpCode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Frame<'_>> for crate::ws::Frame {
|
impl From<Frame<'_>> for crate::ws::Frame {
|
||||||
fn from(mut frame: Frame) -> Self {
|
fn from(frame: Frame) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: frame.fin,
|
finished: frame.fin,
|
||||||
opcode: frame.opcode.into(),
|
opcode: frame.opcode.into(),
|
||||||
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
payload: BytesMut::from(frame.payload.deref()).freeze(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<crate::ws::Frame> for Frame<'_> {
|
impl<'a> From<crate::ws::Frame> for Frame<'a> {
|
||||||
fn from(frame: crate::ws::Frame) -> Self {
|
fn from(frame: crate::ws::Frame) -> Self {
|
||||||
use crate::ws::OpCode::*;
|
use crate::ws::OpCode::*;
|
||||||
|
let payload = Payload::Owned(frame.payload.into());
|
||||||
match frame.opcode {
|
match frame.opcode {
|
||||||
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
Text => Self::text(payload),
|
||||||
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
Binary => Self::binary(payload),
|
||||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
|
Close => Self::close_raw(payload),
|
||||||
Ping => Self::new(
|
Ping => Self::new(true, OpCode::Ping, None, payload),
|
||||||
true,
|
Pong => Self::pong(payload),
|
||||||
OpCode::Ping,
|
|
||||||
None,
|
|
||||||
Payload::Owned(frame.payload.to_vec()),
|
|
||||||
),
|
|
||||||
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -58,11 +59,12 @@ impl From<WebSocketError> for crate::WispError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
#[async_trait]
|
||||||
|
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||||
async fn wisp_read_frame(
|
async fn wisp_read_frame(
|
||||||
&mut self,
|
&mut self,
|
||||||
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
tx: &LockedWebSocketWrite,
|
||||||
) -> Result<crate::ws::Frame, crate::WispError> {
|
) -> Result<crate::ws::Frame, WispError> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||||
.await?
|
.await?
|
||||||
|
@ -70,8 +72,15 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
#[async_trait]
|
||||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||||
|
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
|
||||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||||
|
self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
671
wisp/src/lib.rs
671
wisp/src/lib.rs
|
@ -1,9 +1,10 @@
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs, warnings)]
|
||||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
//! A library for easily creating [Wisp] clients and servers.
|
//! A library for easily creating [Wisp] clients and servers.
|
||||||
//!
|
//!
|
||||||
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
|
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
|
||||||
|
|
||||||
|
pub mod extensions;
|
||||||
#[cfg(feature = "fastwebsockets")]
|
#[cfg(feature = "fastwebsockets")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
|
||||||
mod fastwebsockets;
|
mod fastwebsockets;
|
||||||
|
@ -12,18 +13,26 @@ mod sink_unfold;
|
||||||
mod stream;
|
mod stream;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
pub use crate::packet::*;
|
pub use crate::{packet::*, stream::*};
|
||||||
pub use crate::stream::*;
|
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use event_listener::Event;
|
use event_listener::Event;
|
||||||
use futures::SinkExt;
|
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||||
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
|
use flume as mpsc;
|
||||||
use std::sync::{
|
use futures::{channel::oneshot, select, Future, FutureExt};
|
||||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
use futures_timer::Delay;
|
||||||
Arc,
|
use std::{
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
use ws::AppendingWebSocketRead;
|
||||||
|
|
||||||
|
/// Wisp version supported by this crate.
|
||||||
|
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||||
|
|
||||||
/// The role of the multiplexor.
|
/// The role of the multiplexor.
|
||||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||||
|
@ -37,29 +46,29 @@ pub enum Role {
|
||||||
/// Errors the Wisp implementation can return.
|
/// Errors the Wisp implementation can return.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum WispError {
|
pub enum WispError {
|
||||||
/// The packet recieved did not have enough data.
|
/// The packet received did not have enough data.
|
||||||
PacketTooSmall,
|
PacketTooSmall,
|
||||||
/// The packet recieved had an invalid type.
|
/// The packet received had an invalid type.
|
||||||
InvalidPacketType,
|
InvalidPacketType,
|
||||||
/// The stream had an invalid type.
|
|
||||||
InvalidStreamType,
|
|
||||||
/// The stream had an invalid ID.
|
/// The stream had an invalid ID.
|
||||||
InvalidStreamId,
|
InvalidStreamId,
|
||||||
/// The close packet had an invalid reason.
|
/// The close packet had an invalid reason.
|
||||||
InvalidCloseReason,
|
InvalidCloseReason,
|
||||||
/// The URI recieved was invalid.
|
/// The URI received was invalid.
|
||||||
InvalidUri,
|
InvalidUri,
|
||||||
/// The URI recieved had no host.
|
/// The URI received had no host.
|
||||||
UriHasNoHost,
|
UriHasNoHost,
|
||||||
/// The URI recieved had no port.
|
/// The URI received had no port.
|
||||||
UriHasNoPort,
|
UriHasNoPort,
|
||||||
/// The max stream count was reached.
|
/// The max stream count was reached.
|
||||||
MaxStreamCountReached,
|
MaxStreamCountReached,
|
||||||
|
/// The Wisp protocol version was incompatible.
|
||||||
|
IncompatibleProtocolVersion,
|
||||||
/// The stream had already been closed.
|
/// The stream had already been closed.
|
||||||
StreamAlreadyClosed,
|
StreamAlreadyClosed,
|
||||||
/// The websocket frame recieved had an invalid type.
|
/// The websocket frame received had an invalid type.
|
||||||
WsFrameInvalidType,
|
WsFrameInvalidType,
|
||||||
/// The websocket frame recieved was not finished.
|
/// The websocket frame received was not finished.
|
||||||
WsFrameNotFinished,
|
WsFrameNotFinished,
|
||||||
/// Error specific to the websocket implementation.
|
/// Error specific to the websocket implementation.
|
||||||
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||||
|
@ -67,17 +76,33 @@ pub enum WispError {
|
||||||
WsImplSocketClosed,
|
WsImplSocketClosed,
|
||||||
/// The websocket implementation did not support the action.
|
/// The websocket implementation did not support the action.
|
||||||
WsImplNotSupported,
|
WsImplNotSupported,
|
||||||
|
/// Error specific to the protocol extension implementation.
|
||||||
|
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||||
|
/// The protocol extension implementation did not support the action.
|
||||||
|
ExtensionImplNotSupported,
|
||||||
|
/// The UDP protocol extension is not supported by the server.
|
||||||
|
UdpExtensionNotSupported,
|
||||||
/// The string was invalid UTF-8.
|
/// The string was invalid UTF-8.
|
||||||
Utf8Error(std::str::Utf8Error),
|
Utf8Error(std::str::Utf8Error),
|
||||||
|
/// The integer failed to convert.
|
||||||
|
TryFromIntError(std::num::TryFromIntError),
|
||||||
/// Other error.
|
/// Other error.
|
||||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||||
/// Failed to send message to multiplexor task.
|
/// Failed to send message to multiplexor task.
|
||||||
MuxMessageFailedToSend,
|
MuxMessageFailedToSend,
|
||||||
|
/// Failed to receive message from multiplexor task.
|
||||||
|
MuxMessageFailedToRecv,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<std::str::Utf8Error> for WispError {
|
impl From<std::str::Utf8Error> for WispError {
|
||||||
fn from(err: std::str::Utf8Error) -> WispError {
|
fn from(err: std::str::Utf8Error) -> Self {
|
||||||
WispError::Utf8Error(err)
|
Self::Utf8Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::num::TryFromIntError> for WispError {
|
||||||
|
fn from(value: std::num::TryFromIntError) -> Self {
|
||||||
|
Self::TryFromIntError(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,13 +111,13 @@ impl std::fmt::Display for WispError {
|
||||||
match self {
|
match self {
|
||||||
Self::PacketTooSmall => write!(f, "Packet too small"),
|
Self::PacketTooSmall => write!(f, "Packet too small"),
|
||||||
Self::InvalidPacketType => write!(f, "Invalid packet type"),
|
Self::InvalidPacketType => write!(f, "Invalid packet type"),
|
||||||
Self::InvalidStreamType => write!(f, "Invalid stream type"),
|
|
||||||
Self::InvalidStreamId => write!(f, "Invalid stream id"),
|
Self::InvalidStreamId => write!(f, "Invalid stream id"),
|
||||||
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
|
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
|
||||||
Self::InvalidUri => write!(f, "Invalid URI"),
|
Self::InvalidUri => write!(f, "Invalid URI"),
|
||||||
Self::UriHasNoHost => write!(f, "URI has no host"),
|
Self::UriHasNoHost => write!(f, "URI has no host"),
|
||||||
Self::UriHasNoPort => write!(f, "URI has no port"),
|
Self::UriHasNoPort => write!(f, "URI has no port"),
|
||||||
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||||
|
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
|
||||||
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||||
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||||
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||||
|
@ -103,9 +128,21 @@ impl std::fmt::Display for WispError {
|
||||||
Self::WsImplNotSupported => {
|
Self::WsImplNotSupported => {
|
||||||
write!(f, "Websocket implementation error: unsupported feature")
|
write!(f, "Websocket implementation error: unsupported feature")
|
||||||
}
|
}
|
||||||
|
Self::ExtensionImplError(err) => {
|
||||||
|
write!(f, "Protocol extension implementation error: {}", err)
|
||||||
|
}
|
||||||
|
Self::ExtensionImplNotSupported => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Protocol extension implementation error: unsupported feature"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"),
|
||||||
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||||
|
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
|
||||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||||
|
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,36 +150,36 @@ impl std::fmt::Display for WispError {
|
||||||
impl std::error::Error for WispError {}
|
impl std::error::Error for WispError {}
|
||||||
|
|
||||||
struct MuxMapValue {
|
struct MuxMapValue {
|
||||||
stream: mpsc::UnboundedSender<Bytes>,
|
stream: mpsc::Sender<Bytes>,
|
||||||
stream_type: StreamType,
|
stream_type: StreamType,
|
||||||
flow_control: Arc<AtomicU32>,
|
flow_control: Arc<AtomicU32>,
|
||||||
flow_control_event: Arc<Event>,
|
flow_control_event: Arc<Event>,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
|
is_closed_event: Arc<Event>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MuxInner<W>
|
struct MuxInner {
|
||||||
where
|
tx: ws::LockedWebSocketWrite,
|
||||||
W: ws::WebSocketWrite,
|
stream_map: DashMap<u32, MuxMapValue>,
|
||||||
{
|
buffer_size: u32,
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
|
||||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: ws::WebSocketWrite> MuxInner<W> {
|
impl MuxInner {
|
||||||
pub async fn server_into_future<R>(
|
pub async fn server_into_future<R>(
|
||||||
self,
|
self,
|
||||||
rx: R,
|
rx: R,
|
||||||
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
close_rx: mpsc::Receiver<WsEvent>,
|
||||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||||
buffer_size: u32,
|
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
close_tx: mpsc::Sender<WsEvent>,
|
||||||
) -> Result<(), WispError>
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
self.into_future(
|
self.as_future(
|
||||||
close_rx,
|
close_rx,
|
||||||
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
|
close_tx.clone(),
|
||||||
|
self.server_loop(rx, extensions, muxstream_sender, close_tx),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
@ -150,34 +187,83 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||||
pub async fn client_into_future<R>(
|
pub async fn client_into_future<R>(
|
||||||
self,
|
self,
|
||||||
rx: R,
|
rx: R,
|
||||||
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
close_rx: mpsc::Receiver<WsEvent>,
|
||||||
|
close_tx: mpsc::Sender<WsEvent>,
|
||||||
) -> Result<(), WispError>
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
self.into_future(close_rx, self.client_loop(rx)).await
|
self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn into_future(
|
async fn as_future(
|
||||||
&self,
|
&self,
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
close_rx: mpsc::Receiver<WsEvent>,
|
||||||
|
close_tx: mpsc::Sender<WsEvent>,
|
||||||
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
||||||
) -> Result<(), WispError> {
|
) -> Result<(), WispError> {
|
||||||
let ret = futures::select! {
|
let ret = futures::select! {
|
||||||
_ = self.stream_loop(close_rx).fuse() => Ok(()),
|
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
||||||
x = wisp_fut.fuse() => x,
|
x = wisp_fut.fuse() => x,
|
||||||
};
|
};
|
||||||
self.stream_map.iter_mut().for_each(|mut x| {
|
for x in self.stream_map.iter_mut() {
|
||||||
x.is_closed.store(true, Ordering::Release);
|
x.is_closed.store(true, Ordering::Release);
|
||||||
x.stream.disconnect();
|
x.is_closed_event.notify(usize::MAX);
|
||||||
x.stream.close_channel();
|
}
|
||||||
});
|
|
||||||
self.stream_map.clear();
|
self.stream_map.clear();
|
||||||
|
let _ = self.tx.close().await;
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
|
async fn create_new_stream(
|
||||||
while let Some(msg) = stream_rx.next().await {
|
&self,
|
||||||
|
stream_id: u32,
|
||||||
|
stream_type: StreamType,
|
||||||
|
role: Role,
|
||||||
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
|
target_buffer_size: u32,
|
||||||
|
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||||
|
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
||||||
|
|
||||||
|
let flow_control_event: Arc<Event> = Event::new().into();
|
||||||
|
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||||
|
|
||||||
|
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||||
|
let is_closed_event: Arc<Event> = Event::new().into();
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
MuxMapValue {
|
||||||
|
stream: ch_tx,
|
||||||
|
stream_type,
|
||||||
|
flow_control: flow_control.clone(),
|
||||||
|
flow_control_event: flow_control_event.clone(),
|
||||||
|
is_closed: is_closed.clone(),
|
||||||
|
is_closed_event: is_closed_event.clone(),
|
||||||
|
},
|
||||||
|
MuxStream::new(
|
||||||
|
stream_id,
|
||||||
|
role,
|
||||||
|
stream_type,
|
||||||
|
ch_rx,
|
||||||
|
stream_tx.clone(),
|
||||||
|
is_closed,
|
||||||
|
is_closed_event,
|
||||||
|
flow_control,
|
||||||
|
flow_control_event,
|
||||||
|
target_buffer_size,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_loop(
|
||||||
|
&self,
|
||||||
|
stream_rx: mpsc::Receiver<WsEvent>,
|
||||||
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
|
) {
|
||||||
|
let mut next_free_stream_id: u32 = 1;
|
||||||
|
while let Ok(msg) = stream_rx.recv_async().await {
|
||||||
match msg {
|
match msg {
|
||||||
WsEvent::SendPacket(packet, channel) => {
|
WsEvent::SendPacket(packet, channel) => {
|
||||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
if self.stream_map.get(&packet.stream_id).is_some() {
|
||||||
|
@ -186,16 +272,55 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||||
|
let ret: Result<MuxStream, WispError> = async {
|
||||||
|
let stream_id = next_free_stream_id;
|
||||||
|
let next_stream_id = next_free_stream_id
|
||||||
|
.checked_add(1)
|
||||||
|
.ok_or(WispError::MaxStreamCountReached)?;
|
||||||
|
|
||||||
|
let (map_value, stream) = self
|
||||||
|
.create_new_stream(
|
||||||
|
stream_id,
|
||||||
|
stream_type,
|
||||||
|
Role::Client,
|
||||||
|
stream_tx.clone(),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.tx
|
||||||
|
.write_frame(
|
||||||
|
Packet::new_connect(stream_id, stream_type, port, host).into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.stream_map.insert(stream_id, map_value);
|
||||||
|
|
||||||
|
next_free_stream_id = next_stream_id;
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
let _ = channel.send(ret);
|
||||||
|
}
|
||||||
WsEvent::Close(packet, channel) => {
|
WsEvent::Close(packet, channel) => {
|
||||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
stream.stream.disconnect();
|
|
||||||
stream.stream.close_channel();
|
|
||||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||||
|
drop(stream.stream)
|
||||||
} else {
|
} else {
|
||||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsEvent::EndFut => break,
|
WsEvent::EndFut(x) => {
|
||||||
|
if let Some(reason) = x {
|
||||||
|
let _ = self
|
||||||
|
.tx
|
||||||
|
.write_frame(Packet::new_close(0, reason).into())
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,122 +328,115 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||||
async fn server_loop<R>(
|
async fn server_loop<R>(
|
||||||
&self,
|
&self,
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
mut extensions: Vec<AnyProtocolExtension>,
|
||||||
buffer_size: u32,
|
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
) -> Result<(), WispError>
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
// will send continues once flow_control is at 10% of max
|
// will send continues once flow_control is at 10% of max
|
||||||
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
|
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
|
||||||
self.tx
|
|
||||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||||
if frame.opcode == ws::OpCode::Close {
|
if frame.opcode == ws::OpCode::Close {
|
||||||
break Ok(());
|
break Ok(());
|
||||||
}
|
}
|
||||||
let packet = Packet::try_from(frame)?;
|
if let Some(packet) =
|
||||||
|
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
||||||
use PacketType::*;
|
{
|
||||||
match packet.packet_type {
|
use PacketType::*;
|
||||||
Connect(inner_packet) => {
|
match packet.packet_type {
|
||||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
Connect(inner_packet) => {
|
||||||
let stream_type = inner_packet.stream_type;
|
let (map_value, stream) = self
|
||||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
|
.create_new_stream(
|
||||||
let flow_control_event: Arc<Event> = Event::new().into();
|
|
||||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
|
||||||
|
|
||||||
self.stream_map.insert(
|
|
||||||
packet.stream_id,
|
|
||||||
MuxMapValue {
|
|
||||||
stream: ch_tx,
|
|
||||||
stream_type,
|
|
||||||
flow_control: flow_control.clone(),
|
|
||||||
flow_control_event: flow_control_event.clone(),
|
|
||||||
is_closed: is_closed.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
muxstream_sender
|
|
||||||
.unbounded_send((
|
|
||||||
inner_packet,
|
|
||||||
MuxStream::new(
|
|
||||||
packet.stream_id,
|
packet.stream_id,
|
||||||
|
inner_packet.stream_type,
|
||||||
Role::Server,
|
Role::Server,
|
||||||
stream_type,
|
stream_tx.clone(),
|
||||||
ch_rx,
|
|
||||||
close_tx.clone(),
|
|
||||||
is_closed,
|
|
||||||
flow_control,
|
|
||||||
flow_control_event,
|
|
||||||
target_buffer_size,
|
target_buffer_size,
|
||||||
),
|
)
|
||||||
))
|
.await?;
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
muxstream_sender
|
||||||
}
|
.send_async((inner_packet, stream))
|
||||||
Data(data) => {
|
.await
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
let _ = stream.stream.unbounded_send(data);
|
self.stream_map.insert(packet.stream_id, map_value);
|
||||||
if stream.stream_type == StreamType::Tcp {
|
}
|
||||||
stream.flow_control.store(
|
Data(data) => {
|
||||||
stream
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
.flow_control
|
let _ = stream.stream.send_async(data).await;
|
||||||
.load(Ordering::Acquire)
|
if stream.stream_type == StreamType::Tcp {
|
||||||
.saturating_sub(1),
|
stream.flow_control.store(
|
||||||
Ordering::Release,
|
stream
|
||||||
);
|
.flow_control
|
||||||
|
.load(Ordering::Acquire)
|
||||||
|
.saturating_sub(1),
|
||||||
|
Ordering::Release,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||||
Continue(_) => break Err(WispError::InvalidPacketType),
|
Close(_) => {
|
||||||
Close(_) => {
|
if packet.stream_id == 0 {
|
||||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
break Ok(());
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
}
|
||||||
stream.stream.disconnect();
|
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
stream.stream.close_channel();
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
|
stream.is_closed_event.notify(usize::MAX);
|
||||||
|
drop(stream.stream)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
async fn client_loop<R>(
|
||||||
|
&self,
|
||||||
|
mut rx: R,
|
||||||
|
mut extensions: Vec<AnyProtocolExtension>,
|
||||||
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
loop {
|
loop {
|
||||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||||
if frame.opcode == ws::OpCode::Close {
|
if frame.opcode == ws::OpCode::Close {
|
||||||
break Ok(());
|
break Ok(());
|
||||||
}
|
}
|
||||||
let packet = Packet::try_from(frame)?;
|
if let Some(packet) =
|
||||||
|
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
||||||
use PacketType::*;
|
{
|
||||||
match packet.packet_type {
|
use PacketType::*;
|
||||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
match packet.packet_type {
|
||||||
Data(data) => {
|
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
Data(data) => {
|
||||||
let _ = stream.stream.unbounded_send(data);
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
}
|
let _ = stream.stream.send_async(data).await;
|
||||||
}
|
|
||||||
Continue(inner_packet) => {
|
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
|
||||||
if stream.stream_type == StreamType::Tcp {
|
|
||||||
stream
|
|
||||||
.flow_control
|
|
||||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
|
||||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Continue(inner_packet) => {
|
||||||
Close(_) => {
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
if stream.stream_type == StreamType::Tcp {
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
stream
|
||||||
stream.stream.disconnect();
|
.flow_control
|
||||||
stream.stream.close_channel();
|
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||||
|
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Close(_) => {
|
||||||
|
if packet.stream_id == 0 {
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
|
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
|
stream.is_closed_event.notify(usize::MAX);
|
||||||
|
drop(stream.stream)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -332,7 +450,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||||
/// ```
|
/// ```
|
||||||
/// use wisp_mux::ServerMux;
|
/// use wisp_mux::ServerMux;
|
||||||
///
|
///
|
||||||
/// let (mux, fut) = ServerMux::new(rx, tx, 128);
|
/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([]));
|
||||||
/// tokio::spawn(async move {
|
/// tokio::spawn(async move {
|
||||||
/// if let Err(e) = fut.await {
|
/// if let Err(e) = fut.await {
|
||||||
/// println!("error in multiplexor: {:?}", e);
|
/// println!("error in multiplexor: {:?}", e);
|
||||||
|
@ -346,39 +464,103 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub struct ServerMux {
|
pub struct ServerMux {
|
||||||
|
/// Whether the connection was downgraded to Wisp v1.
|
||||||
|
///
|
||||||
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
|
pub downgraded: bool,
|
||||||
|
/// Extensions that are supported by both sides.
|
||||||
|
pub supported_extension_ids: Vec<u8>,
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
close_tx: mpsc::Sender<WsEvent>,
|
||||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerMux {
|
impl ServerMux {
|
||||||
/// Create a new server-side multiplexor.
|
/// Create a new server-side multiplexor.
|
||||||
pub fn new<R, W: ws::WebSocketWrite>(
|
///
|
||||||
read: R,
|
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
|
pub async fn new<R, W>(
|
||||||
|
mut read: R,
|
||||||
write: W,
|
write: W,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
) -> (Self, impl Future<Output = Result<(), WispError>>)
|
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||||
|
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
{
|
{
|
||||||
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
|
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||||
let write = ws::LockedWebSocketWrite::new(write);
|
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||||
(
|
|
||||||
|
write
|
||||||
|
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut supported_extensions = Vec::new();
|
||||||
|
let mut extra_packet = Vec::with_capacity(1);
|
||||||
|
let mut downgraded = true;
|
||||||
|
|
||||||
|
if let Some(builders) = extension_builders {
|
||||||
|
let extensions: Vec<_> = builders
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.build_to_extension(Role::Server))
|
||||||
|
.collect();
|
||||||
|
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||||
|
write
|
||||||
|
.write_frame(Packet::new_info(extensions).into())
|
||||||
|
.await?;
|
||||||
|
if let Some(frame) = select! {
|
||||||
|
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||||
|
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||||
|
} {
|
||||||
|
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||||
|
if let PacketType::Info(info) = packet.packet_type {
|
||||||
|
supported_extensions = info
|
||||||
|
.extensions
|
||||||
|
.into_iter()
|
||||||
|
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||||
|
.collect();
|
||||||
|
downgraded = false;
|
||||||
|
} else {
|
||||||
|
extra_packet.push(packet.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((
|
||||||
Self {
|
Self {
|
||||||
muxstream_recv: rx,
|
muxstream_recv: rx,
|
||||||
close_tx: close_tx.clone(),
|
close_tx: close_tx.clone(),
|
||||||
|
downgraded,
|
||||||
|
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
|
||||||
},
|
},
|
||||||
MuxInner {
|
MuxInner {
|
||||||
tx: write,
|
tx: write,
|
||||||
stream_map: DashMap::new().into(),
|
stream_map: DashMap::new(),
|
||||||
|
buffer_size,
|
||||||
}
|
}
|
||||||
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
|
.server_into_future(
|
||||||
)
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
supported_extensions,
|
||||||
|
close_rx,
|
||||||
|
tx,
|
||||||
|
close_tx,
|
||||||
|
),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wait for a stream to be created.
|
/// Wait for a stream to be created.
|
||||||
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
|
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
|
||||||
self.muxstream_recv.next().await
|
self.muxstream_recv.recv_async().await.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
self.close_tx
|
||||||
|
.send_async(WsEvent::EndFut(reason))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
|
@ -386,19 +568,32 @@ impl ServerMux {
|
||||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||||
/// this function is called.
|
/// this function is called.
|
||||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||||
self.close_tx
|
self.close_internal(None).await
|
||||||
.send(WsEvent::EndFut)
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
|
||||||
|
/// this function is called.
|
||||||
|
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for ServerMux {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.close_tx.send(WsEvent::EndFut(None));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Client side multiplexor.
|
/// Client side multiplexor.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
/// ```
|
/// ```
|
||||||
/// use wisp_mux::{ClientMux, StreamType};
|
/// use wisp_mux::{ClientMux, StreamType};
|
||||||
///
|
///
|
||||||
/// let (mux, fut) = ClientMux::new(rx, tx).await?;
|
/// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?;
|
||||||
/// tokio::spawn(async move {
|
/// tokio::spawn(async move {
|
||||||
/// if let Err(e) = fut.await {
|
/// if let Err(e) = fut.await {
|
||||||
/// println!("error in multiplexor: {:?}", e);
|
/// println!("error in multiplexor: {:?}", e);
|
||||||
|
@ -406,50 +601,93 @@ impl ServerMux {
|
||||||
/// });
|
/// });
|
||||||
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
|
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
|
||||||
/// ```
|
/// ```
|
||||||
pub struct ClientMux<W>
|
pub struct ClientMux {
|
||||||
where
|
/// Whether the connection was downgraded to Wisp v1.
|
||||||
W: ws::WebSocketWrite,
|
///
|
||||||
{
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
pub downgraded: bool,
|
||||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
/// Extensions that are supported by both sides.
|
||||||
next_free_stream_id: AtomicU32,
|
pub supported_extension_ids: Vec<u8>,
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
buf_size: u32,
|
|
||||||
target_buf_size: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: ws::WebSocketWrite> ClientMux<W> {
|
impl ClientMux {
|
||||||
/// Create a new client side multiplexor.
|
/// Create a new client side multiplexor.
|
||||||
pub async fn new<R>(
|
///
|
||||||
|
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
|
pub async fn new<R, W>(
|
||||||
mut read: R,
|
mut read: R,
|
||||||
write: W,
|
write: W,
|
||||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError>
|
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||||
|
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead + Send,
|
||||||
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
{
|
{
|
||||||
let write = ws::LockedWebSocketWrite::new(write);
|
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
||||||
if first_packet.stream_id != 0 {
|
if first_packet.stream_id != 0 {
|
||||||
return Err(WispError::InvalidStreamId);
|
return Err(WispError::InvalidStreamId);
|
||||||
}
|
}
|
||||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||||
let (tx, rx) = mpsc::channel::<WsEvent>(256);
|
let mut supported_extensions = Vec::new();
|
||||||
let map = Arc::new(DashMap::new());
|
let mut extra_packet = Vec::with_capacity(1);
|
||||||
|
let mut downgraded = true;
|
||||||
|
|
||||||
|
if let Some(builders) = extension_builders {
|
||||||
|
let extensions: Vec<_> = builders
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.build_to_extension(Role::Client))
|
||||||
|
.collect();
|
||||||
|
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||||
|
if let Some(frame) = select! {
|
||||||
|
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||||
|
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||||
|
} {
|
||||||
|
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
||||||
|
if let PacketType::Info(info) = packet.packet_type {
|
||||||
|
supported_extensions = info
|
||||||
|
.extensions
|
||||||
|
.into_iter()
|
||||||
|
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||||
|
.collect();
|
||||||
|
write
|
||||||
|
.write_frame(Packet::new_info(extensions).into())
|
||||||
|
.await?;
|
||||||
|
downgraded = false;
|
||||||
|
} else {
|
||||||
|
extra_packet.push(packet.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for extension in supported_extensions.iter_mut() {
|
||||||
|
extension.handle_handshake(&mut read, &write).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
Ok((
|
Ok((
|
||||||
Self {
|
Self {
|
||||||
tx: write.clone(),
|
stream_tx: tx.clone(),
|
||||||
stream_map: map.clone(),
|
downgraded,
|
||||||
next_free_stream_id: AtomicU32::new(1),
|
supported_extension_ids: supported_extensions
|
||||||
close_tx: tx.clone(),
|
.iter()
|
||||||
buf_size: packet.buffer_remaining,
|
.map(|x| x.get_id())
|
||||||
// server-only
|
.collect(),
|
||||||
target_buf_size: 0,
|
|
||||||
},
|
},
|
||||||
MuxInner {
|
MuxInner {
|
||||||
tx: write.clone(),
|
tx: write,
|
||||||
stream_map: map.clone(),
|
stream_map: DashMap::new(),
|
||||||
|
buffer_size: packet.buffer_remaining,
|
||||||
}
|
}
|
||||||
.client_into_future(read, rx),
|
.client_into_future(
|
||||||
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
supported_extensions,
|
||||||
|
rx,
|
||||||
|
tx,
|
||||||
|
),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Err(WispError::InvalidPacketType)
|
Err(WispError::InvalidPacketType)
|
||||||
|
@ -458,51 +696,32 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
||||||
|
|
||||||
/// Create a new stream, multiplexed through Wisp.
|
/// Create a new stream, multiplexed through Wisp.
|
||||||
pub async fn client_new_stream(
|
pub async fn client_new_stream(
|
||||||
&self,
|
&mut self,
|
||||||
stream_type: StreamType,
|
stream_type: StreamType,
|
||||||
host: String,
|
host: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
) -> Result<MuxStream, WispError> {
|
) -> Result<MuxStream, WispError> {
|
||||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
if stream_type == StreamType::Udp
|
||||||
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
|
&& !self
|
||||||
let next_stream_id = stream_id
|
.supported_extension_ids
|
||||||
.checked_add(1)
|
.iter()
|
||||||
.ok_or(WispError::MaxStreamCountReached)?;
|
.any(|x| *x == UdpProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
return Err(WispError::UdpExtensionNotSupported);
|
||||||
|
}
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
self.stream_tx
|
||||||
|
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||||
|
}
|
||||||
|
|
||||||
let flow_control_event: Arc<Event> = Event::new().into();
|
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
|
self.stream_tx
|
||||||
|
.send_async(WsEvent::EndFut(reason))
|
||||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
self.tx
|
|
||||||
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
self.next_free_stream_id
|
|
||||||
.store(next_stream_id, Ordering::Release);
|
|
||||||
|
|
||||||
self.stream_map.insert(
|
|
||||||
stream_id,
|
|
||||||
MuxMapValue {
|
|
||||||
stream: ch_tx,
|
|
||||||
stream_type,
|
|
||||||
flow_control: flow_control.clone(),
|
|
||||||
flow_control_event: flow_control_event.clone(),
|
|
||||||
is_closed: is_closed.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(MuxStream::new(
|
|
||||||
stream_id,
|
|
||||||
Role::Client,
|
|
||||||
stream_type,
|
|
||||||
ch_rx,
|
|
||||||
self.close_tx.clone(),
|
|
||||||
is_closed,
|
|
||||||
flow_control,
|
|
||||||
flow_control_event,
|
|
||||||
self.target_buf_size,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
|
@ -510,9 +729,21 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
||||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||||
/// function.
|
/// function.
|
||||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||||
self.close_tx
|
self.close_internal(None).await
|
||||||
.send(WsEvent::EndFut)
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||||
|
/// function.
|
||||||
|
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ClientMux {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.stream_tx.send(WsEvent::EndFut(None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,23 +1,39 @@
|
||||||
use crate::{ws, WispError};
|
use crate::{
|
||||||
|
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
||||||
|
ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead},
|
||||||
|
Role, WispError, WISP_VERSION,
|
||||||
|
};
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
|
|
||||||
/// Wisp stream type.
|
/// Wisp stream type.
|
||||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||||
pub enum StreamType {
|
pub enum StreamType {
|
||||||
/// TCP Wisp stream.
|
/// TCP Wisp stream.
|
||||||
Tcp = 0x01,
|
Tcp,
|
||||||
/// UDP Wisp stream.
|
/// UDP Wisp stream.
|
||||||
Udp = 0x02,
|
Udp,
|
||||||
|
/// Unknown Wisp stream type used for custom streams by protocol extensions.
|
||||||
|
Unknown(u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<u8> for StreamType {
|
impl From<u8> for StreamType {
|
||||||
type Error = WispError;
|
fn from(value: u8) -> Self {
|
||||||
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
|
use StreamType as S;
|
||||||
use StreamType::*;
|
match value {
|
||||||
match stream_type {
|
0x01 => S::Tcp,
|
||||||
0x01 => Ok(Tcp),
|
0x02 => S::Udp,
|
||||||
0x02 => Ok(Udp),
|
x => S::Unknown(x),
|
||||||
_ => Err(Self::Error::InvalidStreamType),
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StreamType> for u8 {
|
||||||
|
fn from(value: StreamType) -> Self {
|
||||||
|
use StreamType as S;
|
||||||
|
match value {
|
||||||
|
S::Tcp => 0x01,
|
||||||
|
S::Udp => 0x02,
|
||||||
|
S::Unknown(x) => x,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,6 +50,8 @@ pub enum CloseReason {
|
||||||
Voluntary = 0x02,
|
Voluntary = 0x02,
|
||||||
/// Unexpected stream closure due to a network error.
|
/// Unexpected stream closure due to a network error.
|
||||||
Unexpected = 0x03,
|
Unexpected = 0x03,
|
||||||
|
/// Incompatible extensions. Only used during the handshake.
|
||||||
|
IncompatibleExtensions = 0x04,
|
||||||
/// Stream creation failed due to invalid information.
|
/// Stream creation failed due to invalid information.
|
||||||
ServerStreamInvalidInfo = 0x41,
|
ServerStreamInvalidInfo = 0x41,
|
||||||
/// Stream creation failed due to an unreachable destination host.
|
/// Stream creation failed due to an unreachable destination host.
|
||||||
|
@ -54,21 +72,22 @@ pub enum CloseReason {
|
||||||
|
|
||||||
impl TryFrom<u8> for CloseReason {
|
impl TryFrom<u8> for CloseReason {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
|
fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
|
||||||
use CloseReason::*;
|
use CloseReason as R;
|
||||||
match stream_type {
|
match close_reason {
|
||||||
0x01 => Ok(Unknown),
|
0x01 => Ok(R::Unknown),
|
||||||
0x02 => Ok(Voluntary),
|
0x02 => Ok(R::Voluntary),
|
||||||
0x03 => Ok(Unexpected),
|
0x03 => Ok(R::Unexpected),
|
||||||
0x41 => Ok(ServerStreamInvalidInfo),
|
0x04 => Ok(R::IncompatibleExtensions),
|
||||||
0x42 => Ok(ServerStreamUnreachable),
|
0x41 => Ok(R::ServerStreamInvalidInfo),
|
||||||
0x43 => Ok(ServerStreamConnectionTimedOut),
|
0x42 => Ok(R::ServerStreamUnreachable),
|
||||||
0x44 => Ok(ServerStreamConnectionRefused),
|
0x43 => Ok(R::ServerStreamConnectionTimedOut),
|
||||||
0x47 => Ok(ServerStreamTimedOut),
|
0x44 => Ok(R::ServerStreamConnectionRefused),
|
||||||
0x48 => Ok(ServerStreamBlockedAddress),
|
0x47 => Ok(R::ServerStreamTimedOut),
|
||||||
0x49 => Ok(ServerStreamThrottled),
|
0x48 => Ok(R::ServerStreamBlockedAddress),
|
||||||
0x81 => Ok(ClientUnexpected),
|
0x49 => Ok(R::ServerStreamThrottled),
|
||||||
_ => Err(Self::Error::InvalidStreamType),
|
0x81 => Ok(R::ClientUnexpected),
|
||||||
|
_ => Err(Self::Error::InvalidCloseReason),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,7 +127,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stream_type: bytes.get_u8().try_into()?,
|
stream_type: bytes.get_u8().into(),
|
||||||
destination_port: bytes.get_u16_le(),
|
destination_port: bytes.get_u16_le(),
|
||||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||||
})
|
})
|
||||||
|
@ -118,7 +137,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
||||||
impl From<ConnectPacket> for Bytes {
|
impl From<ConnectPacket> for Bytes {
|
||||||
fn from(packet: ConnectPacket) -> Self {
|
fn from(packet: ConnectPacket) -> Self {
|
||||||
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
|
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||||
encoded.put_u8(packet.stream_type as u8);
|
encoded.put_u8(packet.stream_type.into());
|
||||||
encoded.put_u16_le(packet.destination_port);
|
encoded.put_u16_le(packet.destination_port);
|
||||||
encoded.extend(packet.destination_hostname.bytes());
|
encoded.extend(packet.destination_hostname.bytes());
|
||||||
encoded.freeze()
|
encoded.freeze()
|
||||||
|
@ -198,6 +217,38 @@ impl From<ClosePacket> for Bytes {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wisp version sent in the handshake.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WispVersion {
|
||||||
|
/// Major Wisp version according to semver.
|
||||||
|
pub major: u8,
|
||||||
|
/// Minor Wisp version according to semver.
|
||||||
|
pub minor: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Packet used in the initial handshake.
|
||||||
|
///
|
||||||
|
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct InfoPacket {
|
||||||
|
/// Wisp version sent in the packet.
|
||||||
|
pub version: WispVersion,
|
||||||
|
/// List of protocol extensions sent in the packet.
|
||||||
|
pub extensions: Vec<AnyProtocolExtension>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoPacket> for Bytes {
|
||||||
|
fn from(value: InfoPacket) -> Self {
|
||||||
|
let mut bytes = BytesMut::with_capacity(2);
|
||||||
|
bytes.put_u8(value.version.major);
|
||||||
|
bytes.put_u8(value.version.minor);
|
||||||
|
for extension in value.extensions {
|
||||||
|
bytes.extend(Bytes::from(extension));
|
||||||
|
}
|
||||||
|
bytes.freeze()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Type of packet recieved.
|
/// Type of packet recieved.
|
||||||
pub enum PacketType {
|
pub enum PacketType {
|
||||||
|
@ -209,29 +260,33 @@ pub enum PacketType {
|
||||||
Continue(ContinuePacket),
|
Continue(ContinuePacket),
|
||||||
/// Close packet.
|
/// Close packet.
|
||||||
Close(ClosePacket),
|
Close(ClosePacket),
|
||||||
|
/// Info packet.
|
||||||
|
Info(InfoPacket),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PacketType {
|
impl PacketType {
|
||||||
/// Get the packet type used in the protocol.
|
/// Get the packet type used in the protocol.
|
||||||
pub fn as_u8(&self) -> u8 {
|
pub fn as_u8(&self) -> u8 {
|
||||||
use PacketType::*;
|
use PacketType as P;
|
||||||
match self {
|
match self {
|
||||||
Connect(_) => 0x01,
|
P::Connect(_) => 0x01,
|
||||||
Data(_) => 0x02,
|
P::Data(_) => 0x02,
|
||||||
Continue(_) => 0x03,
|
P::Continue(_) => 0x03,
|
||||||
Close(_) => 0x04,
|
P::Close(_) => 0x04,
|
||||||
|
P::Info(_) => 0x05,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<PacketType> for Bytes {
|
impl From<PacketType> for Bytes {
|
||||||
fn from(packet: PacketType) -> Self {
|
fn from(packet: PacketType) -> Self {
|
||||||
use PacketType::*;
|
use PacketType as P;
|
||||||
match packet {
|
match packet {
|
||||||
Connect(x) => x.into(),
|
P::Connect(x) => x.into(),
|
||||||
Data(x) => x,
|
P::Data(x) => x,
|
||||||
Continue(x) => x.into(),
|
P::Continue(x) => x.into(),
|
||||||
Close(x) => x.into(),
|
P::Close(x) => x.into(),
|
||||||
|
P::Info(x) => x.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -296,26 +351,141 @@ impl Packet {
|
||||||
packet_type: PacketType::Close(ClosePacket::new(reason)),
|
packet_type: PacketType::Close(ClosePacket::new(reason)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
|
||||||
|
Self {
|
||||||
|
stream_id: 0,
|
||||||
|
packet_type: PacketType::Info(InfoPacket {
|
||||||
|
version: WISP_VERSION,
|
||||||
|
extensions,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
|
||||||
|
use PacketType as P;
|
||||||
|
Ok(Self {
|
||||||
|
stream_id: bytes.get_u32_le(),
|
||||||
|
packet_type: match packet_type {
|
||||||
|
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
||||||
|
0x02 => P::Data(bytes),
|
||||||
|
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
||||||
|
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
||||||
|
// 0x05 is handled seperately
|
||||||
|
_ => return Err(WispError::InvalidPacketType),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn maybe_parse_info(
|
||||||
|
frame: Frame,
|
||||||
|
role: Role,
|
||||||
|
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||||
|
) -> Result<Self, WispError> {
|
||||||
|
if !frame.finished {
|
||||||
|
return Err(WispError::WsFrameNotFinished);
|
||||||
|
}
|
||||||
|
if frame.opcode != OpCode::Binary {
|
||||||
|
return Err(WispError::WsFrameInvalidType);
|
||||||
|
}
|
||||||
|
let mut bytes = frame.payload;
|
||||||
|
if bytes.remaining() < 1 {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
let packet_type = bytes.get_u8();
|
||||||
|
if packet_type == 0x05 {
|
||||||
|
Self::parse_info(bytes, role, extension_builders)
|
||||||
|
} else {
|
||||||
|
Self::parse_packet(packet_type, bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn maybe_handle_extension(
|
||||||
|
frame: Frame,
|
||||||
|
extensions: &mut [AnyProtocolExtension],
|
||||||
|
read: &mut (dyn WebSocketRead + Send),
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
) -> Result<Option<Self>, WispError> {
|
||||||
|
if !frame.finished {
|
||||||
|
return Err(WispError::WsFrameNotFinished);
|
||||||
|
}
|
||||||
|
if frame.opcode != OpCode::Binary {
|
||||||
|
return Err(WispError::WsFrameInvalidType);
|
||||||
|
}
|
||||||
|
let mut bytes = frame.payload;
|
||||||
|
if bytes.remaining() < 1 {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
let packet_type = bytes.get_u8();
|
||||||
|
if let Some(extension) = extensions
|
||||||
|
.iter_mut()
|
||||||
|
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||||
|
{
|
||||||
|
extension.handle_packet(bytes, read, write).await?;
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(Self::parse_packet(packet_type, bytes)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_info(
|
||||||
|
mut bytes: Bytes,
|
||||||
|
role: Role,
|
||||||
|
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||||
|
) -> Result<Self, WispError> {
|
||||||
|
// packet type is already read by code that calls this
|
||||||
|
if bytes.remaining() < 4 + 2 {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
if bytes.get_u32_le() != 0 {
|
||||||
|
return Err(WispError::InvalidStreamId);
|
||||||
|
}
|
||||||
|
|
||||||
|
let version = WispVersion {
|
||||||
|
major: bytes.get_u8(),
|
||||||
|
minor: bytes.get_u8(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if version.major != WISP_VERSION.major {
|
||||||
|
return Err(WispError::IncompatibleProtocolVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut extensions = Vec::new();
|
||||||
|
|
||||||
|
while bytes.remaining() > 4 {
|
||||||
|
// We have some extensions
|
||||||
|
let id = bytes.get_u8();
|
||||||
|
let length = usize::try_from(bytes.get_u32_le())?;
|
||||||
|
if bytes.remaining() < length {
|
||||||
|
return Err(WispError::PacketTooSmall);
|
||||||
|
}
|
||||||
|
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
|
||||||
|
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
|
||||||
|
extensions.push(extension)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bytes.advance(length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stream_id: 0,
|
||||||
|
packet_type: PacketType::Info(InfoPacket {
|
||||||
|
version,
|
||||||
|
extensions,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Bytes> for Packet {
|
impl TryFrom<Bytes> for Packet {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||||
if bytes.remaining() < 5 {
|
if bytes.remaining() < 1 {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
let packet_type = bytes.get_u8();
|
let packet_type = bytes.get_u8();
|
||||||
use PacketType::*;
|
Self::parse_packet(packet_type, bytes)
|
||||||
Ok(Self {
|
|
||||||
stream_id: bytes.get_u32_le(),
|
|
||||||
packet_type: match packet_type {
|
|
||||||
0x01 => Connect(ConnectPacket::try_from(bytes)?),
|
|
||||||
0x02 => Data(bytes),
|
|
||||||
0x03 => Continue(ContinuePacket::try_from(bytes)?),
|
|
||||||
0x04 => Close(ClosePacket::try_from(bytes)?),
|
|
||||||
_ => return Err(Self::Error::InvalidPacketType),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
//! futures sink unfold with a close function
|
//! futures sink unfold with a close function
|
||||||
use core::{future::Future, pin::Pin};
|
use core::{future::Future, pin::Pin};
|
||||||
use futures::ready;
|
use futures::{
|
||||||
use futures::task::{Context, Poll};
|
ready,
|
||||||
use futures::Sink;
|
task::{Context, Poll},
|
||||||
|
Sink,
|
||||||
|
};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
|
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
|
||||||
|
|
||||||
use async_io_stream::IoStream;
|
pub use async_io_stream::IoStream;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use event_listener::Event;
|
use event_listener::Event;
|
||||||
|
use flume as mpsc;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::oneshot,
|
||||||
stream,
|
select, stream,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
Sink, SinkExt, Stream, StreamExt,
|
FutureExt, Sink, Stream,
|
||||||
};
|
};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -21,7 +22,13 @@ use std::{
|
||||||
pub(crate) enum WsEvent {
|
pub(crate) enum WsEvent {
|
||||||
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
|
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
|
||||||
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
||||||
EndFut,
|
CreateStream(
|
||||||
|
StreamType,
|
||||||
|
String,
|
||||||
|
u16,
|
||||||
|
oneshot::Sender<Result<MuxStream, WispError>>,
|
||||||
|
),
|
||||||
|
EndFut(Option<CloseReason>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read side of a multiplexor stream.
|
/// Read side of a multiplexor stream.
|
||||||
|
@ -32,8 +39,9 @@ pub struct MuxStreamRead {
|
||||||
pub stream_type: StreamType,
|
pub stream_type: StreamType,
|
||||||
role: Role,
|
role: Role,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
tx: mpsc::Sender<WsEvent>,
|
||||||
rx: mpsc::UnboundedReceiver<Bytes>,
|
rx: mpsc::Receiver<Bytes>,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
|
is_closed_event: Arc<Event>,
|
||||||
flow_control: Arc<AtomicU32>,
|
flow_control: Arc<AtomicU32>,
|
||||||
flow_control_read: AtomicU32,
|
flow_control_read: AtomicU32,
|
||||||
target_flow_control: u32,
|
target_flow_control: u32,
|
||||||
|
@ -45,13 +53,16 @@ impl MuxStreamRead {
|
||||||
if self.is_closed.load(Ordering::Acquire) {
|
if self.is_closed.load(Ordering::Acquire) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let bytes = self.rx.next().await?;
|
let bytes = select! {
|
||||||
|
x = self.rx.recv_async() => x.ok()?,
|
||||||
|
_ = self.is_closed_event.listen().fuse() => return None
|
||||||
|
};
|
||||||
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
|
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
|
||||||
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
|
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
|
||||||
if val > self.target_flow_control {
|
if val > self.target_flow_control {
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||||
self.tx
|
self.tx
|
||||||
.send(WsEvent::SendPacket(
|
.send_async(WsEvent::SendPacket(
|
||||||
Packet::new_continue(
|
Packet::new_continue(
|
||||||
self.stream_id,
|
self.stream_id,
|
||||||
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
||||||
|
@ -101,13 +112,13 @@ impl MuxStreamWrite {
|
||||||
}
|
}
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||||
self.tx
|
self.tx
|
||||||
.send(WsEvent::SendPacket(
|
.send_async(WsEvent::SendPacket(
|
||||||
Packet::new_data(self.stream_id, data),
|
Packet::new_data(self.stream_id, data),
|
||||||
tx,
|
tx,
|
||||||
))
|
))
|
||||||
.await
|
.await
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||||
self.flow_control.store(
|
self.flow_control.store(
|
||||||
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
|
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
|
||||||
|
@ -145,13 +156,13 @@ impl MuxStreamWrite {
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||||
self.tx
|
self.tx
|
||||||
.send(WsEvent::Close(
|
.send_async(WsEvent::Close(
|
||||||
Packet::new_close(self.stream_id, reason),
|
Packet::new_close(self.stream_id, reason),
|
||||||
tx,
|
tx,
|
||||||
))
|
))
|
||||||
.await
|
.await
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -173,6 +184,19 @@ impl MuxStreamWrite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for MuxStreamWrite {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.is_closed.load(Ordering::Acquire) {
|
||||||
|
self.is_closed.store(true, Ordering::Release);
|
||||||
|
let (tx, _) = oneshot::channel();
|
||||||
|
let _ = self.tx.send(WsEvent::Close(
|
||||||
|
Packet::new_close(self.stream_id, CloseReason::Unknown),
|
||||||
|
tx,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Multiplexor stream.
|
/// Multiplexor stream.
|
||||||
pub struct MuxStream {
|
pub struct MuxStream {
|
||||||
/// ID of the stream.
|
/// ID of the stream.
|
||||||
|
@ -187,9 +211,10 @@ impl MuxStream {
|
||||||
stream_id: u32,
|
stream_id: u32,
|
||||||
role: Role,
|
role: Role,
|
||||||
stream_type: StreamType,
|
stream_type: StreamType,
|
||||||
rx: mpsc::UnboundedReceiver<Bytes>,
|
rx: mpsc::Receiver<Bytes>,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
tx: mpsc::Sender<WsEvent>,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
|
is_closed_event: Arc<Event>,
|
||||||
flow_control: Arc<AtomicU32>,
|
flow_control: Arc<AtomicU32>,
|
||||||
continue_recieved: Arc<Event>,
|
continue_recieved: Arc<Event>,
|
||||||
target_flow_control: u32,
|
target_flow_control: u32,
|
||||||
|
@ -203,6 +228,7 @@ impl MuxStream {
|
||||||
tx: tx.clone(),
|
tx: tx.clone(),
|
||||||
rx,
|
rx,
|
||||||
is_closed: is_closed.clone(),
|
is_closed: is_closed.clone(),
|
||||||
|
is_closed_event: is_closed_event.clone(),
|
||||||
flow_control: flow_control.clone(),
|
flow_control: flow_control.clone(),
|
||||||
flow_control_read: AtomicU32::new(0),
|
flow_control_read: AtomicU32::new(0),
|
||||||
target_flow_control,
|
target_flow_control,
|
||||||
|
@ -282,13 +308,13 @@ impl MuxStreamCloser {
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||||
self.close_channel
|
self.close_channel
|
||||||
.send(WsEvent::Close(
|
.send_async(WsEvent::Close(
|
||||||
Packet::new_close(self.stream_id, reason),
|
Packet::new_close(self.stream_id, reason),
|
||||||
tx,
|
tx,
|
||||||
))
|
))
|
||||||
.await
|
.await
|
||||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -317,7 +343,10 @@ impl MuxStreamIo {
|
||||||
impl Stream for MuxStreamIo {
|
impl Stream for MuxStreamIo {
|
||||||
type Item = Result<Vec<u8>, std::io::Error>;
|
type Item = Result<Vec<u8>, std::io::Error>;
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec())))
|
self.project()
|
||||||
|
.rx
|
||||||
|
.poll_next(cx)
|
||||||
|
.map(|x| x.map(|x| Ok(x.to_vec())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,10 @@
|
||||||
//! for other WebSocket implementations.
|
//! for other WebSocket implementations.
|
||||||
//!
|
//!
|
||||||
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
||||||
|
use crate::WispError;
|
||||||
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::lock::Mutex;
|
use futures::lock::Mutex;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Opcode of the WebSocket frame.
|
/// Opcode of the WebSocket frame.
|
||||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||||
|
@ -64,40 +65,55 @@ impl Frame {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic WebSocket read trait.
|
/// Generic WebSocket read trait.
|
||||||
|
#[async_trait]
|
||||||
pub trait WebSocketRead {
|
pub trait WebSocketRead {
|
||||||
/// Read a frame from the socket.
|
/// Read a frame from the socket.
|
||||||
fn wisp_read_frame(
|
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>;
|
||||||
&mut self,
|
|
||||||
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
|
||||||
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic WebSocket write trait.
|
/// Generic WebSocket write trait.
|
||||||
|
#[async_trait]
|
||||||
pub trait WebSocketWrite {
|
pub trait WebSocketWrite {
|
||||||
/// Write a frame to the socket.
|
/// Write a frame to the socket.
|
||||||
fn wisp_write_frame(
|
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
|
||||||
&mut self,
|
|
||||||
frame: Frame,
|
/// Close the socket.
|
||||||
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
|
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Locked WebSocket that can be shared between threads.
|
/// Locked WebSocket.
|
||||||
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
|
pub struct LockedWebSocketWrite(Mutex<Box<dyn WebSocketWrite + Send>>);
|
||||||
|
|
||||||
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
|
impl LockedWebSocketWrite {
|
||||||
/// Create a new locked websocket.
|
/// Create a new locked websocket.
|
||||||
pub fn new(ws: S) -> Self {
|
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
||||||
Self(Arc::new(Mutex::new(ws)))
|
Self(Mutex::new(ws))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write a frame to the websocket.
|
/// Write a frame to the websocket.
|
||||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
|
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
|
||||||
self.0.lock().await.wisp_write_frame(frame).await
|
self.0.lock().await.wisp_write_frame(frame).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Close the websocket.
|
||||||
|
pub async fn close(&self) -> Result<(), WispError> {
|
||||||
|
self.0.lock().await.wisp_close().await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
|
pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)
|
||||||
fn clone(&self) -> Self {
|
where
|
||||||
Self(self.0.clone())
|
R: WebSocketRead + Send;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<R> WebSocketRead for AppendingWebSocketRead<R>
|
||||||
|
where
|
||||||
|
R: WebSocketRead + Send,
|
||||||
|
{
|
||||||
|
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
||||||
|
if let Some(x) = self.0.pop() {
|
||||||
|
return Ok(x);
|
||||||
|
}
|
||||||
|
return self.1.wisp_read_frame(tx).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue