Merge pull request #1 from MercuryWorkshop/wisp

Wisp
This commit is contained in:
Toshit 2024-02-07 16:50:51 -08:00 committed by GitHub
commit 21d847fb56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2701 additions and 746 deletions

11
.gitignore vendored
View file

@ -1,9 +1,12 @@
/target
server/src/*.pem
**/*.pem
client/pkg
client/out
.direnv
client/index.js
client/module.js
client/module.d.ts
client/epoxy-bundled.js
client/epoxy-module-bundled.js
client/epoxy-module-bundled.d.ts
client/epoxy.js
client/epoxy.d.ts
client/epoxy.wasm
pnpm-lock.yaml

764
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,11 @@
[workspace]
resolver = "2"
members = ["server", "client"]
members = ["server", "client", "wisp", "simple-wisp-client"]
[patch.crates-io]
rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
[profile.release]
lto = true
opt-level = 'z'
codegen-units = 1

View file

@ -1,29 +1,55 @@
# epoxy
Epoxy is an encrypted proxy for browser javascript. It allows you to make requests that bypass cors without compromising security, by running SSL/TLS inside webassembly.
Simple usage example for making a secure GET request to httpbin.org:
## Using the client
Epoxy must be run from within a web worker and must be served with the [security headers needed for `SharedArrayBuffer`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer#security_requirements). Here is a simple usage example:
```javascript
import epoxy from "@mercuryworkshop/epoxy-tls";
importScripts("epoxy-bundled.js");
const { EpoxyClient } = await epoxy();
let client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10);
let response = await client.fetch("https://httpbin.org/get");
await response.text();
```
Epoxy also allows you to make arbitrary end to end encrypted TCP connections safely directly from the browser.
## Using the server
```
$ cargo r -r --bin epoxy-server -- --help
Implementation of the Wisp protocol in Rust, made for epoxy.
Usage: epoxy-server [OPTIONS] --pubkey <PUBKEY> --privkey <PRIVKEY>
Options:
--prefix <PREFIX> [default: ]
-l, --port <PORT> [default: 4000]
-p, --pubkey <PUBKEY>
-P, --privkey <PRIVKEY>
-h, --help Print help
-V, --version Print version
```
## Building
Rust nightly is required.
### Server
1. Generate certs with `mkcert` and place the public certificate in `./server/src/pem.pem` and private certificate in `./server/src/key.pem`
2. Run `cargo r --bin epoxy-server`, optionally with `-r` flag for release
```
cargo b -r --bin epoxy-server
```
The executable will be placed at `target/release/epoxy-server`.
### Client
Note: Building the client is only supported on linux
> [!IMPORTANT]
> Building the client is only supported on Linux.
1. Make sure you have the `wasm32-unknown-unknown` target installed, `wasm-bindgen` and `wasm-opt` executables installed, and `bash`, `python3` packages (`python3` is used for `http.server` module)
2. Run `pnpm build`
Make sure you have the `wasm32-unknown-unknown` rust target, the `rust-std` component, and the `wasm-bindgen`, `wasm-opt`, and `base64` binaries installed.
In the `client` directory:
```
bash build.sh
```
To host a local server with the required headers:
```
python3 serve.py
```

View file

@ -6,18 +6,12 @@ edition = "2021"
[lib]
crate-type = ["cdylib"]
[features]
default = ["console_error_panic_hook"]
[dependencies]
bytes = "1.5.0"
console_error_panic_hook = { version = "0.1.7", optional = true }
http = "1.0.0"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["client", "http1"] }
hyper = { version = "1.1.0", features = ["client", "http1", "http2"] }
pin-project-lite = "0.2.13"
penguin-mux-wasm = { git = "https://github.com/r58Playz/penguin-mux-wasm" }
tokio = { version = "1.35.1", default_features = false }
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4.39"
ws_stream_wasm = { version = "0.7.4", features = ["tokio_io"] }
@ -25,17 +19,19 @@ futures-util = "0.3.30"
js-sys = "0.3.66"
webpki-roots = "0.26.0"
tokio-rustls = "0.25.0"
web-sys = { version = "0.3.66", features = ["TextEncoder", "Navigator", "Response", "ResponseInit"] }
web-sys = { version = "0.3.66", features = ["TextEncoder", "Response", "ResponseInit"] }
wasm-streams = "0.4.0"
either = "1.9.0"
tokio-util = { version = "0.7.10", features = ["io"] }
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] }
rand = "0.8.5"
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
base64 = "0.21.7"
[dependencies.getrandom]
features = ["js"]
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io", "hyper_tower"] }
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
getrandom = { version = "0.2.12", features = ["js"] }
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] }
tokio = { version = "1.36.0", default-features = false }
tower-service = "0.3.2"
console_error_panic_hook = "0.1.7"
[dependencies.ring]
features = ["wasm32_unknown_unknown_js"]

View file

@ -5,27 +5,33 @@ shopt -s inherit_errexit
rm -rf out/ || true
mkdir out/
cargo build --target wasm32-unknown-unknown --release
echo "[ws] built rust"
RUSTFLAGS='-C target-feature=+atomics,+bulk-memory' cargo build --target wasm32-unknown-unknown -Z build-std=panic_abort,std --release
echo "[ws] cargo finished"
wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir out/ ../target/wasm32-unknown-unknown/release/epoxy_client.wasm
echo "[ws] bindgen finished"
echo "[ws] wasm-bindgen finished"
mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm
time wasm-opt -O4 out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm
echo "[ws] optimized"
time wasm-opt -Oz --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm
echo "[ws] wasm-opt finished"
AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js")
# patch for websocket sharedarraybuffer error
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//getObject(arg0).send(getArrayU8FromWasm0(arg1, arg2)/getObject(arg0).send(new Uint8Array(getArrayU8FromWasm0(arg1, arg2))}
WASM_BASE64=$(base64 -w0 out/epoxy_client_bg.wasm)
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input) \{/__wbg_init() \{let input=\'data:application/wasm;base64,$WASM_BASE64\'}
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input, maybe_memory) \{/__wbg_init(input, maybe_memory) \{$'\n'if (!input) \{input=\'data:application/wasm;base64,$WASM_BASE64\'\}}
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//return __wbg_finalize_init\(instance\, module\);/__wbg_finalize_init\(instance\, module\); return epoxy}
echo "$AUTOGENERATED_SOURCE" > index.js
cp index.js module.js
echo "module.exports = epoxy" >> module.js
echo "$AUTOGENERATED_SOURCE" > epoxy-bundled.js
cp epoxy-bundled.js epoxy-module-bundled.js
echo "module.exports = epoxy" >> epoxy-module-bundled.js
AUTOGENERATED_TYPEDEFS=$(<"out/epoxy_client.d.ts")
AUTOGENERATED_TYPEDEFS=${AUTOGENERATED_TYPEDEFS%%export class IntoUnderlyingByteSource*}
echo "$AUTOGENERATED_TYPEDEFS" >"module.d.ts"
echo "} export default function epoxy(): Promise<typeof wasm_bindgen>;" >> "module.d.ts"
echo "$AUTOGENERATED_TYPEDEFS" >"epoxy-module-bundled.d.ts"
echo "} export default function epoxy(): Promise<typeof wasm_bindgen>;" >> "epoxy-module-bundled.d.ts"
cp out/epoxy_client.js epoxy.js
cp out/epoxy_client.d.ts epoxy.d.ts
cp out/epoxy_client_bg.wasm epoxy.wasm
rm -rf out/
echo "[ws] done!"

View file

@ -1,21 +1,38 @@
(async () => {
importScripts("epoxy-bundled.js");
onmessage = async (msg) => {
console.debug("recieved:", msg);
let [should_feature_test, should_multiparallel_test, should_parallel_test, should_multiperf_test, should_perf_test, should_ws_test, should_tls_test] = msg.data;
console.log(
"%cWASM is significantly slower with DevTools open!",
"color:red;font-size:3rem;font-weight:bold"
);
const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test");
const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test");
const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test");
const log = (str) => {
console.warn(str);
postMessage(str);
}
let { EpoxyClient } = await epoxy();
const { EpoxyClient } = await epoxy();
const tconn0 = performance.now();
// args: websocket url, user agent, redirect limit
let epoxy_client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10);
const tconn1 = performance.now();
console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
const test_mux = async (url) => {
const t0 = performance.now();
await epoxy_client.fetch(url);
const t1 = performance.now();
return t1 - t0;
};
const test_native = async (url) => {
const t0 = performance.now();
await fetch(url, { cache: "no-store" });
const t1 = performance.now();
return t1 - t0;
};
if (should_feature_test) {
for (const url of [
@ -23,44 +40,102 @@
["https://httpbin.org/gzip", {}],
["https://httpbin.org/brotli", {}],
["https://httpbin.org/redirect/11", {}],
["https://httpbin.org/redirect/1", { redirect: "manual" }]
["https://httpbin.org/redirect/1", { redirect: "manual" }],
]) {
let resp = await epoxy_client.fetch(url[0], url[1]);
console.warn(url, resp, Object.fromEntries(resp.headers));
console.warn(await resp.text());
}
} else if (should_perf_test) {
const test_mux = async (url) => {
const t0 = performance.now();
await epoxy_client.fetch(url);
const t1 = performance.now();
return t1 - t0;
};
} else if (should_multiparallel_test) {
const num_tests = 10;
let total_mux_minus_native = 0;
for (const _ of Array(num_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_tests;
const test_native = async (url) => {
const t0 = performance.now();
await fetch(url);
const t1 = performance.now();
return t1 - t0;
};
let total_native = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running native test ${i}`);
return await test_native("https://httpbin.org/get");
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
total_native = total_native / num_tests;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
total_mux_minus_native += total_mux - total_native;
}
total_mux_minus_native = total_mux_minus_native / num_tests;
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
} else if (should_parallel_test) {
const num_tests = 10;
let total_mux = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_tests;
let total_native = 0;
await Promise.all([...Array(num_tests).keys()].map(async i => {
log(`running native test ${i}`);
return await test_native("https://httpbin.org/get");
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
total_native = total_native / num_tests;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
} else if (should_multiperf_test) {
const num_tests = 10;
let total_mux_minus_native = 0;
for (const _ of Array(num_tests).keys()) {
let total_mux = 0;
for (const i of Array(num_tests).keys()) {
log(`running mux test ${i}`);
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_native = 0;
for (const i of Array(num_tests).keys()) {
log(`running native test ${i}`);
total_native += await test_native("https://httpbin.org/get");
}
total_native = total_native / num_tests;
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
total_mux_minus_native += total_mux - total_native;
}
total_mux_minus_native = total_mux_minus_native / num_tests;
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
} else if (should_perf_test) {
const num_tests = 10;
let total_mux = 0;
for (const i of Array(num_tests).keys()) {
log(`running mux test ${i}`);
total_mux += await test_mux("https://httpbin.org/get");
}
total_mux = total_mux / num_tests;
let total_native = 0;
for (const _ of Array(num_tests).keys()) {
for (const i of Array(num_tests).keys()) {
log(`running native test ${i}`);
total_native += await test_native("https://httpbin.org/get");
}
total_native = total_native / num_tests;
console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`);
console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`);
console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
} else if (should_ws_test) {
let ws = await epoxy_client.connect_ws(
() => console.log("opened"),
@ -75,10 +150,20 @@
await ws.send("data");
await (new Promise((res, _) => setTimeout(res, 100)));
}
} else if (should_tls_test) {
let decoder = new TextDecoder();
let ws = await epoxy_client.connect_tls(
() => console.log("opened"),
() => console.log("closed"),
err => console.error(err),
msg => { console.log(msg); console.log(decoder.decode(msg)) },
"alicesworld.tech:443",
);
await ws.send("GET / HTTP 1.1\r\nHost: alicesworld.tech\r\nConnection: close\r\n\r\n");
} else {
let resp = await epoxy_client.fetch("https://httpbin.org/get");
console.warn(resp, Object.fromEntries(resp.headers));
console.warn(await resp.text());
}
if (!should_ws_test) alert("you can open console now");
})();
log("done");
};

View file

@ -2,12 +2,35 @@
<head>
<title>epoxy</title>
<script src="index.js"></script>
<script src="demo.js"></script>
<style>
body { font-family: sans-serif }
#logs > * { font-family: monospace }
</style>
<script>
const params = (new URL(window.location.href)).searchParams;
const should_feature_test = params.has("feature_test");
const should_multiparallel_test = params.has("multi_parallel_test");
const should_parallel_test = params.has("parallel_test");
const should_multiperf_test = params.has("multi_perf_test");
const should_perf_test = params.has("perf_test");
const should_ws_test = params.has("ws_test");
const should_tls_test = params.has("rawtls_test");
const worker = new Worker("demo.js");
worker.onmessage = (msg) => {
let el = document.createElement("div");
el.textContent = msg.data;
document.getElementById("logs").appendChild(el);
window.scrollTo(0, document.body.scrollHeight);
};
worker.postMessage([should_feature_test, should_multiparallel_test, should_parallel_test, should_multiperf_test, should_perf_test, should_ws_test, should_tls_test]);
</script>
</head>
<body>
running... (wait for the browser alert if not running ws test)
<div>
running... (wait for the browser alert if not running ws test)
</div>
<div id="logs"></div>
</body>
</html>

10
client/serve.py Normal file
View file

@ -0,0 +1,10 @@
from http.server import HTTPServer, SimpleHTTPRequestHandler, test
import sys
class RequestHandler (SimpleHTTPRequestHandler):
def end_headers (self):
self.send_header('Cross-Origin-Opener-Policy', 'same-origin')
self.send_header('Cross-Origin-Embedder-Policy', 'require-corp')
SimpleHTTPRequestHandler.end_headers(self)
test(RequestHandler, HTTPServer, port=int(sys.argv[1]) if len(sys.argv) > 1 else 8000)

View file

@ -1,24 +1,25 @@
#![feature(let_chains)]
#![feature(let_chains, impl_trait_in_assoc_type)]
#[macro_use]
mod utils;
mod tokioio;
mod tls_stream;
mod websocket;
mod wrappers;
use tokioio::TokioIo;
use tls_stream::EpxTlsStream;
use utils::{ReplaceErr, UriExt};
use websocket::EpxWebSocket;
use wrappers::{IncomingBody, WsStreamWrapper};
use wrappers::{IncomingBody, TlsWispService};
use std::sync::Arc;
use async_compression::tokio::bufread as async_comp;
use async_io_stream::IoStream;
use bytes::Bytes;
use futures_util::StreamExt;
use futures_util::{stream::SplitSink, StreamExt};
use http::{uri, HeaderName, HeaderValue, Request, Response};
use hyper::{body::Incoming, client::conn::http1::Builder, Uri};
use hyper::{body::Incoming, Uri};
use hyper_util::client::legacy::Client;
use js_sys::{Array, Function, Object, Reflect, Uint8Array};
use penguin_mux_wasm::{Multiplexor, MuxStream};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
use tokio_util::{
either::Either,
@ -26,13 +27,15 @@ use tokio_util::{
};
use wasm_bindgen::prelude::*;
use web_sys::TextEncoder;
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, ClientMux, MuxStreamIo, StreamType};
use ws_stream_wasm::{WsMessage, WsMeta, WsStream};
type HttpBody = http_body_util::Full<Bytes>;
#[derive(Debug)]
enum EpxResponse {
Success(Response<Incoming>),
Redirect((Response<Incoming>, http::Request<HttpBody>, Uri)),
Redirect((Response<Incoming>, http::Request<HttpBody>)),
}
enum EpxCompression {
@ -40,80 +43,20 @@ enum EpxCompression {
Gzip,
}
type EpxTlsStream = TlsStream<MuxStream<WsStreamWrapper>>;
type EpxUnencryptedStream = MuxStream<WsStreamWrapper>;
type EpxStream = Either<EpxTlsStream, EpxUnencryptedStream>;
async fn send_req(
req: http::Request<HttpBody>,
should_redirect: bool,
io: EpxStream,
) -> Result<EpxResponse, JsError> {
let (mut req_sender, conn) = Builder::new()
.title_case_headers(true)
.preserve_header_case(true)
.handshake(TokioIo::new(io))
.await
.replace_err("Failed to connect to host")?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.await {
error!("epoxy: error in muxed hyper connection! {:?}", e);
}
});
let new_req = if should_redirect {
Some(req.clone())
} else {
None
};
let res = req_sender
.send_request(req)
.await
.replace_err("Failed to send request");
match res {
Ok(res) => {
if utils::is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url
.clone()
.authority()
.replace_err("Redirect URL must have an authority")
.ok()
{
let should_strip = new_req.uri().is_same_host(&redirect_url);
if should_strip {
new_req.headers_mut().remove("authorization");
new_req.headers_mut().remove("cookie");
new_req.headers_mut().remove("www-authenticate");
}
let new_url = redirect_url.clone();
*new_req.uri_mut() = redirect_url;
new_req.headers_mut().insert(
"Host",
HeaderValue::from_str(redirect_url_authority.as_str())?,
);
Ok(EpxResponse::Redirect((res, new_req, new_url)))
} else {
Ok(EpxResponse::Success(res))
}
}
Err(err) => Err(err),
}
}
type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
#[wasm_bindgen(start)]
async fn start() {
utils::set_panic_hook();
fn init() {
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>,
mux: Multiplexor<WsStreamWrapper>,
mux: Arc<ClientMux<SplitSink<WsStream, WsMessage>>>,
hyper_client: Client<TlsWispService<SplitSink<WsStream, WsMessage>>, HttpBody>,
useragent: String,
redirect_limit: usize,
}
@ -138,11 +81,19 @@ impl EpoxyClient {
}
debug!("connecting to ws {:?}", ws_url);
let ws = WsStreamWrapper::connect(ws_url, None)
let (_, ws) = WsMeta::connect(ws_url, vec!["wisp-v1"])
.await
.replace_err("Failed to connect to websocket")?;
debug!("connected!");
let mux = Multiplexor::new(ws, penguin_mux_wasm::Role::Client, None, None);
let (wtx, wrx) = ws.split();
let (mux, fut) = ClientMux::new(wrx, wtx).await?;
let mux = Arc::new(mux);
wasm_bindgen_futures::spawn_local(async move {
if let Err(err) = fut.await {
error!("epoxy: error in mux future! {:?}", err);
}
});
let mut certstore = RootCertStore::empty();
certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
@ -154,37 +105,86 @@ impl EpoxyClient {
);
Ok(EpoxyClient {
mux,
mux: mux.clone(),
hyper_client: Client::builder(utils::WasmExecutor {})
.http09_responses(true)
.http1_title_case_headers(true)
.http1_preserve_header_case(true)
.build(TlsWispService {
rustls_config: rustls_config.clone(),
service: ServiceWrapper(mux),
}),
rustls_config,
useragent,
redirect_limit,
})
}
async fn get_http_io(&self, url: &Uri) -> Result<EpxStream, JsError> {
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = utils::get_url_port(url)?;
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> {
let channel = self
.mux
.client_new_stream_channel(url_host.as_bytes(), url_port)
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
.await
.replace_err("Failed to create multiplexor channel")?;
.replace_err("Failed to create multiplexor channel")?
.into_io()
.into_asyncrw();
let cloned_uri = url_host.to_string().clone();
let connector = TlsConnector::from(self.rustls_config.clone());
debug!("connecting channel");
let io = connector
.connect(
cloned_uri
.try_into()
.replace_err("Failed to parse URL (rustls)")?,
channel,
)
.await
.replace_err("Failed to perform TLS handshake")?;
debug!("connected channel");
Ok(io)
}
if utils::get_is_secure(url)? {
let cloned_uri = url_host.to_string().clone();
let connector = TlsConnector::from(self.rustls_config.clone());
let io = connector
.connect(
cloned_uri
.try_into()
.replace_err("Failed to parse URL (rustls)")?,
channel,
)
.await
.replace_err("Failed to perform TLS handshake")?;
Ok(EpxStream::Left(io))
async fn send_req_inner(
&self,
req: http::Request<HttpBody>,
should_redirect: bool,
) -> Result<EpxResponse, JsError> {
let new_req = if should_redirect {
Some(req.clone())
} else {
Ok(EpxStream::Right(channel))
None
};
debug!("sending req");
let res = self
.hyper_client
.request(req)
.await
.replace_err("Failed to send request");
debug!("recieved res");
match res {
Ok(res) => {
if utils::is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url
.clone()
.authority()
.replace_err("Redirect URL must have an authority")
.ok()
{
*new_req.uri_mut() = redirect_url;
new_req.headers_mut().insert(
"Host",
HeaderValue::from_str(redirect_url_authority.as_str())?,
);
Ok(EpxResponse::Redirect((res, new_req)))
} else {
Ok(EpxResponse::Success(res))
}
}
Err(err) => Err(err),
}
}
@ -194,23 +194,22 @@ impl EpoxyClient {
should_redirect: bool,
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> {
let mut redirected = false;
let uri = req.uri().clone();
let mut current_resp: EpxResponse =
send_req(req, should_redirect, self.get_http_io(&uri).await?).await?;
let mut current_url = req.uri().clone();
let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?;
for _ in 0..self.redirect_limit - 1 {
match current_resp {
EpxResponse::Success(_) => break,
EpxResponse::Redirect((_, req, new_url)) => {
EpxResponse::Redirect((_, req)) => {
redirected = true;
current_resp =
send_req(req, should_redirect, self.get_http_io(&new_url).await?).await?
current_url = req.uri().clone();
current_resp = self.send_req_inner(req, should_redirect).await?
}
}
}
match current_resp {
EpxResponse::Success(resp) => Ok((resp, uri, redirected)),
EpxResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)),
EpxResponse::Success(resp) => Ok((resp, current_url, redirected)),
EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)),
}
}
@ -232,6 +231,17 @@ impl EpoxyClient {
.await
}
pub async fn connect_tls(
&self,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxTlsStream, JsError> {
EpxTlsStream::connect(self, onopen, onclose, onerror, onmessage, url).await
}
pub async fn fetch(&self, url: String, options: Object) -> Result<web_sys::Response, JsError> {
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?;
let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?;
@ -292,7 +302,7 @@ impl EpoxyClient {
let headers_map = builder.headers_mut().replace_err("Failed to get headers")?;
headers_map.insert("Accept-Encoding", HeaderValue::from_str("gzip, br")?);
headers_map.insert("Connection", HeaderValue::from_str("close")?);
headers_map.insert("Connection", HeaderValue::from_str("keep-alive")?);
headers_map.insert("User-Agent", HeaderValue::from_str(&self.useragent)?);
headers_map.insert("Host", HeaderValue::from_str(uri_host)?);
if body_bytes.is_empty() {
@ -314,7 +324,7 @@ impl EpoxyClient {
.body(HttpBody::new(body_bytes))
.replace_err("Failed to make request")?;
let (resp, last_url, req_redirected) = self.send_req(request, req_should_redirect).await?;
let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?;
let resp_headers_raw = resp.headers().clone();
@ -378,7 +388,7 @@ impl EpoxyClient {
Object::define_property(
&resp,
&jval!("url"),
&utils::define_property_obj(jval!(last_url.to_string()), false)
&utils::define_property_obj(jval!(resp_uri.to_string()), false)
.replace_err("Failed to make define_property object for url")?,
);

82
client/src/tls_stream.rs Normal file
View file

@ -0,0 +1,82 @@
use crate::*;
use js_sys::Function;
use tokio::io::{split, AsyncWriteExt, WriteHalf};
use tokio_util::io::ReaderStream;
#[wasm_bindgen]
pub struct EpxTlsStream {
tx: WriteHalf<EpxIoTlsStream>,
onerror: Function,
}
#[wasm_bindgen]
impl EpxTlsStream {
#[wasm_bindgen(constructor)]
pub fn new() -> Result<EpxTlsStream, JsError> {
Err(jerr!("Use EpoxyClient.connect_tls() instead."))
}
// shut up
#[allow(clippy::too_many_arguments)]
pub async fn connect(
tcp: &EpoxyClient,
onopen: Function,
onclose: Function,
onerror: Function,
onmessage: Function,
url: String,
) -> Result<EpxTlsStream, JsError> {
let onerr = onerror.clone();
let ret: Result<EpxTlsStream, JsError> = async move {
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = url.port().replace_err("URL must have a port")?.into();
let io = tcp.get_tls_io(url_host, url_port).await?;
let (rx, tx) = split(io);
let mut rx = ReaderStream::new(rx);
wasm_bindgen_futures::spawn_local(async move {
while let Some(Ok(data)) = rx.next().await {
let _ = onmessage.call1(
&JsValue::null(),
&jval!(Uint8Array::from(data.to_vec().as_slice())),
);
}
let _ = onclose.call0(&JsValue::null());
});
onopen
.call0(&Object::default())
.replace_err("Failed to call onopen")?;
Ok(Self { tx, onerror })
}
.await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone()));
Err(ret)
} else {
ret
}
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: Uint8Array) -> Result<(), JsError> {
let onerr = self.onerror.clone();
let ret = self.tx.write_all(&payload.to_vec()).await;
if let Err(ret) = ret {
let _ = onerr.call1(&JsValue::null(), &jval!(format!("{}", ret)));
Err(ret.into())
} else {
Ok(ret?)
}
}
#[wasm_bindgen]
pub async fn close(&mut self) -> Result<(), JsError> {
self.tx.shutdown().await?;
Ok(())
}
}

View file

@ -1,13 +1,9 @@
use wasm_bindgen::prelude::*;
use hyper::rt::Executor;
use hyper::{header::HeaderValue, Uri};
use http::uri;
use js_sys::{Array, Object};
pub fn set_panic_hook() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
use std::future::Future;
#[wasm_bindgen]
extern "C" {
@ -55,15 +51,15 @@ pub trait ReplaceErr {
fn replace_err_jv(self, err: &str) -> Result<Self::Ok, JsValue>;
}
impl<T, E> ReplaceErr for Result<T, E> {
impl<T, E: std::fmt::Debug> ReplaceErr for Result<T, E> {
type Ok = T;
fn replace_err(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsError> {
self.map_err(|_| jerr!(err))
self.map_err(|oe| jerr!(&format!("{}, original error: {:?}", err, oe)))
}
fn replace_err_jv(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsValue> {
self.map_err(|_| jval!(err))
self.map_err(|oe| jval!(&format!("{}, original error: {:?}", err, oe)))
}
}
@ -102,6 +98,21 @@ impl UriExt for Uri {
}
}
#[derive(Clone)]
pub struct WasmExecutor;
impl<F> Executor<F> for WasmExecutor
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
fn execute(&self, future: F) {
wasm_bindgen_futures::spawn_local(async move {
let _ = future.await;
});
}
}
pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
js_sys::Object::entries(obj)
.to_vec()
@ -131,41 +142,19 @@ pub fn is_redirect(code: u16) -> bool {
}
pub fn get_is_secure(url: &Uri) -> Result<bool, JsError> {
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *url_scheme == uri::Scheme::HTTP {
Ok(false)
} else if *url_scheme == uri::Scheme::HTTPS {
Ok(true)
} else if url_scheme_str == "ws" {
Ok(false)
} else if url_scheme_str == "wss" {
Ok(true)
} else {
return Ok(false);
match url_scheme_str {
"https" | "wss" => Ok(true),
_ => Ok(false),
}
}
pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
if let Some(port) = url.port() {
Ok(port.as_u16())
} else if get_is_secure(url)? {
Ok(443)
} else {
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *url_scheme == uri::Scheme::HTTP {
Ok(80)
} else if *url_scheme == uri::Scheme::HTTPS {
Ok(443)
} else if url_scheme_str == "ws" {
Ok(80)
} else if url_scheme_str == "wss" {
Ok(443)
} else {
return Err(jerr!("Failed to coerce port from scheme"));
}
Ok(80)
}
}

View file

@ -4,7 +4,8 @@ use base64::{engine::general_purpose::STANDARD, Engine};
use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
};
use http_body_util::Empty;
use futures_util::lock::Mutex;
use http_body_util::Full;
use hyper::{
header::{CONNECTION, UPGRADE},
upgrade::Upgraded,
@ -16,7 +17,7 @@ use tokio::io::WriteHalf;
#[wasm_bindgen]
pub struct EpxWebSocket {
tx: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
onerror: Function,
}
@ -44,7 +45,8 @@ impl EpxWebSocket {
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
let host = url.host().replace_err("URL must have a host")?;
let rand: [u8; 16] = rand::random();
let mut rand: [u8; 16] = [0; 16];
getrandom::getrandom(&mut rand)?;
let key = STANDARD.encode(rand);
let mut builder = Request::builder()
@ -61,23 +63,9 @@ impl EpxWebSocket {
builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", "));
}
let req = builder.body(Empty::<Bytes>::new())?;
let req = builder.body(Full::<Bytes>::new(Bytes::new()))?;
let stream = tcp.get_http_io(&url).await?;
let (mut sender, conn) = Builder::new()
.title_case_headers(true)
.preserve_header_case(true)
.handshake::<TokioIo<EpxStream>, Empty<Bytes>>(TokioIo::new(stream))
.await?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.with_upgrades().await {
error!("epoxy: error in muxed hyper connection (ws)! {:?}", e);
}
});
let mut response = sender.send_request(req).await?;
let mut response = tcp.hyper_client.request(req).await?;
verify(&response)?;
let ws = WebSocket::after_handshake(
@ -88,16 +76,12 @@ impl EpxWebSocket {
let (rx, tx) = ws.split(tokio::io::split);
let mut rx = FragmentCollectorRead::new(rx);
let tx = Arc::new(Mutex::new(tx));
let tx_cloned = tx.clone();
wasm_bindgen_futures::spawn_local(async move {
while let Ok(frame) = rx
.read_frame(&mut |arg| async move {
error!(
"wtf is an obligated write {:?}, {:?}, {:?}",
arg.fin, arg.opcode, arg.payload
);
Ok::<(), std::io::Error>(())
})
.read_frame(&mut |arg| async { tx_cloned.lock().await.write_frame(arg).await })
.await
{
match frame.opcode {
@ -137,10 +121,12 @@ impl EpxWebSocket {
}
#[wasm_bindgen]
pub async fn send(&mut self, payload: String) -> Result<(), JsError> {
pub async fn send(&self, payload: String) -> Result<(), JsError> {
let onerr = self.onerror.clone();
let ret = self
.tx
.lock()
.await
.write_frame(Frame::text(Payload::Owned(payload.as_bytes().to_vec())))
.await;
if let Err(ret) = ret {
@ -152,8 +138,10 @@ impl EpxWebSocket {
}
#[wasm_bindgen]
pub async fn close(&mut self) -> Result<(), JsError> {
pub async fn close(&self) -> Result<(), JsError> {
self.tx
.lock()
.await
.write_frame(Frame::close(CloseCode::Normal.into(), b""))
.await?;
Ok(())

View file

@ -4,117 +4,11 @@ use std::{
task::{Context, Poll},
};
use futures_util::{Sink, Stream};
use futures_util::Stream;
use hyper::body::Body;
use penguin_mux_wasm::ws;
use pin_project_lite::pin_project;
use ws_stream_wasm::{WsErr, WsMessage, WsMeta, WsStream};
pin_project! {
pub struct WsStreamWrapper {
#[pin]
ws: WsStream,
}
}
impl WsStreamWrapper {
pub async fn connect(
url: impl AsRef<str>,
protocols: impl Into<Option<Vec<&str>>>,
) -> Result<Self, WsErr> {
let (_, wsstream) = WsMeta::connect(url, protocols).await?;
Ok(WsStreamWrapper { ws: wsstream })
}
}
impl Stream for WsStreamWrapper {
type Item = Result<ws::Message, ws::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let ret = this.ws.poll_next(cx);
match ret {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(item.map(|x| {
Ok(match x {
WsMessage::Text(txt) => ws::Message::Text(txt),
WsMessage::Binary(bin) => ws::Message::Binary(bin),
})
})),
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
}
}
}
fn wserr_to_ws_err(err: WsErr) -> ws::Error {
debug!("err: {:?}", err);
match err {
WsErr::ConnectionNotOpen => ws::Error::AlreadyClosed,
_ => ws::Error::Io(std::io::Error::other(format!("{:?}", err))),
}
}
impl Sink<ws::Message> for WsStreamWrapper {
type Error = ws::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
let ret = this.ws.poll_ready(cx);
match ret {
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}),
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
}
}
fn start_send(self: Pin<&mut Self>, item: ws::Message) -> Result<(), Self::Error> {
use ws::Message::*;
let item = match item {
Text(txt) => WsMessage::Text(txt),
Binary(bin) => WsMessage::Binary(bin),
Close(_) => {
debug!("closing");
return match self.ws.wrapped().close() {
Ok(_) => Ok(()),
Err(err) => Err(ws::Error::Io(std::io::Error::other(format!(
"ws close err: {:?}",
err
)))),
};
}
Ping(_) | Pong(_) | Frame(_) => return Ok(()),
};
let this = self.project();
let ret = this.ws.start_send(item);
match ret {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}
}
// no point wrapping this as it's not going to do anything
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
let ret = this.ws.poll_close(cx);
match ret {
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
Ok(_) => Ok(()),
Err(err) => Err(wserr_to_ws_err(err)),
}),
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
}
}
}
impl ws::WebSocketStream for WsStreamWrapper {
fn ping_auto_pong(&self) -> bool {
true
}
}
use std::future::Future;
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, WispError};
pin_project! {
pub struct IncomingBody {
@ -138,7 +32,8 @@ impl Stream for IncomingBody {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item {
Some(frame) => frame
.map(|x| {
x.into_data().map_err(|_| std::io::Error::other("not data frame"))
x.into_data()
.map_err(|_| std::io::Error::other("not data frame"))
})
.ok(),
None => None,
@ -147,3 +42,68 @@ impl Stream for IncomingBody {
}
}
}
pub struct TlsWispService<W>
where
W: wisp_mux::ws::WebSocketWrite + Send + 'static,
{
pub service: ServiceWrapper<W>,
pub rustls_config: Arc<rustls::ClientConfig>,
}
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri>
for TlsWispService<W>
{
type Response = TokioIo<EpxIoStream>;
type Error = WispError;
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: http::Uri) -> Self::Future {
let mut service = self.service.clone();
let rustls_config = self.rustls_config.clone();
Box::pin(async move {
let uri_host = req
.host()
.ok_or(WispError::UriHasNoHost)?
.to_string()
.clone();
let uri_parsed = Uri::builder()
.authority(format!(
"{}:{}",
uri_host,
utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)?
))
.build()
.map_err(|x| WispError::Other(Box::new(x)))?;
let stream = service.call(uri_parsed).await?.into_inner();
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
let connector = TlsConnector::from(rustls_config);
Ok(TokioIo::new(Either::Left(
connector
.connect(
uri_host.try_into().map_err(|_| WispError::InvalidUri)?,
stream,
)
.await
.map_err(|x| WispError::Other(Box::new(x)))?,
)))
} else {
Ok(TokioIo::new(Either::Right(stream)))
}
})
}
}
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> Clone for TlsWispService<W> {
fn clone(&self) -> Self {
Self {
rustls_config: self.rustls_config.clone(),
service: self.service.clone(),
}
}
}

View file

@ -16,8 +16,8 @@
"author": "MercuryWorkshop",
"repository": "https://github.com/MercuryWorkshop/epoxy-tls",
"license": "MIT",
"browser": "./client/module.js",
"module": "./client/module.js",
"main": "./client/module.js",
"types": "./client/module.d.ts"
"browser": "./client/epoxy-module-bundled.js",
"module": "./client/epoxy-module-bundled.js",
"main": "./client/epoxy-module-bundled.js",
"types": "./client/epoxy-module-bundled.d.ts"
}

View file

@ -4,10 +4,16 @@ version = "1.0.0"
edition = "2021"
[dependencies]
bytes = "1.5.0"
clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] }
clio = { version = "0.3.5", features = ["clap-parse"] }
dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] }
futures-util = { version = "0.3.30", features = ["sink"] }
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["server", "http1"] }
hyper-util = { version = "0.1.2", features = ["tokio"] }
rusty-penguin = { version = "0.5.3", default-features = false }
tokio = { version = "1.35.1", features = ["rt-multi-thread", "net", "macros"] }
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
tokio-native-tls = "0.3.1"
tokio-tungstenite = "0.21.0"
tokio-util = { version = "0.7.10", features = ["codec"] }
wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] }

View file

@ -1,176 +1,283 @@
use std::{convert::Infallible, env, net::SocketAddr, sync::Arc};
#![feature(let_chains)]
use std::io::{Error, Read};
use bytes::Bytes;
use clap::Parser;
use fastwebsockets::{
upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
WebSocketError,
};
use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version,
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use penguin_mux::{Multiplexor, MuxStream};
use tokio::{
net::{TcpListener, TcpStream},
task::{JoinError, JoinSet},
};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_tungstenite::{
tungstenite::{handshake::derive_accept_key, protocol::Role},
WebSocketStream,
};
use tokio_util::codec::{BytesCodec, Framed};
type Body = http_body_util::Empty<hyper::body::Bytes>;
use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent};
type MultiplexorStream = MuxStream<WebSocketStream<TokioIo<Upgraded>>>;
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
async fn forward(mut stream: MultiplexorStream) -> Result<(), JoinError> {
println!("forwarding");
let host = std::str::from_utf8(&stream.dest_host).unwrap();
let mut tcp_stream = TcpStream::connect((host, stream.dest_port)).await.unwrap();
println!("connected to {:?}", tcp_stream.peer_addr().unwrap());
tokio::io::copy_bidirectional(&mut stream, &mut tcp_stream)
#[derive(Parser)]
#[command(version = clap::crate_version!(), about = "Implementation of the Wisp protocol in Rust, made for epoxy.")]
struct Cli {
#[arg(long, default_value = "")]
prefix: String,
#[arg(
long = "port",
short = 'l',
value_name = "PORT",
default_value = "4000"
)]
listen_port: String,
#[arg(long, short, value_parser)]
pubkey: clio::Input,
#[arg(long, short = 'P', value_parser)]
privkey: clio::Input,
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Error> {
let mut opt = Cli::parse();
let mut pem = Vec::new();
opt.pubkey.read_to_end(&mut pem)?;
let mut key = Vec::new();
opt.privkey.read_to_end(&mut key)?;
let identity = native_tls::Identity::from_pkcs8(&pem, &key).expect("failed to make identity");
let socket = TcpListener::bind(format!("0.0.0.0:{}", opt.listen_port))
.await
.unwrap();
println!("finished");
Ok(())
}
async fn handle_connection(ws_stream: WebSocketStream<TokioIo<Upgraded>>, addr: SocketAddr) {
println!("WebSocket connection established: {}", addr);
let mux = Multiplexor::new(ws_stream, penguin_mux::Role::Server, None, None);
let mut jobs = JoinSet::new();
println!("muxing");
loop {
tokio::select! {
Some(result) = jobs.join_next() => {
match result {
Ok(Ok(())) => {}
Ok(Err(err)) | Err(err) => eprintln!("failed to forward: {:?}", err),
}
}
Ok(result) = mux.server_new_stream_channel() => {
jobs.spawn(forward(result));
}
else => {
break;
}
}
}
println!("{} disconnected", &addr);
}
async fn handle_request(
mut req: Request<Incoming>,
addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
let headers = req.headers();
let derived = headers
.get(SEC_WEBSOCKET_KEY)
.map(|k| derive_accept_key(k.as_bytes()));
let mut negotiated_protocol: Option<String> = None;
if let Some(protocols) = headers
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|h| h.to_str().ok())
{
negotiated_protocol = protocols.split(',').next().map(|h| h.trim().to_string());
}
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("upgrade"))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| derived.is_none()
{
return Ok(Response::new(Body::default()));
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
let upgraded = TokioIo::new(upgraded);
handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
addr,
)
.await;
}
Err(e) => eprintln!("upgrade error: {}", e),
}
});
let mut res = Response::new(Body::default());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut()
.append(CONNECTION, HeaderValue::from_static("Upgrade"));
res.headers_mut()
.append(UPGRADE, HeaderValue::from_static("websocket"));
res.headers_mut()
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
if let Some(protocol) = negotiated_protocol {
res.headers_mut()
.append(SEC_WEBSOCKET_PROTOCOL, protocol.parse().unwrap());
}
Ok(res)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = env::args()
.nth(1)
.unwrap_or_else(|| "0.0.0.0:4000".to_string())
.parse::<SocketAddr>()?;
let pem = include_bytes!("./pem.pem");
let key = include_bytes!("./key.pem");
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("invalid pem/key");
let acceptor = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity).unwrap());
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(addr).await?;
println!("listening on {}", addr);
loop {
let (stream, remote_addr) = listener.accept().await?;
let acceptor = acceptor.clone();
.expect("failed to bind");
let acceptor = TlsAcceptor::from(
native_tls::TlsAcceptor::new(identity).expect("failed to make tls acceptor"),
);
let acceptor = std::sync::Arc::new(acceptor);
println!("listening on 0.0.0.0:4000");
while let Ok((stream, addr)) = socket.accept().await {
let acceptor_cloned = acceptor.clone();
let prefix_cloned = opt.prefix.clone();
tokio::spawn(async move {
let stream = acceptor.accept(stream).await.expect("not tls");
let stream = acceptor_cloned.accept(stream).await.expect("not tls");
let io = TokioIo::new(stream);
let service = service_fn(move |req| handle_request(req, remote_addr));
let service =
service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone()));
let conn = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(err) = conn.await {
eprintln!("failed to serve connection: {:?}", err);
println!("{:?}: failed to serve conn: {:?}", addr, err);
}
});
}
Ok(())
}
async fn accept_http(
mut req: Request<Incoming>,
addr: String,
prefix: String,
) -> Result<Response<HttpBody>, WebSocketError> {
let uri = req.uri().clone().path().to_string();
if upgrade::is_upgrade_request(&req)
&& let Some(uri) = uri.strip_prefix(&prefix)
{
let (mut res, fut) = upgrade::upgrade(&mut req)?;
if let Some(protocols) = req.headers().get("Sec-Websocket-Protocol").and_then(|x| {
Some(
x.to_str()
.ok()?
.split(',')
.map(|x| x.trim())
.collect::<Vec<&str>>(),
)
}) && protocols.contains(&"wisp-v1")
&& (uri == "" || uri == "/")
{
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
res.headers_mut().insert(
"Sec-Websocket-Protocol",
HeaderValue::from_str("wisp-v1").unwrap(),
);
} else {
let uri = uri.strip_prefix("/").unwrap_or(uri).to_string();
tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await });
}
Ok(Response::from_parts(
res.into_parts().0,
HttpBody::new(Bytes::new()),
))
} else {
println!("random request to path {:?}", uri);
Ok(Response::builder()
.status(StatusCode::OK)
.body(HttpBody::new(":3".to_string().into()))
.unwrap())
}
}
async fn handle_mux(
packet: ConnectPacket,
mut stream: MuxStream<impl ws::WebSocketWrite + Send + 'static>,
) -> Result<bool, WispError> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
);
match packet.stream_type {
StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut mux_stream = stream.into_io().into_asyncrw();
tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
}
StreamType::Udp => {
let udp_socket = UdpSocket::bind(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut data = vec![0u8; 65507]; // udp standard max datagram size
loop {
tokio::select! {
size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => {
let size = size?;
stream.write(Bytes::copy_from_slice(&data[..size])).await?
},
event = stream.read() => {
match event {
Some(event) => match event {
WsEvent::Send(data) => {
udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?;
}
WsEvent::Close(_) => return Ok(false),
},
None => break,
}
}
}
}
}
}
Ok(true)
}
async fn accept_ws(
fut: upgrade::UpgradeFut,
addr: String,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (rx, tx) = fut.await?.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr);
let (mut mux, fut) = ServerMux::new(rx, tx, 128);
tokio::spawn(async move {
if let Err(e) = fut.await {
println!("err in mux: {:?}", e);
}
});
while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream)
.or_else(|err| async move {
let _ = close_err.close(0x03).await;
Err(err)
})
.and_then(|should_send| async move {
if should_send {
close_ok.close(0x02).await
} else {
Ok(())
}
})
.await;
});
}
println!("{:?}: disconnected", addr);
Ok(())
}
async fn accept_wsproxy(
fut: upgrade::UpgradeFut,
incoming_uri: String,
addr: String,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let mut ws_stream = FragmentCollector::new(fut.await?);
println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri);
match hyper::Uri::try_from(incoming_uri.clone()) {
Ok(_) => (),
Err(err) => {
ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")).await?;
return Err(Box::new(err));
}
}
let tcp_stream = match TcpStream::connect(incoming_uri).await {
Ok(stream) => stream,
Err(err) => {
ws_stream
.write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect"))
.await?;
return Err(Box::new(err));
}
};
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = ws_stream.read_frame() => {
match event {
Ok(frame) => {
match frame.opcode {
OpCode::Text | OpCode::Binary => {
let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await;
}
OpCode::Close => {
// tokio closes the stream for us
drop(tcp_stream_framed);
break;
}
_ => {}
}
},
Err(_) => {
// tokio closes the stream for us
drop(tcp_stream_framed);
break;
}
}
},
event = tcp_stream_framed.next() => {
if let Some(res) = event {
match res {
Ok(buf) => {
let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await;
}
Err(_) => {
let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await;
}
}
}
}
}
}
println!("{:?}: disconnected (wsproxy)", addr);
Ok(())
}

View file

@ -0,0 +1,16 @@
[package]
name = "simple-wisp-client"
version = "0.1.0"
edition = "2021"
[dependencies]
bytes = "1.5.0"
fastwebsockets = { version = "0.6.0", features = ["unstable-split", "upgrade"] }
futures = "0.3.30"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["http1", "client"] }
tokio = { version = "1.36.0", features = ["full"] }
tokio-native-tls = "0.3.1"
tokio-util = "0.7.10"
wisp-mux = { path = "../wisp", features = ["fastwebsockets"]}

View file

@ -0,0 +1,114 @@
use bytes::Bytes;
use fastwebsockets::{handshake, FragmentCollectorRead};
use futures::io::AsyncWriteExt;
use http_body_util::Empty;
use hyper::{
header::{CONNECTION, UPGRADE},
Request,
};
use std::{error::Error, future::Future};
use tokio::net::TcpStream;
use tokio_native_tls::{native_tls, TlsConnector};
use wisp_mux::{ClientMux, StreamType};
use tokio_util::either::Either;
#[derive(Debug)]
struct StrError(String);
impl StrError {
pub fn new(str: &str) -> Self {
Self(str.to_string())
}
}
impl std::fmt::Display for StrError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(fmt, "{}", self.0)
}
}
impl Error for StrError {}
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let addr = std::env::args()
.nth(1)
.ok_or(StrError::new("no src addr"))?;
let addr_port: u16 = std::env::args()
.nth(2)
.ok_or(StrError::new("no src port"))?
.parse()?;
let addr_dest = std::env::args()
.nth(3)
.ok_or(StrError::new("no dest addr"))?;
let addr_dest_port: u16 = std::env::args()
.nth(4)
.ok_or(StrError::new("no dest port"))?
.parse()?;
let should_tls: bool = std::env::args()
.nth(5)
.ok_or(StrError::new("no should tls"))?
.parse()?;
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if should_tls {
let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?);
Either::Left(cx.connect(&addr, socket).await?)
} else {
Either::Right(socket)
};
let req = Request::builder()
.method("GET")
.uri(format!("wss://{}:{}/", &addr, addr_port))
.header("Host", &addr)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Protocol", "wisp-v1")
.body(Empty::<Bytes>::new())?;
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?;
tokio::task::spawn(fut);
let mut hi: u64 = 0;
loop {
let mut channel = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?
.into_io()
.into_asyncrw();
for _ in 0..10 {
channel.write_all(b"hiiiiiiii").await?;
hi += 1;
println!("said hi {}", hi);
}
}
#[allow(unreachable_code)]
Ok(())
}

1
wisp/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

25
wisp/Cargo.toml Normal file
View file

@ -0,0 +1,25 @@
[package]
name = "wisp-mux"
version = "0.1.0"
edition = "2021"
[dependencies]
async_io_stream = "0.3.3"
bytes = "1.5.0"
event-listener = "5.0.0"
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
futures = "0.3.30"
futures-util = "0.3.30"
hyper = { version = "1.1.0", optional = true }
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy"], optional = true }
pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true, default-features = false }
tower-service = { version = "0.3.2", optional = true }
ws_stream_wasm = { version = "0.7.4", optional = true }
[features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
ws_stream_wasm = ["dep:ws_stream_wasm"]
tokio_io = ["async_io_stream/tokio_io"]
hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util"]

View file

@ -0,0 +1,72 @@
use bytes::Bytes;
use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
};
use tokio::io::{AsyncRead, AsyncWrite};
impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self {
use OpCode::*;
match opcode {
Continuation => unreachable!(),
Text => Self::Text,
Binary => Self::Binary,
Close => Self::Close,
Ping => Self::Ping,
Pong => Self::Pong,
}
}
}
impl From<Frame<'_>> for crate::ws::Frame {
fn from(mut frame: Frame) -> Self {
Self {
finished: frame.fin,
opcode: frame.opcode.into(),
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
}
}
}
impl TryFrom<crate::ws::Frame> for Frame<'_> {
type Error = crate::WispError;
fn try_from(frame: crate::ws::Frame) -> Result<Self, Self::Error> {
use crate::ws::OpCode::*;
Ok(match frame.opcode {
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
Ping => Self::new(
true,
OpCode::Ping,
None,
Payload::Owned(frame.payload.to_vec()),
),
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
})
}
}
impl From<WebSocketError> for crate::WispError {
fn from(err: WebSocketError) -> Self {
Self::WsImplError(Box::new(err))
}
}
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?
.into())
}
}
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.try_into()?).await.map_err(|e| e.into())
}
}

390
wisp/src/lib.rs Normal file
View file

@ -0,0 +1,390 @@
#![feature(impl_trait_in_assoc_type)]
#[cfg(feature = "fastwebsockets")]
mod fastwebsockets;
mod packet;
mod stream;
#[cfg(feature = "hyper_tower")]
pub mod tokioio;
#[cfg(feature = "hyper_tower")]
pub mod tower;
pub mod ws;
#[cfg(feature = "ws_stream_wasm")]
mod ws_stream_wasm;
pub use crate::packet::*;
pub use crate::stream::*;
use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum Role {
Client,
Server,
}
#[derive(Debug)]
pub enum WispError {
PacketTooSmall,
InvalidPacketType,
InvalidStreamType,
InvalidStreamId,
InvalidUri,
UriHasNoHost,
UriHasNoPort,
MaxStreamCountReached,
StreamAlreadyClosed,
WsFrameInvalidType,
WsFrameNotFinished,
WsImplError(Box<dyn std::error::Error + Sync + Send>),
WsImplSocketClosed,
WsImplNotSupported,
Utf8Error(std::str::Utf8Error),
Other(Box<dyn std::error::Error + Sync + Send>),
}
impl From<std::str::Utf8Error> for WispError {
fn from(err: std::str::Utf8Error) -> WispError {
WispError::Utf8Error(err)
}
}
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WispError::*;
match self {
PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"),
InvalidUri => write!(f, "Invalid URI"),
UriHasNoHost => write!(f, "URI has no host"),
UriHasNoPort => write!(f, "URI has no port"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
Other(err) => write!(f, "Other error: {:?}", err),
}
}
}
impl std::error::Error for WispError {}
struct ServerMuxInner<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
pub async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
buffer_size: u32
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
};
self.stream_map.lock().await.iter().for_each(|x| {
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01)));
});
ret
}
async fn server_close_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
tx: ws::LockedWebSocketWrite<W>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
tx.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
}
}
Ok(())
}
async fn server_msg_loop<R>(
&self,
mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
buffer_size: u32,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.tx
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
self.stream_map.lock().await.insert(packet.stream_id, ch_tx);
muxstream_sender
.unbounded_send((
inner_packet,
MuxStream::new(
packet.stream_id,
Role::Server,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
AtomicU32::new(buffer_size).into(),
Event::new().into(),
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
}
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
} else {
break;
}
}
drop(muxstream_sender);
Ok(())
}
}
pub struct ServerMux<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn new<R>(read: R, write: W, buffer_size: u32) -> (Self, impl Future<Output = Result<(), WispError>>)
where
R: ws::WebSocketRead,
{
let (close_tx, close_rx) = mpsc::unbounded::<MuxEvent>();
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self { muxstream_recv: rx },
ServerMuxInner {
tx: write,
close_tx,
stream_map: map.clone(),
}
.into_future(read, close_rx, tx, buffer_size),
)
}
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
self.muxstream_recv.next().await
}
}
pub struct ClientMuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map:
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
}
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
pub async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
}
}
async fn client_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
}
}
Ok(())
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.0.unbounded_send(WsEvent::Send(data));
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
stream
.1
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.2.notify(u32::MAX);
}
}
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.0.unbounded_send(WsEvent::Close(inner_packet));
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
}
Ok(())
}
}
pub struct ClientMux<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map:
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<MuxEvent>,
buf_size: u32,
}
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
pub async fn new<R>(
mut read: R,
write: W,
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError>
where
R: ws::WebSocketRead,
{
let write = ws::LockedWebSocketWrite::new(write);
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet {
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
Ok((
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
buf_size: packet.buffer_remaining,
},
ClientMuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
))
} else {
Err(WispError::InvalidPacketType)
}
}
pub async fn client_new_stream(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream<W>, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let evt: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
self.tx
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
.await?;
self.next_free_stream_id.store(
stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release,
);
self.stream_map
.lock()
.await
.insert(stream_id, (ch_tx, flow_control.clone(), evt.clone()));
Ok(MuxStream::new(
stream_id,
Role::Client,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
flow_control,
evt,
))
}
}

255
wisp/src/packet.rs Normal file
View file

@ -0,0 +1,255 @@
use crate::ws;
use crate::WispError;
use bytes::{Buf, BufMut, Bytes};
#[derive(Debug)]
pub enum StreamType {
Tcp = 0x01,
Udp = 0x02,
}
impl TryFrom<u8> for StreamType {
type Error = WispError;
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
use StreamType::*;
match stream_type {
0x01 => Ok(Tcp),
0x02 => Ok(Udp),
_ => Err(Self::Error::InvalidStreamType),
}
}
}
#[derive(Debug)]
pub struct ConnectPacket {
pub stream_type: StreamType,
pub destination_port: u16,
pub destination_hostname: String,
}
impl ConnectPacket {
pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self {
Self {
stream_type,
destination_port,
destination_hostname,
}
}
}
impl TryFrom<Bytes> for ConnectPacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
stream_type: bytes.get_u8().try_into()?,
destination_port: bytes.get_u16_le(),
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
})
}
}
impl From<ConnectPacket> for Vec<u8> {
fn from(packet: ConnectPacket) -> Self {
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type as u8);
encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes());
encoded
}
}
#[derive(Debug)]
pub struct ContinuePacket {
pub buffer_remaining: u32,
}
impl ContinuePacket {
pub fn new(buffer_remaining: u32) -> Self {
Self { buffer_remaining }
}
}
impl TryFrom<Bytes> for ContinuePacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
buffer_remaining: bytes.get_u32_le(),
})
}
}
impl From<ContinuePacket> for Vec<u8> {
fn from(packet: ContinuePacket) -> Self {
let mut encoded = Self::with_capacity(4);
encoded.put_u32_le(packet.buffer_remaining);
encoded
}
}
#[derive(Debug)]
pub struct ClosePacket {
pub reason: u8,
}
impl ClosePacket {
pub fn new(reason: u8) -> Self {
Self { reason }
}
}
impl TryFrom<Bytes> for ClosePacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall);
}
Ok(Self {
reason: bytes.get_u8(),
})
}
}
impl From<ClosePacket> for Vec<u8> {
fn from(packet: ClosePacket) -> Self {
let mut encoded = Self::with_capacity(1);
encoded.put_u8(packet.reason);
encoded
}
}
#[derive(Debug)]
pub enum PacketType {
Connect(ConnectPacket),
Data(Bytes),
Continue(ContinuePacket),
Close(ClosePacket),
}
impl PacketType {
pub fn as_u8(&self) -> u8 {
use PacketType::*;
match self {
Connect(_) => 0x01,
Data(_) => 0x02,
Continue(_) => 0x03,
Close(_) => 0x04,
}
}
}
impl From<PacketType> for Vec<u8> {
fn from(packet: PacketType) -> Self {
use PacketType::*;
match packet {
Connect(x) => x.into(),
Data(x) => x.to_vec(),
Continue(x) => x.into(),
Close(x) => x.into(),
}
}
}
#[derive(Debug)]
pub struct Packet {
pub stream_id: u32,
pub packet: PacketType,
}
impl Packet {
pub fn new(stream_id: u32, packet: PacketType) -> Self {
Self { stream_id, packet }
}
pub fn new_connect(
stream_id: u32,
stream_type: StreamType,
destination_port: u16,
destination_hostname: String,
) -> Self {
Self {
stream_id,
packet: PacketType::Connect(ConnectPacket::new(
stream_type,
destination_port,
destination_hostname,
)),
}
}
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
Self {
stream_id,
packet: PacketType::Data(data),
}
}
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
Self {
stream_id,
packet: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
}
}
pub fn new_close(stream_id: u32, reason: u8) -> Self {
Self {
stream_id,
packet: PacketType::Close(ClosePacket::new(reason)),
}
}
}
impl TryFrom<Bytes> for Packet {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
if bytes.remaining() < 5 {
return Err(Self::Error::PacketTooSmall);
}
let packet_type = bytes.get_u8();
use PacketType::*;
Ok(Self {
stream_id: bytes.get_u32_le(),
packet: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?),
0x04 => Close(ClosePacket::try_from(bytes)?),
_ => return Err(Self::Error::InvalidPacketType),
},
})
}
}
impl From<Packet> for Vec<u8> {
fn from(packet: Packet) -> Self {
let mut encoded = Self::with_capacity(1 + 4);
encoded.push(packet.packet.as_u8());
encoded.put_u32_le(packet.stream_id);
encoded.extend(Vec::<u8>::from(packet.packet));
encoded
}
}
impl TryFrom<ws::Frame> for Packet {
type Error = WispError;
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> {
if !frame.finished {
return Err(Self::Error::WsFrameNotFinished);
}
if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType);
}
frame.payload.try_into()
}
}
impl From<Packet> for ws::Frame {
fn from(packet: Packet) -> Self {
Self::binary(Vec::<u8>::from(packet).into())
}
}

298
wisp/src/stream.rs Normal file
View file

@ -0,0 +1,298 @@
use async_io_stream::IoStream;
use bytes::Bytes;
use event_listener::Event;
use futures::{
channel::{mpsc, oneshot},
sink, stream,
task::{Context, Poll},
Sink, Stream, StreamExt,
};
use pin_project_lite::pin_project;
use std::{
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
pub enum WsEvent {
Send(Bytes),
Close(crate::ClosePacket),
}
pub enum MuxEvent {
Close(u32, u8, oneshot::Sender<Result<(), crate::WispError>>),
}
pub struct MuxStreamRead<W>
where
W: crate::ws::WebSocketWrite,
{
pub stream_id: u32,
role: crate::Role,
tx: crate::ws::LockedWebSocketWrite<W>,
rx: mpsc::UnboundedReceiver<WsEvent>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
}
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
pub async fn read(&mut self) -> Option<WsEvent> {
if self.is_closed.load(Ordering::Acquire) {
return None;
}
match self.rx.next().await? {
WsEvent::Send(bytes) => {
if self.role == crate::Role::Server {
let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst);
self.tx
.write_frame(
crate::Packet::new_continue(self.stream_id, old_val + 1).into(),
)
.await
.ok()?;
}
Some(WsEvent::Send(bytes))
}
WsEvent::Close(packet) => {
self.is_closed.store(true, Ordering::Release);
Some(WsEvent::Close(packet))
}
}
}
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move {
let evt = rx.read().await?;
Some((
match evt {
WsEvent::Send(bytes) => bytes,
WsEvent::Close(_) => return None,
},
rx,
))
}))
}
}
pub struct MuxStreamWrite<W>
where
W: crate::ws::WebSocketWrite,
{
pub stream_id: u32,
role: crate::Role,
tx: crate::ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<MuxEvent>,
is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>,
}
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed);
}
if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) <= 0 {
self.continue_recieved.listen().await;
}
self.tx
.write_frame(crate::Packet::new_data(self.stream_id, data).into())
.await?;
if self.role == crate::Role::Client {
self.flow_control.store(
self.flow_control
.load(Ordering::Acquire)
.checked_add(1)
.unwrap_or(0),
Ordering::Release,
);
}
Ok(())
}
pub fn get_close_handle(&self) -> MuxStreamCloser {
MuxStreamCloser {
stream_id: self.stream_id,
close_channel: self.close_channel.clone(),
is_closed: self.is_closed.clone(),
}
}
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
self.close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx))
.map_err(|x| crate::WispError::Other(Box::new(x)))?;
rx.await
.map_err(|x| crate::WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>> {
Box::pin(sink::unfold(self, |tx, data| async move {
tx.write(data).await?;
Ok(tx)
}))
}
}
impl<W: crate::ws::WebSocketWrite> Drop for MuxStreamWrite<W> {
fn drop(&mut self) {
let (tx, _) = oneshot::channel::<Result<(), crate::WispError>>();
let _ = self
.close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, 0x01, tx));
}
}
pub struct MuxStream<W>
where
W: crate::ws::WebSocketWrite,
{
pub stream_id: u32,
rx: MuxStreamRead<W>,
tx: MuxStreamWrite<W>,
}
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
pub(crate) fn new(
stream_id: u32,
role: crate::Role,
rx: mpsc::UnboundedReceiver<WsEvent>,
tx: crate::ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<MuxEvent>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>
) -> Self {
Self {
stream_id,
rx: MuxStreamRead {
stream_id,
role,
tx: tx.clone(),
rx,
is_closed: is_closed.clone(),
flow_control: flow_control.clone(),
},
tx: MuxStreamWrite {
stream_id,
role,
tx,
close_channel,
is_closed: is_closed.clone(),
flow_control: flow_control.clone(),
continue_recieved: continue_recieved.clone(),
},
}
}
pub async fn read(&mut self) -> Option<WsEvent> {
self.rx.read().await
}
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
self.tx.write(data).await
}
pub fn get_close_handle(&self) -> MuxStreamCloser {
self.tx.get_close_handle()
}
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
self.tx.close(reason).await
}
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
(self.rx, self.tx)
}
pub fn into_io(self) -> MuxStreamIo {
MuxStreamIo {
rx: self.rx.into_stream(),
tx: self.tx.into_sink(),
}
}
}
pub struct MuxStreamCloser {
stream_id: u32,
close_channel: mpsc::UnboundedSender<MuxEvent>,
is_closed: Arc<AtomicBool>,
}
impl MuxStreamCloser {
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
self.close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx))
.map_err(|x| crate::WispError::Other(Box::new(x)))?;
rx.await
.map_err(|x| crate::WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
}
pin_project! {
pub struct MuxStreamIo {
#[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
#[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>>,
}
}
impl MuxStreamIo {
pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> {
IoStream::new(self)
}
}
impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
}
}
impl Sink<Vec<u8>> for MuxStreamIo {
type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_ready(cx)
.map_err(std::io::Error::other)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.project()
.tx
.start_send(item.into())
.map_err(std::io::Error::other)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(std::io::Error::other)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(std::io::Error::other)
}
}

View file

@ -1,7 +1,6 @@
#![allow(dead_code)]
// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs
// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need
// hyper-util for TokioIo.
// hyper-util fails to compile on WASM as it has a dependency on socket2
use std::{
pin::Pin,
@ -169,3 +168,9 @@ where
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> hyper_util::client::legacy::connect::Connection for TokioIo<T> {
fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
hyper_util::client::legacy::connect::Connected::new()
}
}

41
wisp/src/tower.rs Normal file
View file

@ -0,0 +1,41 @@
use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError};
use async_io_stream::IoStream;
use futures::{
task::{Context, Poll},
Future,
};
use std::sync::Arc;
pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>);
impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> {
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Uri) -> Self::Future {
let mux = self.0.clone();
async move {
Ok(TokioIo::new(
mux.client_new_stream(
StreamType::Tcp,
req.host().ok_or(WispError::UriHasNoHost)?.to_string(),
req.port().ok_or(WispError::UriHasNoPort)?.into(),
)
.await?
.into_io()
.into_asyncrw(),
))
}
}
}
impl<W: WebSocketWrite + Send + 'static> Clone for ServiceWrapper<W> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

76
wisp/src/ws.rs Normal file
View file

@ -0,0 +1,76 @@
use bytes::Bytes;
use futures::lock::Mutex;
use std::sync::Arc;
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode {
Text,
Binary,
Close,
Ping,
Pong,
}
pub struct Frame {
pub finished: bool,
pub opcode: OpCode,
pub payload: Bytes,
}
impl Frame {
pub fn text(payload: Bytes) -> Self {
Self {
finished: true,
opcode: OpCode::Text,
payload,
}
}
pub fn binary(payload: Bytes) -> Self {
Self {
finished: true,
opcode: OpCode::Binary,
payload,
}
}
pub fn close(payload: Bytes) -> Self {
Self {
finished: true,
opcode: OpCode::Close,
payload,
}
}
}
pub trait WebSocketRead {
fn wisp_read_frame(
&mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send;
}
pub trait WebSocketWrite {
fn wisp_write_frame(
&mut self,
frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send;
}
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> {
pub fn new(ws: S) -> Self {
Self(Arc::new(Mutex::new(ws)))
}
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
self.0.lock().await.wisp_write_frame(frame).await
}
}
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

View file

@ -0,0 +1,60 @@
use futures::{stream::{SplitStream, SplitSink}, SinkExt, StreamExt};
use ws_stream_wasm::{WsErr, WsMessage, WsStream};
impl From<WsMessage> for crate::ws::Frame {
fn from(msg: WsMessage) -> Self {
use crate::ws::OpCode;
match msg {
WsMessage::Text(str) => Self {
finished: true,
opcode: OpCode::Text,
payload: str.into(),
},
WsMessage::Binary(bin) => Self {
finished: true,
opcode: OpCode::Binary,
payload: bin.into(),
},
}
}
}
impl TryFrom<crate::ws::Frame> for WsMessage {
type Error = crate::WispError;
fn try_from(msg: crate::ws::Frame) -> Result<Self, Self::Error> {
use crate::ws::OpCode;
match msg.opcode {
OpCode::Text => Ok(Self::Text(std::str::from_utf8(&msg.payload)?.to_string())),
OpCode::Binary => Ok(Self::Binary(msg.payload.to_vec())),
_ => Err(Self::Error::WsImplNotSupported),
}
}
}
impl From<WsErr> for crate::WispError {
fn from(err: WsErr) -> Self {
Self::WsImplError(Box::new(err))
}
}
impl crate::ws::WebSocketRead for SplitStream<WsStream> {
async fn wisp_read_frame(
&mut self,
_: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.next()
.await
.ok_or(crate::WispError::WsImplSocketClosed)?
.into())
}
}
impl crate::ws::WebSocketWrite for SplitSink<WsStream, WsMessage> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self
.send(frame.try_into()?)
.await
.map_err(|e| e.into())
}
}