mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
commit
21d847fb56
29 changed files with 2701 additions and 746 deletions
11
.gitignore
vendored
11
.gitignore
vendored
|
@ -1,9 +1,12 @@
|
|||
/target
|
||||
server/src/*.pem
|
||||
**/*.pem
|
||||
client/pkg
|
||||
client/out
|
||||
.direnv
|
||||
client/index.js
|
||||
client/module.js
|
||||
client/module.d.ts
|
||||
client/epoxy-bundled.js
|
||||
client/epoxy-module-bundled.js
|
||||
client/epoxy-module-bundled.d.ts
|
||||
client/epoxy.js
|
||||
client/epoxy.d.ts
|
||||
client/epoxy.wasm
|
||||
pnpm-lock.yaml
|
||||
|
|
764
Cargo.lock
generated
764
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,11 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["server", "client"]
|
||||
members = ["server", "client", "wisp", "simple-wisp-client"]
|
||||
|
||||
[patch.crates-io]
|
||||
rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
opt-level = 'z'
|
||||
codegen-units = 1
|
||||
|
|
46
README.md
46
README.md
|
@ -1,29 +1,55 @@
|
|||
# epoxy
|
||||
Epoxy is an encrypted proxy for browser javascript. It allows you to make requests that bypass cors without compromising security, by running SSL/TLS inside webassembly.
|
||||
|
||||
Simple usage example for making a secure GET request to httpbin.org:
|
||||
## Using the client
|
||||
Epoxy must be run from within a web worker and must be served with the [security headers needed for `SharedArrayBuffer`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer#security_requirements). Here is a simple usage example:
|
||||
```javascript
|
||||
import epoxy from "@mercuryworkshop/epoxy-tls";
|
||||
importScripts("epoxy-bundled.js");
|
||||
|
||||
const { EpoxyClient } = await epoxy();
|
||||
let client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10);
|
||||
|
||||
let response = await client.fetch("https://httpbin.org/get");
|
||||
await response.text();
|
||||
|
||||
```
|
||||
|
||||
Epoxy also allows you to make arbitrary end to end encrypted TCP connections safely directly from the browser.
|
||||
## Using the server
|
||||
```
|
||||
$ cargo r -r --bin epoxy-server -- --help
|
||||
Implementation of the Wisp protocol in Rust, made for epoxy.
|
||||
|
||||
Usage: epoxy-server [OPTIONS] --pubkey <PUBKEY> --privkey <PRIVKEY>
|
||||
|
||||
Options:
|
||||
--prefix <PREFIX> [default: ]
|
||||
-l, --port <PORT> [default: 4000]
|
||||
-p, --pubkey <PUBKEY>
|
||||
-P, --privkey <PRIVKEY>
|
||||
-h, --help Print help
|
||||
-V, --version Print version
|
||||
```
|
||||
|
||||
## Building
|
||||
Rust nightly is required.
|
||||
|
||||
### Server
|
||||
|
||||
1. Generate certs with `mkcert` and place the public certificate in `./server/src/pem.pem` and private certificate in `./server/src/key.pem`
|
||||
2. Run `cargo r --bin epoxy-server`, optionally with `-r` flag for release
|
||||
```
|
||||
cargo b -r --bin epoxy-server
|
||||
```
|
||||
The executable will be placed at `target/release/epoxy-server`.
|
||||
|
||||
### Client
|
||||
Note: Building the client is only supported on linux
|
||||
> [!IMPORTANT]
|
||||
> Building the client is only supported on Linux.
|
||||
|
||||
1. Make sure you have the `wasm32-unknown-unknown` target installed, `wasm-bindgen` and `wasm-opt` executables installed, and `bash`, `python3` packages (`python3` is used for `http.server` module)
|
||||
2. Run `pnpm build`
|
||||
Make sure you have the `wasm32-unknown-unknown` rust target, the `rust-std` component, and the `wasm-bindgen`, `wasm-opt`, and `base64` binaries installed.
|
||||
|
||||
In the `client` directory:
|
||||
```
|
||||
bash build.sh
|
||||
```
|
||||
|
||||
To host a local server with the required headers:
|
||||
```
|
||||
python3 serve.py
|
||||
```
|
||||
|
|
|
@ -6,18 +6,12 @@ edition = "2021"
|
|||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.5.0"
|
||||
console_error_panic_hook = { version = "0.1.7", optional = true }
|
||||
http = "1.0.0"
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.1.0", features = ["client", "http1"] }
|
||||
hyper = { version = "1.1.0", features = ["client", "http1", "http2"] }
|
||||
pin-project-lite = "0.2.13"
|
||||
penguin-mux-wasm = { git = "https://github.com/r58Playz/penguin-mux-wasm" }
|
||||
tokio = { version = "1.35.1", default_features = false }
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4.39"
|
||||
ws_stream_wasm = { version = "0.7.4", features = ["tokio_io"] }
|
||||
|
@ -25,17 +19,19 @@ futures-util = "0.3.30"
|
|||
js-sys = "0.3.66"
|
||||
webpki-roots = "0.26.0"
|
||||
tokio-rustls = "0.25.0"
|
||||
web-sys = { version = "0.3.66", features = ["TextEncoder", "Navigator", "Response", "ResponseInit"] }
|
||||
web-sys = { version = "0.3.66", features = ["TextEncoder", "Response", "ResponseInit"] }
|
||||
wasm-streams = "0.4.0"
|
||||
either = "1.9.0"
|
||||
tokio-util = { version = "0.7.10", features = ["io"] }
|
||||
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
|
||||
fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] }
|
||||
rand = "0.8.5"
|
||||
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
|
||||
base64 = "0.21.7"
|
||||
|
||||
[dependencies.getrandom]
|
||||
features = ["js"]
|
||||
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io", "hyper_tower"] }
|
||||
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
|
||||
getrandom = { version = "0.2.12", features = ["js"] }
|
||||
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] }
|
||||
tokio = { version = "1.36.0", default-features = false }
|
||||
tower-service = "0.3.2"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
|
||||
[dependencies.ring]
|
||||
features = ["wasm32_unknown_unknown_js"]
|
||||
|
|
|
@ -5,27 +5,33 @@ shopt -s inherit_errexit
|
|||
rm -rf out/ || true
|
||||
mkdir out/
|
||||
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
echo "[ws] built rust"
|
||||
RUSTFLAGS='-C target-feature=+atomics,+bulk-memory' cargo build --target wasm32-unknown-unknown -Z build-std=panic_abort,std --release
|
||||
echo "[ws] cargo finished"
|
||||
wasm-bindgen --weak-refs --target no-modules --no-modules-global epoxy --out-dir out/ ../target/wasm32-unknown-unknown/release/epoxy_client.wasm
|
||||
echo "[ws] bindgen finished"
|
||||
echo "[ws] wasm-bindgen finished"
|
||||
|
||||
mv out/epoxy_client_bg.wasm out/epoxy_client_unoptimized.wasm
|
||||
time wasm-opt -O4 out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm
|
||||
echo "[ws] optimized"
|
||||
time wasm-opt -Oz --vacuum --dce --enable-threads --enable-bulk-memory --enable-simd out/epoxy_client_unoptimized.wasm -o out/epoxy_client_bg.wasm
|
||||
echo "[ws] wasm-opt finished"
|
||||
|
||||
AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js")
|
||||
# patch for websocket sharedarraybuffer error
|
||||
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//getObject(arg0).send(getArrayU8FromWasm0(arg1, arg2)/getObject(arg0).send(new Uint8Array(getArrayU8FromWasm0(arg1, arg2))}
|
||||
WASM_BASE64=$(base64 -w0 out/epoxy_client_bg.wasm)
|
||||
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input) \{/__wbg_init() \{let input=\'data:application/wasm;base64,$WASM_BASE64\'}
|
||||
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input, maybe_memory) \{/__wbg_init(input, maybe_memory) \{$'\n'if (!input) \{input=\'data:application/wasm;base64,$WASM_BASE64\'\}}
|
||||
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//return __wbg_finalize_init\(instance\, module\);/__wbg_finalize_init\(instance\, module\); return epoxy}
|
||||
echo "$AUTOGENERATED_SOURCE" > index.js
|
||||
cp index.js module.js
|
||||
echo "module.exports = epoxy" >> module.js
|
||||
echo "$AUTOGENERATED_SOURCE" > epoxy-bundled.js
|
||||
cp epoxy-bundled.js epoxy-module-bundled.js
|
||||
echo "module.exports = epoxy" >> epoxy-module-bundled.js
|
||||
|
||||
AUTOGENERATED_TYPEDEFS=$(<"out/epoxy_client.d.ts")
|
||||
AUTOGENERATED_TYPEDEFS=${AUTOGENERATED_TYPEDEFS%%export class IntoUnderlyingByteSource*}
|
||||
echo "$AUTOGENERATED_TYPEDEFS" >"module.d.ts"
|
||||
echo "} export default function epoxy(): Promise<typeof wasm_bindgen>;" >> "module.d.ts"
|
||||
echo "$AUTOGENERATED_TYPEDEFS" >"epoxy-module-bundled.d.ts"
|
||||
echo "} export default function epoxy(): Promise<typeof wasm_bindgen>;" >> "epoxy-module-bundled.d.ts"
|
||||
|
||||
cp out/epoxy_client.js epoxy.js
|
||||
cp out/epoxy_client.d.ts epoxy.d.ts
|
||||
cp out/epoxy_client_bg.wasm epoxy.wasm
|
||||
|
||||
rm -rf out/
|
||||
echo "[ws] done!"
|
||||
|
|
137
client/demo.js
137
client/demo.js
|
@ -1,21 +1,38 @@
|
|||
(async () => {
|
||||
importScripts("epoxy-bundled.js");
|
||||
onmessage = async (msg) => {
|
||||
console.debug("recieved:", msg);
|
||||
let [should_feature_test, should_multiparallel_test, should_parallel_test, should_multiperf_test, should_perf_test, should_ws_test, should_tls_test] = msg.data;
|
||||
console.log(
|
||||
"%cWASM is significantly slower with DevTools open!",
|
||||
"color:red;font-size:3rem;font-weight:bold"
|
||||
);
|
||||
|
||||
const should_feature_test = (new URL(window.location.href)).searchParams.has("feature_test");
|
||||
const should_perf_test = (new URL(window.location.href)).searchParams.has("perf_test");
|
||||
const should_ws_test = (new URL(window.location.href)).searchParams.has("ws_test");
|
||||
const log = (str) => {
|
||||
console.warn(str);
|
||||
postMessage(str);
|
||||
}
|
||||
|
||||
let { EpoxyClient } = await epoxy();
|
||||
const { EpoxyClient } = await epoxy();
|
||||
|
||||
const tconn0 = performance.now();
|
||||
// args: websocket url, user agent, redirect limit
|
||||
let epoxy_client = await new EpoxyClient("wss://localhost:4000", navigator.userAgent, 10);
|
||||
const tconn1 = performance.now();
|
||||
console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
|
||||
log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
|
||||
|
||||
const test_mux = async (url) => {
|
||||
const t0 = performance.now();
|
||||
await epoxy_client.fetch(url);
|
||||
const t1 = performance.now();
|
||||
return t1 - t0;
|
||||
};
|
||||
|
||||
const test_native = async (url) => {
|
||||
const t0 = performance.now();
|
||||
await fetch(url, { cache: "no-store" });
|
||||
const t1 = performance.now();
|
||||
return t1 - t0;
|
||||
};
|
||||
|
||||
if (should_feature_test) {
|
||||
for (const url of [
|
||||
|
@ -23,44 +40,102 @@
|
|||
["https://httpbin.org/gzip", {}],
|
||||
["https://httpbin.org/brotli", {}],
|
||||
["https://httpbin.org/redirect/11", {}],
|
||||
["https://httpbin.org/redirect/1", { redirect: "manual" }]
|
||||
["https://httpbin.org/redirect/1", { redirect: "manual" }],
|
||||
]) {
|
||||
let resp = await epoxy_client.fetch(url[0], url[1]);
|
||||
console.warn(url, resp, Object.fromEntries(resp.headers));
|
||||
console.warn(await resp.text());
|
||||
}
|
||||
} else if (should_perf_test) {
|
||||
const test_mux = async (url) => {
|
||||
const t0 = performance.now();
|
||||
await epoxy_client.fetch(url);
|
||||
const t1 = performance.now();
|
||||
return t1 - t0;
|
||||
};
|
||||
} else if (should_multiparallel_test) {
|
||||
const num_tests = 10;
|
||||
let total_mux_minus_native = 0;
|
||||
for (const _ of Array(num_tests).keys()) {
|
||||
let total_mux = 0;
|
||||
await Promise.all([...Array(num_tests).keys()].map(async i => {
|
||||
log(`running mux test ${i}`);
|
||||
return await test_mux("https://httpbin.org/get");
|
||||
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
|
||||
total_mux = total_mux / num_tests;
|
||||
|
||||
const test_native = async (url) => {
|
||||
const t0 = performance.now();
|
||||
await fetch(url);
|
||||
const t1 = performance.now();
|
||||
return t1 - t0;
|
||||
};
|
||||
let total_native = 0;
|
||||
await Promise.all([...Array(num_tests).keys()].map(async i => {
|
||||
log(`running native test ${i}`);
|
||||
return await test_native("https://httpbin.org/get");
|
||||
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
|
||||
total_native = total_native / num_tests;
|
||||
|
||||
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
|
||||
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
|
||||
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
|
||||
total_mux_minus_native += total_mux - total_native;
|
||||
}
|
||||
total_mux_minus_native = total_mux_minus_native / num_tests;
|
||||
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
|
||||
} else if (should_parallel_test) {
|
||||
const num_tests = 10;
|
||||
|
||||
let total_mux = 0;
|
||||
await Promise.all([...Array(num_tests).keys()].map(async i => {
|
||||
log(`running mux test ${i}`);
|
||||
return await test_mux("https://httpbin.org/get");
|
||||
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
|
||||
total_mux = total_mux / num_tests;
|
||||
|
||||
let total_native = 0;
|
||||
await Promise.all([...Array(num_tests).keys()].map(async i => {
|
||||
log(`running native test ${i}`);
|
||||
return await test_native("https://httpbin.org/get");
|
||||
})).then((vals) => { total_native = vals.reduce((acc, x) => acc + x, 0) });
|
||||
total_native = total_native / num_tests;
|
||||
|
||||
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
|
||||
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
|
||||
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
|
||||
} else if (should_multiperf_test) {
|
||||
const num_tests = 10;
|
||||
let total_mux_minus_native = 0;
|
||||
for (const _ of Array(num_tests).keys()) {
|
||||
let total_mux = 0;
|
||||
for (const i of Array(num_tests).keys()) {
|
||||
log(`running mux test ${i}`);
|
||||
total_mux += await test_mux("https://httpbin.org/get");
|
||||
}
|
||||
total_mux = total_mux / num_tests;
|
||||
|
||||
let total_native = 0;
|
||||
for (const i of Array(num_tests).keys()) {
|
||||
log(`running native test ${i}`);
|
||||
total_native += await test_native("https://httpbin.org/get");
|
||||
}
|
||||
total_native = total_native / num_tests;
|
||||
|
||||
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
|
||||
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
|
||||
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
|
||||
total_mux_minus_native += total_mux - total_native;
|
||||
}
|
||||
total_mux_minus_native = total_mux_minus_native / num_tests;
|
||||
log(`total mux - native (${num_tests} tests of ${num_tests} reqs): ${total_mux_minus_native} ms or ${total_mux_minus_native / 1000} s`);
|
||||
} else if (should_perf_test) {
|
||||
const num_tests = 10;
|
||||
|
||||
let total_mux = 0;
|
||||
for (const i of Array(num_tests).keys()) {
|
||||
log(`running mux test ${i}`);
|
||||
total_mux += await test_mux("https://httpbin.org/get");
|
||||
}
|
||||
total_mux = total_mux / num_tests;
|
||||
|
||||
let total_native = 0;
|
||||
for (const _ of Array(num_tests).keys()) {
|
||||
for (const i of Array(num_tests).keys()) {
|
||||
log(`running native test ${i}`);
|
||||
total_native += await test_native("https://httpbin.org/get");
|
||||
}
|
||||
total_native = total_native / num_tests;
|
||||
|
||||
console.warn(`avg mux (10) took ${total_mux} ms or ${total_mux / 1000} s`);
|
||||
console.warn(`avg native (10) took ${total_native} ms or ${total_native / 1000} s`);
|
||||
console.warn(`mux - native: ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
|
||||
log(`avg mux (${num_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
|
||||
log(`avg native (${num_tests}) took ${total_native} ms or ${total_native / 1000} s`);
|
||||
log(`avg mux - avg native (${num_tests}): ${total_mux - total_native} ms or ${(total_mux - total_native) / 1000} s`);
|
||||
} else if (should_ws_test) {
|
||||
let ws = await epoxy_client.connect_ws(
|
||||
() => console.log("opened"),
|
||||
|
@ -75,10 +150,20 @@
|
|||
await ws.send("data");
|
||||
await (new Promise((res, _) => setTimeout(res, 100)));
|
||||
}
|
||||
} else if (should_tls_test) {
|
||||
let decoder = new TextDecoder();
|
||||
let ws = await epoxy_client.connect_tls(
|
||||
() => console.log("opened"),
|
||||
() => console.log("closed"),
|
||||
err => console.error(err),
|
||||
msg => { console.log(msg); console.log(decoder.decode(msg)) },
|
||||
"alicesworld.tech:443",
|
||||
);
|
||||
await ws.send("GET / HTTP 1.1\r\nHost: alicesworld.tech\r\nConnection: close\r\n\r\n");
|
||||
} else {
|
||||
let resp = await epoxy_client.fetch("https://httpbin.org/get");
|
||||
console.warn(resp, Object.fromEntries(resp.headers));
|
||||
console.warn(await resp.text());
|
||||
}
|
||||
if (!should_ws_test) alert("you can open console now");
|
||||
})();
|
||||
log("done");
|
||||
};
|
||||
|
|
|
@ -2,12 +2,35 @@
|
|||
|
||||
<head>
|
||||
<title>epoxy</title>
|
||||
<script src="index.js"></script>
|
||||
<script src="demo.js"></script>
|
||||
<style>
|
||||
body { font-family: sans-serif }
|
||||
#logs > * { font-family: monospace }
|
||||
</style>
|
||||
<script>
|
||||
const params = (new URL(window.location.href)).searchParams;
|
||||
const should_feature_test = params.has("feature_test");
|
||||
const should_multiparallel_test = params.has("multi_parallel_test");
|
||||
const should_parallel_test = params.has("parallel_test");
|
||||
const should_multiperf_test = params.has("multi_perf_test");
|
||||
const should_perf_test = params.has("perf_test");
|
||||
const should_ws_test = params.has("ws_test");
|
||||
const should_tls_test = params.has("rawtls_test");
|
||||
const worker = new Worker("demo.js");
|
||||
worker.onmessage = (msg) => {
|
||||
let el = document.createElement("div");
|
||||
el.textContent = msg.data;
|
||||
document.getElementById("logs").appendChild(el);
|
||||
window.scrollTo(0, document.body.scrollHeight);
|
||||
};
|
||||
worker.postMessage([should_feature_test, should_multiparallel_test, should_parallel_test, should_multiperf_test, should_perf_test, should_ws_test, should_tls_test]);
|
||||
</script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
running... (wait for the browser alert if not running ws test)
|
||||
<div>
|
||||
running... (wait for the browser alert if not running ws test)
|
||||
</div>
|
||||
<div id="logs"></div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
|
|
10
client/serve.py
Normal file
10
client/serve.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from http.server import HTTPServer, SimpleHTTPRequestHandler, test
|
||||
import sys
|
||||
|
||||
class RequestHandler (SimpleHTTPRequestHandler):
|
||||
def end_headers (self):
|
||||
self.send_header('Cross-Origin-Opener-Policy', 'same-origin')
|
||||
self.send_header('Cross-Origin-Embedder-Policy', 'require-corp')
|
||||
SimpleHTTPRequestHandler.end_headers(self)
|
||||
|
||||
test(RequestHandler, HTTPServer, port=int(sys.argv[1]) if len(sys.argv) > 1 else 8000)
|
|
@ -1,24 +1,25 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(let_chains, impl_trait_in_assoc_type)]
|
||||
#[macro_use]
|
||||
mod utils;
|
||||
mod tokioio;
|
||||
mod tls_stream;
|
||||
mod websocket;
|
||||
mod wrappers;
|
||||
|
||||
use tokioio::TokioIo;
|
||||
use tls_stream::EpxTlsStream;
|
||||
use utils::{ReplaceErr, UriExt};
|
||||
use websocket::EpxWebSocket;
|
||||
use wrappers::{IncomingBody, WsStreamWrapper};
|
||||
use wrappers::{IncomingBody, TlsWispService};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_compression::tokio::bufread as async_comp;
|
||||
use async_io_stream::IoStream;
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::{stream::SplitSink, StreamExt};
|
||||
use http::{uri, HeaderName, HeaderValue, Request, Response};
|
||||
use hyper::{body::Incoming, client::conn::http1::Builder, Uri};
|
||||
use hyper::{body::Incoming, Uri};
|
||||
use hyper_util::client::legacy::Client;
|
||||
use js_sys::{Array, Function, Object, Reflect, Uint8Array};
|
||||
use penguin_mux_wasm::{Multiplexor, MuxStream};
|
||||
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
|
||||
use tokio_util::{
|
||||
either::Either,
|
||||
|
@ -26,13 +27,15 @@ use tokio_util::{
|
|||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::TextEncoder;
|
||||
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, ClientMux, MuxStreamIo, StreamType};
|
||||
use ws_stream_wasm::{WsMessage, WsMeta, WsStream};
|
||||
|
||||
type HttpBody = http_body_util::Full<Bytes>;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum EpxResponse {
|
||||
Success(Response<Incoming>),
|
||||
Redirect((Response<Incoming>, http::Request<HttpBody>, Uri)),
|
||||
Redirect((Response<Incoming>, http::Request<HttpBody>)),
|
||||
}
|
||||
|
||||
enum EpxCompression {
|
||||
|
@ -40,80 +43,20 @@ enum EpxCompression {
|
|||
Gzip,
|
||||
}
|
||||
|
||||
type EpxTlsStream = TlsStream<MuxStream<WsStreamWrapper>>;
|
||||
type EpxUnencryptedStream = MuxStream<WsStreamWrapper>;
|
||||
type EpxStream = Either<EpxTlsStream, EpxUnencryptedStream>;
|
||||
|
||||
async fn send_req(
|
||||
req: http::Request<HttpBody>,
|
||||
should_redirect: bool,
|
||||
io: EpxStream,
|
||||
) -> Result<EpxResponse, JsError> {
|
||||
let (mut req_sender, conn) = Builder::new()
|
||||
.title_case_headers(true)
|
||||
.preserve_header_case(true)
|
||||
.handshake(TokioIo::new(io))
|
||||
.await
|
||||
.replace_err("Failed to connect to host")?;
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
if let Err(e) = conn.await {
|
||||
error!("epoxy: error in muxed hyper connection! {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let new_req = if should_redirect {
|
||||
Some(req.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let res = req_sender
|
||||
.send_request(req)
|
||||
.await
|
||||
.replace_err("Failed to send request");
|
||||
match res {
|
||||
Ok(res) => {
|
||||
if utils::is_redirect(res.status().as_u16())
|
||||
&& let Some(mut new_req) = new_req
|
||||
&& let Some(location) = res.headers().get("Location")
|
||||
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
|
||||
&& let Some(redirect_url_authority) = redirect_url
|
||||
.clone()
|
||||
.authority()
|
||||
.replace_err("Redirect URL must have an authority")
|
||||
.ok()
|
||||
{
|
||||
let should_strip = new_req.uri().is_same_host(&redirect_url);
|
||||
if should_strip {
|
||||
new_req.headers_mut().remove("authorization");
|
||||
new_req.headers_mut().remove("cookie");
|
||||
new_req.headers_mut().remove("www-authenticate");
|
||||
}
|
||||
let new_url = redirect_url.clone();
|
||||
*new_req.uri_mut() = redirect_url;
|
||||
new_req.headers_mut().insert(
|
||||
"Host",
|
||||
HeaderValue::from_str(redirect_url_authority.as_str())?,
|
||||
);
|
||||
Ok(EpxResponse::Redirect((res, new_req, new_url)))
|
||||
} else {
|
||||
Ok(EpxResponse::Success(res))
|
||||
}
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
|
||||
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
|
||||
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
|
||||
|
||||
#[wasm_bindgen(start)]
|
||||
async fn start() {
|
||||
utils::set_panic_hook();
|
||||
fn init() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct EpoxyClient {
|
||||
rustls_config: Arc<rustls::ClientConfig>,
|
||||
mux: Multiplexor<WsStreamWrapper>,
|
||||
mux: Arc<ClientMux<SplitSink<WsStream, WsMessage>>>,
|
||||
hyper_client: Client<TlsWispService<SplitSink<WsStream, WsMessage>>, HttpBody>,
|
||||
useragent: String,
|
||||
redirect_limit: usize,
|
||||
}
|
||||
|
@ -138,11 +81,19 @@ impl EpoxyClient {
|
|||
}
|
||||
|
||||
debug!("connecting to ws {:?}", ws_url);
|
||||
let ws = WsStreamWrapper::connect(ws_url, None)
|
||||
let (_, ws) = WsMeta::connect(ws_url, vec!["wisp-v1"])
|
||||
.await
|
||||
.replace_err("Failed to connect to websocket")?;
|
||||
debug!("connected!");
|
||||
let mux = Multiplexor::new(ws, penguin_mux_wasm::Role::Client, None, None);
|
||||
let (wtx, wrx) = ws.split();
|
||||
let (mux, fut) = ClientMux::new(wrx, wtx).await?;
|
||||
let mux = Arc::new(mux);
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
if let Err(err) = fut.await {
|
||||
error!("epoxy: error in mux future! {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
let mut certstore = RootCertStore::empty();
|
||||
certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
@ -154,37 +105,86 @@ impl EpoxyClient {
|
|||
);
|
||||
|
||||
Ok(EpoxyClient {
|
||||
mux,
|
||||
mux: mux.clone(),
|
||||
hyper_client: Client::builder(utils::WasmExecutor {})
|
||||
.http09_responses(true)
|
||||
.http1_title_case_headers(true)
|
||||
.http1_preserve_header_case(true)
|
||||
.build(TlsWispService {
|
||||
rustls_config: rustls_config.clone(),
|
||||
service: ServiceWrapper(mux),
|
||||
}),
|
||||
rustls_config,
|
||||
useragent,
|
||||
redirect_limit,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_http_io(&self, url: &Uri) -> Result<EpxStream, JsError> {
|
||||
let url_host = url.host().replace_err("URL must have a host")?;
|
||||
let url_port = utils::get_url_port(url)?;
|
||||
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> {
|
||||
let channel = self
|
||||
.mux
|
||||
.client_new_stream_channel(url_host.as_bytes(), url_port)
|
||||
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
|
||||
.await
|
||||
.replace_err("Failed to create multiplexor channel")?;
|
||||
.replace_err("Failed to create multiplexor channel")?
|
||||
.into_io()
|
||||
.into_asyncrw();
|
||||
let cloned_uri = url_host.to_string().clone();
|
||||
let connector = TlsConnector::from(self.rustls_config.clone());
|
||||
debug!("connecting channel");
|
||||
let io = connector
|
||||
.connect(
|
||||
cloned_uri
|
||||
.try_into()
|
||||
.replace_err("Failed to parse URL (rustls)")?,
|
||||
channel,
|
||||
)
|
||||
.await
|
||||
.replace_err("Failed to perform TLS handshake")?;
|
||||
debug!("connected channel");
|
||||
Ok(io)
|
||||
}
|
||||
|
||||
if utils::get_is_secure(url)? {
|
||||
let cloned_uri = url_host.to_string().clone();
|
||||
let connector = TlsConnector::from(self.rustls_config.clone());
|
||||
let io = connector
|
||||
.connect(
|
||||
cloned_uri
|
||||
.try_into()
|
||||
.replace_err("Failed to parse URL (rustls)")?,
|
||||
channel,
|
||||
)
|
||||
.await
|
||||
.replace_err("Failed to perform TLS handshake")?;
|
||||
Ok(EpxStream::Left(io))
|
||||
async fn send_req_inner(
|
||||
&self,
|
||||
req: http::Request<HttpBody>,
|
||||
should_redirect: bool,
|
||||
) -> Result<EpxResponse, JsError> {
|
||||
let new_req = if should_redirect {
|
||||
Some(req.clone())
|
||||
} else {
|
||||
Ok(EpxStream::Right(channel))
|
||||
None
|
||||
};
|
||||
|
||||
debug!("sending req");
|
||||
let res = self
|
||||
.hyper_client
|
||||
.request(req)
|
||||
.await
|
||||
.replace_err("Failed to send request");
|
||||
debug!("recieved res");
|
||||
match res {
|
||||
Ok(res) => {
|
||||
if utils::is_redirect(res.status().as_u16())
|
||||
&& let Some(mut new_req) = new_req
|
||||
&& let Some(location) = res.headers().get("Location")
|
||||
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
|
||||
&& let Some(redirect_url_authority) = redirect_url
|
||||
.clone()
|
||||
.authority()
|
||||
.replace_err("Redirect URL must have an authority")
|
||||
.ok()
|
||||
{
|
||||
*new_req.uri_mut() = redirect_url;
|
||||
new_req.headers_mut().insert(
|
||||
"Host",
|
||||
HeaderValue::from_str(redirect_url_authority.as_str())?,
|
||||
);
|
||||
Ok(EpxResponse::Redirect((res, new_req)))
|
||||
} else {
|
||||
Ok(EpxResponse::Success(res))
|
||||
}
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,23 +194,22 @@ impl EpoxyClient {
|
|||
should_redirect: bool,
|
||||
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> {
|
||||
let mut redirected = false;
|
||||
let uri = req.uri().clone();
|
||||
let mut current_resp: EpxResponse =
|
||||
send_req(req, should_redirect, self.get_http_io(&uri).await?).await?;
|
||||
let mut current_url = req.uri().clone();
|
||||
let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?;
|
||||
for _ in 0..self.redirect_limit - 1 {
|
||||
match current_resp {
|
||||
EpxResponse::Success(_) => break,
|
||||
EpxResponse::Redirect((_, req, new_url)) => {
|
||||
EpxResponse::Redirect((_, req)) => {
|
||||
redirected = true;
|
||||
current_resp =
|
||||
send_req(req, should_redirect, self.get_http_io(&new_url).await?).await?
|
||||
current_url = req.uri().clone();
|
||||
current_resp = self.send_req_inner(req, should_redirect).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match current_resp {
|
||||
EpxResponse::Success(resp) => Ok((resp, uri, redirected)),
|
||||
EpxResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)),
|
||||
EpxResponse::Success(resp) => Ok((resp, current_url, redirected)),
|
||||
EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,6 +231,17 @@ impl EpoxyClient {
|
|||
.await
|
||||
}
|
||||
|
||||
pub async fn connect_tls(
|
||||
&self,
|
||||
onopen: Function,
|
||||
onclose: Function,
|
||||
onerror: Function,
|
||||
onmessage: Function,
|
||||
url: String,
|
||||
) -> Result<EpxTlsStream, JsError> {
|
||||
EpxTlsStream::connect(self, onopen, onclose, onerror, onmessage, url).await
|
||||
}
|
||||
|
||||
pub async fn fetch(&self, url: String, options: Object) -> Result<web_sys::Response, JsError> {
|
||||
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?;
|
||||
let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?;
|
||||
|
@ -292,7 +302,7 @@ impl EpoxyClient {
|
|||
|
||||
let headers_map = builder.headers_mut().replace_err("Failed to get headers")?;
|
||||
headers_map.insert("Accept-Encoding", HeaderValue::from_str("gzip, br")?);
|
||||
headers_map.insert("Connection", HeaderValue::from_str("close")?);
|
||||
headers_map.insert("Connection", HeaderValue::from_str("keep-alive")?);
|
||||
headers_map.insert("User-Agent", HeaderValue::from_str(&self.useragent)?);
|
||||
headers_map.insert("Host", HeaderValue::from_str(uri_host)?);
|
||||
if body_bytes.is_empty() {
|
||||
|
@ -314,7 +324,7 @@ impl EpoxyClient {
|
|||
.body(HttpBody::new(body_bytes))
|
||||
.replace_err("Failed to make request")?;
|
||||
|
||||
let (resp, last_url, req_redirected) = self.send_req(request, req_should_redirect).await?;
|
||||
let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?;
|
||||
|
||||
let resp_headers_raw = resp.headers().clone();
|
||||
|
||||
|
@ -378,7 +388,7 @@ impl EpoxyClient {
|
|||
Object::define_property(
|
||||
&resp,
|
||||
&jval!("url"),
|
||||
&utils::define_property_obj(jval!(last_url.to_string()), false)
|
||||
&utils::define_property_obj(jval!(resp_uri.to_string()), false)
|
||||
.replace_err("Failed to make define_property object for url")?,
|
||||
);
|
||||
|
||||
|
|
82
client/src/tls_stream.rs
Normal file
82
client/src/tls_stream.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use crate::*;
|
||||
|
||||
use js_sys::Function;
|
||||
use tokio::io::{split, AsyncWriteExt, WriteHalf};
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct EpxTlsStream {
|
||||
tx: WriteHalf<EpxIoTlsStream>,
|
||||
onerror: Function,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl EpxTlsStream {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Result<EpxTlsStream, JsError> {
|
||||
Err(jerr!("Use EpoxyClient.connect_tls() instead."))
|
||||
}
|
||||
|
||||
// shut up
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect(
|
||||
tcp: &EpoxyClient,
|
||||
onopen: Function,
|
||||
onclose: Function,
|
||||
onerror: Function,
|
||||
onmessage: Function,
|
||||
url: String,
|
||||
) -> Result<EpxTlsStream, JsError> {
|
||||
let onerr = onerror.clone();
|
||||
let ret: Result<EpxTlsStream, JsError> = async move {
|
||||
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
|
||||
let url_host = url.host().replace_err("URL must have a host")?;
|
||||
let url_port = url.port().replace_err("URL must have a port")?.into();
|
||||
|
||||
let io = tcp.get_tls_io(url_host, url_port).await?;
|
||||
let (rx, tx) = split(io);
|
||||
let mut rx = ReaderStream::new(rx);
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
while let Some(Ok(data)) = rx.next().await {
|
||||
let _ = onmessage.call1(
|
||||
&JsValue::null(),
|
||||
&jval!(Uint8Array::from(data.to_vec().as_slice())),
|
||||
);
|
||||
}
|
||||
let _ = onclose.call0(&JsValue::null());
|
||||
});
|
||||
|
||||
onopen
|
||||
.call0(&Object::default())
|
||||
.replace_err("Failed to call onopen")?;
|
||||
|
||||
Ok(Self { tx, onerror })
|
||||
}
|
||||
.await;
|
||||
if let Err(ret) = ret {
|
||||
let _ = onerr.call1(&JsValue::null(), &jval!(ret.clone()));
|
||||
Err(ret)
|
||||
} else {
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn send(&mut self, payload: Uint8Array) -> Result<(), JsError> {
|
||||
let onerr = self.onerror.clone();
|
||||
let ret = self.tx.write_all(&payload.to_vec()).await;
|
||||
if let Err(ret) = ret {
|
||||
let _ = onerr.call1(&JsValue::null(), &jval!(format!("{}", ret)));
|
||||
Err(ret.into())
|
||||
} else {
|
||||
Ok(ret?)
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn close(&mut self) -> Result<(), JsError> {
|
||||
self.tx.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,13 +1,9 @@
|
|||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use hyper::rt::Executor;
|
||||
use hyper::{header::HeaderValue, Uri};
|
||||
use http::uri;
|
||||
use js_sys::{Array, Object};
|
||||
|
||||
pub fn set_panic_hook() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
use std::future::Future;
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
|
@ -55,15 +51,15 @@ pub trait ReplaceErr {
|
|||
fn replace_err_jv(self, err: &str) -> Result<Self::Ok, JsValue>;
|
||||
}
|
||||
|
||||
impl<T, E> ReplaceErr for Result<T, E> {
|
||||
impl<T, E: std::fmt::Debug> ReplaceErr for Result<T, E> {
|
||||
type Ok = T;
|
||||
|
||||
fn replace_err(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsError> {
|
||||
self.map_err(|_| jerr!(err))
|
||||
self.map_err(|oe| jerr!(&format!("{}, original error: {:?}", err, oe)))
|
||||
}
|
||||
|
||||
fn replace_err_jv(self, err: &str) -> Result<<Self as ReplaceErr>::Ok, JsValue> {
|
||||
self.map_err(|_| jval!(err))
|
||||
self.map_err(|oe| jval!(&format!("{}, original error: {:?}", err, oe)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,6 +98,21 @@ impl UriExt for Uri {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WasmExecutor;
|
||||
|
||||
impl<F> Executor<F> for WasmExecutor
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
fn execute(&self, future: F) {
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
let _ = future.await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
|
||||
js_sys::Object::entries(obj)
|
||||
.to_vec()
|
||||
|
@ -131,41 +142,19 @@ pub fn is_redirect(code: u16) -> bool {
|
|||
}
|
||||
|
||||
pub fn get_is_secure(url: &Uri) -> Result<bool, JsError> {
|
||||
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
|
||||
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
|
||||
// can't use match, compiler error
|
||||
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
|
||||
if *url_scheme == uri::Scheme::HTTP {
|
||||
Ok(false)
|
||||
} else if *url_scheme == uri::Scheme::HTTPS {
|
||||
Ok(true)
|
||||
} else if url_scheme_str == "ws" {
|
||||
Ok(false)
|
||||
} else if url_scheme_str == "wss" {
|
||||
Ok(true)
|
||||
} else {
|
||||
return Ok(false);
|
||||
match url_scheme_str {
|
||||
"https" | "wss" => Ok(true),
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
|
||||
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
|
||||
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
|
||||
if let Some(port) = url.port() {
|
||||
Ok(port.as_u16())
|
||||
} else if get_is_secure(url)? {
|
||||
Ok(443)
|
||||
} else {
|
||||
// can't use match, compiler error
|
||||
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
|
||||
if *url_scheme == uri::Scheme::HTTP {
|
||||
Ok(80)
|
||||
} else if *url_scheme == uri::Scheme::HTTPS {
|
||||
Ok(443)
|
||||
} else if url_scheme_str == "ws" {
|
||||
Ok(80)
|
||||
} else if url_scheme_str == "wss" {
|
||||
Ok(443)
|
||||
} else {
|
||||
return Err(jerr!("Failed to coerce port from scheme"));
|
||||
}
|
||||
Ok(80)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,8 @@ use base64::{engine::general_purpose::STANDARD, Engine};
|
|||
use fastwebsockets::{
|
||||
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
|
||||
};
|
||||
use http_body_util::Empty;
|
||||
use futures_util::lock::Mutex;
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
header::{CONNECTION, UPGRADE},
|
||||
upgrade::Upgraded,
|
||||
|
@ -16,7 +17,7 @@ use tokio::io::WriteHalf;
|
|||
|
||||
#[wasm_bindgen]
|
||||
pub struct EpxWebSocket {
|
||||
tx: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
||||
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
|
||||
onerror: Function,
|
||||
}
|
||||
|
||||
|
@ -44,7 +45,8 @@ impl EpxWebSocket {
|
|||
let url = Uri::try_from(url).replace_err("Failed to parse URL")?;
|
||||
let host = url.host().replace_err("URL must have a host")?;
|
||||
|
||||
let rand: [u8; 16] = rand::random();
|
||||
let mut rand: [u8; 16] = [0; 16];
|
||||
getrandom::getrandom(&mut rand)?;
|
||||
let key = STANDARD.encode(rand);
|
||||
|
||||
let mut builder = Request::builder()
|
||||
|
@ -61,23 +63,9 @@ impl EpxWebSocket {
|
|||
builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", "));
|
||||
}
|
||||
|
||||
let req = builder.body(Empty::<Bytes>::new())?;
|
||||
let req = builder.body(Full::<Bytes>::new(Bytes::new()))?;
|
||||
|
||||
let stream = tcp.get_http_io(&url).await?;
|
||||
|
||||
let (mut sender, conn) = Builder::new()
|
||||
.title_case_headers(true)
|
||||
.preserve_header_case(true)
|
||||
.handshake::<TokioIo<EpxStream>, Empty<Bytes>>(TokioIo::new(stream))
|
||||
.await?;
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
if let Err(e) = conn.with_upgrades().await {
|
||||
error!("epoxy: error in muxed hyper connection (ws)! {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut response = sender.send_request(req).await?;
|
||||
let mut response = tcp.hyper_client.request(req).await?;
|
||||
verify(&response)?;
|
||||
|
||||
let ws = WebSocket::after_handshake(
|
||||
|
@ -88,16 +76,12 @@ impl EpxWebSocket {
|
|||
let (rx, tx) = ws.split(tokio::io::split);
|
||||
|
||||
let mut rx = FragmentCollectorRead::new(rx);
|
||||
let tx = Arc::new(Mutex::new(tx));
|
||||
let tx_cloned = tx.clone();
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
while let Ok(frame) = rx
|
||||
.read_frame(&mut |arg| async move {
|
||||
error!(
|
||||
"wtf is an obligated write {:?}, {:?}, {:?}",
|
||||
arg.fin, arg.opcode, arg.payload
|
||||
);
|
||||
Ok::<(), std::io::Error>(())
|
||||
})
|
||||
.read_frame(&mut |arg| async { tx_cloned.lock().await.write_frame(arg).await })
|
||||
.await
|
||||
{
|
||||
match frame.opcode {
|
||||
|
@ -137,10 +121,12 @@ impl EpxWebSocket {
|
|||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn send(&mut self, payload: String) -> Result<(), JsError> {
|
||||
pub async fn send(&self, payload: String) -> Result<(), JsError> {
|
||||
let onerr = self.onerror.clone();
|
||||
let ret = self
|
||||
.tx
|
||||
.lock()
|
||||
.await
|
||||
.write_frame(Frame::text(Payload::Owned(payload.as_bytes().to_vec())))
|
||||
.await;
|
||||
if let Err(ret) = ret {
|
||||
|
@ -152,8 +138,10 @@ impl EpxWebSocket {
|
|||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn close(&mut self) -> Result<(), JsError> {
|
||||
pub async fn close(&self) -> Result<(), JsError> {
|
||||
self.tx
|
||||
.lock()
|
||||
.await
|
||||
.write_frame(Frame::close(CloseCode::Normal.into(), b""))
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -4,117 +4,11 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::{Sink, Stream};
|
||||
use futures_util::Stream;
|
||||
use hyper::body::Body;
|
||||
use penguin_mux_wasm::ws;
|
||||
use pin_project_lite::pin_project;
|
||||
use ws_stream_wasm::{WsErr, WsMessage, WsMeta, WsStream};
|
||||
|
||||
pin_project! {
|
||||
pub struct WsStreamWrapper {
|
||||
#[pin]
|
||||
ws: WsStream,
|
||||
}
|
||||
}
|
||||
|
||||
impl WsStreamWrapper {
|
||||
pub async fn connect(
|
||||
url: impl AsRef<str>,
|
||||
protocols: impl Into<Option<Vec<&str>>>,
|
||||
) -> Result<Self, WsErr> {
|
||||
let (_, wsstream) = WsMeta::connect(url, protocols).await?;
|
||||
Ok(WsStreamWrapper { ws: wsstream })
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WsStreamWrapper {
|
||||
type Item = Result<ws::Message, ws::Error>;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
let ret = this.ws.poll_next(cx);
|
||||
match ret {
|
||||
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(item.map(|x| {
|
||||
Ok(match x {
|
||||
WsMessage::Text(txt) => ws::Message::Text(txt),
|
||||
WsMessage::Binary(bin) => ws::Message::Binary(bin),
|
||||
})
|
||||
})),
|
||||
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wserr_to_ws_err(err: WsErr) -> ws::Error {
|
||||
debug!("err: {:?}", err);
|
||||
match err {
|
||||
WsErr::ConnectionNotOpen => ws::Error::AlreadyClosed,
|
||||
_ => ws::Error::Io(std::io::Error::other(format!("{:?}", err))),
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<ws::Message> for WsStreamWrapper {
|
||||
type Error = ws::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let this = self.project();
|
||||
let ret = this.ws.poll_ready(cx);
|
||||
match ret {
|
||||
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(wserr_to_ws_err(err)),
|
||||
}),
|
||||
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: ws::Message) -> Result<(), Self::Error> {
|
||||
use ws::Message::*;
|
||||
let item = match item {
|
||||
Text(txt) => WsMessage::Text(txt),
|
||||
Binary(bin) => WsMessage::Binary(bin),
|
||||
Close(_) => {
|
||||
debug!("closing");
|
||||
return match self.ws.wrapped().close() {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(ws::Error::Io(std::io::Error::other(format!(
|
||||
"ws close err: {:?}",
|
||||
err
|
||||
)))),
|
||||
};
|
||||
}
|
||||
Ping(_) | Pong(_) | Frame(_) => return Ok(()),
|
||||
};
|
||||
let this = self.project();
|
||||
let ret = this.ws.start_send(item);
|
||||
match ret {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(wserr_to_ws_err(err)),
|
||||
}
|
||||
}
|
||||
|
||||
// no point wrapping this as it's not going to do anything
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Ok(()).into()
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let this = self.project();
|
||||
let ret = this.ws.poll_close(cx);
|
||||
match ret {
|
||||
Poll::Ready(item) => Poll::<Result<(), Self::Error>>::Ready(match item {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(wserr_to_ws_err(err)),
|
||||
}),
|
||||
Poll::Pending => Poll::<Result<(), Self::Error>>::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ws::WebSocketStream for WsStreamWrapper {
|
||||
fn ping_auto_pong(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
use std::future::Future;
|
||||
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, WispError};
|
||||
|
||||
pin_project! {
|
||||
pub struct IncomingBody {
|
||||
|
@ -138,7 +32,8 @@ impl Stream for IncomingBody {
|
|||
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item {
|
||||
Some(frame) => frame
|
||||
.map(|x| {
|
||||
x.into_data().map_err(|_| std::io::Error::other("not data frame"))
|
||||
x.into_data()
|
||||
.map_err(|_| std::io::Error::other("not data frame"))
|
||||
})
|
||||
.ok(),
|
||||
None => None,
|
||||
|
@ -147,3 +42,68 @@ impl Stream for IncomingBody {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TlsWispService<W>
|
||||
where
|
||||
W: wisp_mux::ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
pub service: ServiceWrapper<W>,
|
||||
pub rustls_config: Arc<rustls::ClientConfig>,
|
||||
}
|
||||
|
||||
|
||||
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri>
|
||||
for TlsWispService<W>
|
||||
{
|
||||
type Response = TokioIo<EpxIoStream>;
|
||||
type Error = WispError;
|
||||
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.service.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: http::Uri) -> Self::Future {
|
||||
let mut service = self.service.clone();
|
||||
let rustls_config = self.rustls_config.clone();
|
||||
Box::pin(async move {
|
||||
let uri_host = req
|
||||
.host()
|
||||
.ok_or(WispError::UriHasNoHost)?
|
||||
.to_string()
|
||||
.clone();
|
||||
let uri_parsed = Uri::builder()
|
||||
.authority(format!(
|
||||
"{}:{}",
|
||||
uri_host,
|
||||
utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)?
|
||||
))
|
||||
.build()
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
let stream = service.call(uri_parsed).await?.into_inner();
|
||||
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
|
||||
let connector = TlsConnector::from(rustls_config);
|
||||
Ok(TokioIo::new(Either::Left(
|
||||
connector
|
||||
.connect(
|
||||
uri_host.try_into().map_err(|_| WispError::InvalidUri)?,
|
||||
stream,
|
||||
)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?,
|
||||
)))
|
||||
} else {
|
||||
Ok(TokioIo::new(Either::Right(stream)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> Clone for TlsWispService<W> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
rustls_config: self.rustls_config.clone(),
|
||||
service: self.service.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
"author": "MercuryWorkshop",
|
||||
"repository": "https://github.com/MercuryWorkshop/epoxy-tls",
|
||||
"license": "MIT",
|
||||
"browser": "./client/module.js",
|
||||
"module": "./client/module.js",
|
||||
"main": "./client/module.js",
|
||||
"types": "./client/module.d.ts"
|
||||
"browser": "./client/epoxy-module-bundled.js",
|
||||
"module": "./client/epoxy-module-bundled.js",
|
||||
"main": "./client/epoxy-module-bundled.js",
|
||||
"types": "./client/epoxy-module-bundled.d.ts"
|
||||
}
|
||||
|
|
|
@ -4,10 +4,16 @@ version = "1.0.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.5.0"
|
||||
clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] }
|
||||
clio = { version = "0.3.5", features = ["clap-parse"] }
|
||||
dashmap = "5.5.3"
|
||||
fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] }
|
||||
futures-util = { version = "0.3.30", features = ["sink"] }
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.1.0", features = ["server", "http1"] }
|
||||
hyper-util = { version = "0.1.2", features = ["tokio"] }
|
||||
rusty-penguin = { version = "0.5.3", default-features = false }
|
||||
tokio = { version = "1.35.1", features = ["rt-multi-thread", "net", "macros"] }
|
||||
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
|
||||
tokio-native-tls = "0.3.1"
|
||||
tokio-tungstenite = "0.21.0"
|
||||
tokio-util = { version = "0.7.10", features = ["codec"] }
|
||||
wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] }
|
||||
|
|
|
@ -1,176 +1,283 @@
|
|||
use std::{convert::Infallible, env, net::SocketAddr, sync::Arc};
|
||||
#![feature(let_chains)]
|
||||
use std::io::{Error, Read};
|
||||
|
||||
use bytes::Bytes;
|
||||
use clap::Parser;
|
||||
use fastwebsockets::{
|
||||
upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
|
||||
WebSocketError,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt, TryFutureExt};
|
||||
use hyper::{
|
||||
body::Incoming,
|
||||
header::{
|
||||
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
|
||||
SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||
},
|
||||
server::conn::http1,
|
||||
service::service_fn,
|
||||
upgrade::Upgraded,
|
||||
Method, Request, Response, StatusCode, Version,
|
||||
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
|
||||
Response, StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use penguin_mux::{Multiplexor, MuxStream};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
task::{JoinError, JoinSet},
|
||||
};
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use tokio_native_tls::{native_tls, TlsAcceptor};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{handshake::derive_accept_key, protocol::Role},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio_util::codec::{BytesCodec, Framed};
|
||||
|
||||
type Body = http_body_util::Empty<hyper::body::Bytes>;
|
||||
use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent};
|
||||
|
||||
type MultiplexorStream = MuxStream<WebSocketStream<TokioIo<Upgraded>>>;
|
||||
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
||||
|
||||
async fn forward(mut stream: MultiplexorStream) -> Result<(), JoinError> {
|
||||
println!("forwarding");
|
||||
let host = std::str::from_utf8(&stream.dest_host).unwrap();
|
||||
let mut tcp_stream = TcpStream::connect((host, stream.dest_port)).await.unwrap();
|
||||
println!("connected to {:?}", tcp_stream.peer_addr().unwrap());
|
||||
tokio::io::copy_bidirectional(&mut stream, &mut tcp_stream)
|
||||
#[derive(Parser)]
|
||||
#[command(version = clap::crate_version!(), about = "Implementation of the Wisp protocol in Rust, made for epoxy.")]
|
||||
struct Cli {
|
||||
#[arg(long, default_value = "")]
|
||||
prefix: String,
|
||||
#[arg(
|
||||
long = "port",
|
||||
short = 'l',
|
||||
value_name = "PORT",
|
||||
default_value = "4000"
|
||||
)]
|
||||
listen_port: String,
|
||||
#[arg(long, short, value_parser)]
|
||||
pubkey: clio::Input,
|
||||
#[arg(long, short = 'P', value_parser)]
|
||||
privkey: clio::Input,
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<(), Error> {
|
||||
let mut opt = Cli::parse();
|
||||
let mut pem = Vec::new();
|
||||
opt.pubkey.read_to_end(&mut pem)?;
|
||||
let mut key = Vec::new();
|
||||
opt.privkey.read_to_end(&mut key)?;
|
||||
let identity = native_tls::Identity::from_pkcs8(&pem, &key).expect("failed to make identity");
|
||||
|
||||
let socket = TcpListener::bind(format!("0.0.0.0:{}", opt.listen_port))
|
||||
.await
|
||||
.unwrap();
|
||||
println!("finished");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connection(ws_stream: WebSocketStream<TokioIo<Upgraded>>, addr: SocketAddr) {
|
||||
println!("WebSocket connection established: {}", addr);
|
||||
let mux = Multiplexor::new(ws_stream, penguin_mux::Role::Server, None, None);
|
||||
let mut jobs = JoinSet::new();
|
||||
println!("muxing");
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(result) = jobs.join_next() => {
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) | Err(err) => eprintln!("failed to forward: {:?}", err),
|
||||
}
|
||||
}
|
||||
Ok(result) = mux.server_new_stream_channel() => {
|
||||
jobs.spawn(forward(result));
|
||||
}
|
||||
else => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("{} disconnected", &addr);
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
mut req: Request<Incoming>,
|
||||
addr: SocketAddr,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
let headers = req.headers();
|
||||
let derived = headers
|
||||
.get(SEC_WEBSOCKET_KEY)
|
||||
.map(|k| derive_accept_key(k.as_bytes()));
|
||||
|
||||
let mut negotiated_protocol: Option<String> = None;
|
||||
if let Some(protocols) = headers
|
||||
.get(SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
negotiated_protocol = protocols.split(',').next().map(|h| h.trim().to_string());
|
||||
}
|
||||
|
||||
if req.method() != Method::GET
|
||||
|| req.version() < Version::HTTP_11
|
||||
|| !headers
|
||||
.get(CONNECTION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
h.split(|c| c == ' ' || c == ',')
|
||||
.any(|p| p.eq_ignore_ascii_case("upgrade"))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
|| !headers
|
||||
.get(UPGRADE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| h.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false)
|
||||
|| !headers
|
||||
.get(SEC_WEBSOCKET_VERSION)
|
||||
.map(|h| h == "13")
|
||||
.unwrap_or(false)
|
||||
|| derived.is_none()
|
||||
{
|
||||
return Ok(Response::new(Body::default()));
|
||||
}
|
||||
|
||||
let ver = req.version();
|
||||
tokio::task::spawn(async move {
|
||||
match hyper::upgrade::on(&mut req).await {
|
||||
Ok(upgraded) => {
|
||||
let upgraded = TokioIo::new(upgraded);
|
||||
handle_connection(
|
||||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
||||
addr,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => eprintln!("upgrade error: {}", e),
|
||||
}
|
||||
});
|
||||
|
||||
let mut res = Response::new(Body::default());
|
||||
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
|
||||
*res.version_mut() = ver;
|
||||
res.headers_mut()
|
||||
.append(CONNECTION, HeaderValue::from_static("Upgrade"));
|
||||
res.headers_mut()
|
||||
.append(UPGRADE, HeaderValue::from_static("websocket"));
|
||||
res.headers_mut()
|
||||
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
|
||||
if let Some(protocol) = negotiated_protocol {
|
||||
res.headers_mut()
|
||||
.append(SEC_WEBSOCKET_PROTOCOL, protocol.parse().unwrap());
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr = env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "0.0.0.0:4000".to_string())
|
||||
.parse::<SocketAddr>()?;
|
||||
let pem = include_bytes!("./pem.pem");
|
||||
let key = include_bytes!("./key.pem");
|
||||
|
||||
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("invalid pem/key");
|
||||
|
||||
let acceptor = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity).unwrap());
|
||||
let acceptor = Arc::new(acceptor);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
println!("listening on {}", addr);
|
||||
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
let acceptor = acceptor.clone();
|
||||
.expect("failed to bind");
|
||||
let acceptor = TlsAcceptor::from(
|
||||
native_tls::TlsAcceptor::new(identity).expect("failed to make tls acceptor"),
|
||||
);
|
||||
let acceptor = std::sync::Arc::new(acceptor);
|
||||
|
||||
println!("listening on 0.0.0.0:4000");
|
||||
while let Ok((stream, addr)) = socket.accept().await {
|
||||
let acceptor_cloned = acceptor.clone();
|
||||
let prefix_cloned = opt.prefix.clone();
|
||||
tokio::spawn(async move {
|
||||
let stream = acceptor.accept(stream).await.expect("not tls");
|
||||
let stream = acceptor_cloned.accept(stream).await.expect("not tls");
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = service_fn(move |req| handle_request(req, remote_addr));
|
||||
|
||||
let service =
|
||||
service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone()));
|
||||
let conn = http1::Builder::new()
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(err) = conn.await {
|
||||
eprintln!("failed to serve connection: {:?}", err);
|
||||
println!("{:?}: failed to serve conn: {:?}", addr, err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_http(
|
||||
mut req: Request<Incoming>,
|
||||
addr: String,
|
||||
prefix: String,
|
||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||
let uri = req.uri().clone().path().to_string();
|
||||
if upgrade::is_upgrade_request(&req)
|
||||
&& let Some(uri) = uri.strip_prefix(&prefix)
|
||||
{
|
||||
let (mut res, fut) = upgrade::upgrade(&mut req)?;
|
||||
|
||||
if let Some(protocols) = req.headers().get("Sec-Websocket-Protocol").and_then(|x| {
|
||||
Some(
|
||||
x.to_str()
|
||||
.ok()?
|
||||
.split(',')
|
||||
.map(|x| x.trim())
|
||||
.collect::<Vec<&str>>(),
|
||||
)
|
||||
}) && protocols.contains(&"wisp-v1")
|
||||
&& (uri == "" || uri == "/")
|
||||
{
|
||||
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
|
||||
res.headers_mut().insert(
|
||||
"Sec-Websocket-Protocol",
|
||||
HeaderValue::from_str("wisp-v1").unwrap(),
|
||||
);
|
||||
} else {
|
||||
let uri = uri.strip_prefix("/").unwrap_or(uri).to_string();
|
||||
tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await });
|
||||
}
|
||||
|
||||
Ok(Response::from_parts(
|
||||
res.into_parts().0,
|
||||
HttpBody::new(Bytes::new()),
|
||||
))
|
||||
} else {
|
||||
println!("random request to path {:?}", uri);
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(HttpBody::new(":3".to_string().into()))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_mux(
|
||||
packet: ConnectPacket,
|
||||
mut stream: MuxStream<impl ws::WebSocketWrite + Send + 'static>,
|
||||
) -> Result<bool, WispError> {
|
||||
let uri = format!(
|
||||
"{}:{}",
|
||||
packet.destination_hostname, packet.destination_port
|
||||
);
|
||||
match packet.stream_type {
|
||||
StreamType::Tcp => {
|
||||
let mut tcp_stream = TcpStream::connect(uri)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
let mut mux_stream = stream.into_io().into_asyncrw();
|
||||
tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
StreamType::Udp => {
|
||||
let udp_socket = UdpSocket::bind(uri)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
let mut data = vec![0u8; 65507]; // udp standard max datagram size
|
||||
loop {
|
||||
tokio::select! {
|
||||
size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => {
|
||||
let size = size?;
|
||||
stream.write(Bytes::copy_from_slice(&data[..size])).await?
|
||||
},
|
||||
event = stream.read() => {
|
||||
match event {
|
||||
Some(event) => match event {
|
||||
WsEvent::Send(data) => {
|
||||
udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
WsEvent::Close(_) => return Ok(false),
|
||||
},
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn accept_ws(
|
||||
fut: upgrade::UpgradeFut,
|
||||
addr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||
let (rx, tx) = fut.await?.split(tokio::io::split);
|
||||
let rx = FragmentCollectorRead::new(rx);
|
||||
|
||||
println!("{:?}: connected", addr);
|
||||
|
||||
let (mut mux, fut) = ServerMux::new(rx, tx, 128);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = fut.await {
|
||||
println!("err in mux: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
while let Some((packet, stream)) = mux.server_new_stream().await {
|
||||
tokio::spawn(async move {
|
||||
let close_err = stream.get_close_handle();
|
||||
let close_ok = stream.get_close_handle();
|
||||
let _ = handle_mux(packet, stream)
|
||||
.or_else(|err| async move {
|
||||
let _ = close_err.close(0x03).await;
|
||||
Err(err)
|
||||
})
|
||||
.and_then(|should_send| async move {
|
||||
if should_send {
|
||||
close_ok.close(0x02).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
println!("{:?}: disconnected", addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_wsproxy(
|
||||
fut: upgrade::UpgradeFut,
|
||||
incoming_uri: String,
|
||||
addr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||
let mut ws_stream = FragmentCollector::new(fut.await?);
|
||||
|
||||
println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri);
|
||||
|
||||
match hyper::Uri::try_from(incoming_uri.clone()) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")).await?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
}
|
||||
|
||||
let tcp_stream = match TcpStream::connect(incoming_uri).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
ws_stream
|
||||
.write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect"))
|
||||
.await?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
};
|
||||
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = ws_stream.read_frame() => {
|
||||
match event {
|
||||
Ok(frame) => {
|
||||
match frame.opcode {
|
||||
OpCode::Text | OpCode::Binary => {
|
||||
let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await;
|
||||
}
|
||||
OpCode::Close => {
|
||||
// tokio closes the stream for us
|
||||
drop(tcp_stream_framed);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
// tokio closes the stream for us
|
||||
drop(tcp_stream_framed);
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
event = tcp_stream_framed.next() => {
|
||||
if let Some(res) = event {
|
||||
match res {
|
||||
Ok(buf) => {
|
||||
let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("{:?}: disconnected (wsproxy)", addr);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
16
simple-wisp-client/Cargo.toml
Normal file
16
simple-wisp-client/Cargo.toml
Normal file
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "simple-wisp-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.5.0"
|
||||
fastwebsockets = { version = "0.6.0", features = ["unstable-split", "upgrade"] }
|
||||
futures = "0.3.30"
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.1.0", features = ["http1", "client"] }
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
tokio-native-tls = "0.3.1"
|
||||
tokio-util = "0.7.10"
|
||||
wisp-mux = { path = "../wisp", features = ["fastwebsockets"]}
|
||||
|
114
simple-wisp-client/src/main.rs
Normal file
114
simple-wisp-client/src/main.rs
Normal file
|
@ -0,0 +1,114 @@
|
|||
use bytes::Bytes;
|
||||
use fastwebsockets::{handshake, FragmentCollectorRead};
|
||||
use futures::io::AsyncWriteExt;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{
|
||||
header::{CONNECTION, UPGRADE},
|
||||
Request,
|
||||
};
|
||||
use std::{error::Error, future::Future};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_native_tls::{native_tls, TlsConnector};
|
||||
use wisp_mux::{ClientMux, StreamType};
|
||||
use tokio_util::either::Either;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StrError(String);
|
||||
|
||||
impl StrError {
|
||||
pub fn new(str: &str) -> Self {
|
||||
Self(str.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StrError {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
|
||||
write!(fmt, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for StrError {}
|
||||
|
||||
struct SpawnExecutor;
|
||||
|
||||
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
|
||||
where
|
||||
Fut: Future + Send + 'static,
|
||||
Fut::Output: Send + 'static,
|
||||
{
|
||||
fn execute(&self, fut: Fut) {
|
||||
tokio::task::spawn(fut);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let addr = std::env::args()
|
||||
.nth(1)
|
||||
.ok_or(StrError::new("no src addr"))?;
|
||||
|
||||
let addr_port: u16 = std::env::args()
|
||||
.nth(2)
|
||||
.ok_or(StrError::new("no src port"))?
|
||||
.parse()?;
|
||||
|
||||
let addr_dest = std::env::args()
|
||||
.nth(3)
|
||||
.ok_or(StrError::new("no dest addr"))?;
|
||||
|
||||
let addr_dest_port: u16 = std::env::args()
|
||||
.nth(4)
|
||||
.ok_or(StrError::new("no dest port"))?
|
||||
.parse()?;
|
||||
let should_tls: bool = std::env::args()
|
||||
.nth(5)
|
||||
.ok_or(StrError::new("no should tls"))?
|
||||
.parse()?;
|
||||
|
||||
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
|
||||
let socket = if should_tls {
|
||||
let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?);
|
||||
Either::Left(cx.connect(&addr, socket).await?)
|
||||
} else {
|
||||
Either::Right(socket)
|
||||
};
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(format!("wss://{}:{}/", &addr, addr_port))
|
||||
.header("Host", &addr)
|
||||
.header(UPGRADE, "websocket")
|
||||
.header(CONNECTION, "upgrade")
|
||||
.header(
|
||||
"Sec-WebSocket-Key",
|
||||
fastwebsockets::handshake::generate_key(),
|
||||
)
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Sec-WebSocket-Protocol", "wisp-v1")
|
||||
.body(Empty::<Bytes>::new())?;
|
||||
|
||||
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
|
||||
|
||||
let (rx, tx) = ws.split(tokio::io::split);
|
||||
let rx = FragmentCollectorRead::new(rx);
|
||||
|
||||
let (mux, fut) = ClientMux::new(rx, tx).await?;
|
||||
|
||||
tokio::task::spawn(fut);
|
||||
|
||||
let mut hi: u64 = 0;
|
||||
loop {
|
||||
let mut channel = mux
|
||||
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
|
||||
.await?
|
||||
.into_io()
|
||||
.into_asyncrw();
|
||||
for _ in 0..10 {
|
||||
channel.write_all(b"hiiiiiiii").await?;
|
||||
hi += 1;
|
||||
println!("said hi {}", hi);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
Ok(())
|
||||
}
|
1
wisp/.gitignore
vendored
Normal file
1
wisp/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/target
|
25
wisp/Cargo.toml
Normal file
25
wisp/Cargo.toml
Normal file
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
name = "wisp-mux"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
async_io_stream = "0.3.3"
|
||||
bytes = "1.5.0"
|
||||
event-listener = "5.0.0"
|
||||
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
|
||||
futures = "0.3.30"
|
||||
futures-util = "0.3.30"
|
||||
hyper = { version = "1.1.0", optional = true }
|
||||
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy"], optional = true }
|
||||
pin-project-lite = "0.2.13"
|
||||
tokio = { version = "1.35.1", optional = true, default-features = false }
|
||||
tower-service = { version = "0.3.2", optional = true }
|
||||
ws_stream_wasm = { version = "0.7.4", optional = true }
|
||||
|
||||
[features]
|
||||
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
||||
ws_stream_wasm = ["dep:ws_stream_wasm"]
|
||||
tokio_io = ["async_io_stream/tokio_io"]
|
||||
hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util"]
|
||||
|
72
wisp/src/fastwebsockets.rs
Normal file
72
wisp/src/fastwebsockets.rs
Normal file
|
@ -0,0 +1,72 @@
|
|||
use bytes::Bytes;
|
||||
use fastwebsockets::{
|
||||
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
impl From<OpCode> for crate::ws::OpCode {
|
||||
fn from(opcode: OpCode) -> Self {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => unreachable!(),
|
||||
Text => Self::Text,
|
||||
Binary => Self::Binary,
|
||||
Close => Self::Close,
|
||||
Ping => Self::Ping,
|
||||
Pong => Self::Pong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Frame<'_>> for crate::ws::Frame {
|
||||
fn from(mut frame: Frame) -> Self {
|
||||
Self {
|
||||
finished: frame.fin,
|
||||
opcode: frame.opcode.into(),
|
||||
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<crate::ws::Frame> for Frame<'_> {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(frame: crate::ws::Frame) -> Result<Self, Self::Error> {
|
||||
use crate::ws::OpCode::*;
|
||||
Ok(match frame.opcode {
|
||||
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
||||
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
|
||||
Ping => Self::new(
|
||||
true,
|
||||
OpCode::Ping,
|
||||
None,
|
||||
Payload::Owned(frame.payload.to_vec()),
|
||||
),
|
||||
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WebSocketError> for crate::WispError {
|
||||
fn from(err: WebSocketError) -> Self {
|
||||
Self::WsImplError(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
|
||||
) -> Result<crate::ws::Frame, crate::WispError> {
|
||||
Ok(self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.await?
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
||||
self.write_frame(frame.try_into()?).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
390
wisp/src/lib.rs
Normal file
390
wisp/src/lib.rs
Normal file
|
@ -0,0 +1,390 @@
|
|||
#![feature(impl_trait_in_assoc_type)]
|
||||
#[cfg(feature = "fastwebsockets")]
|
||||
mod fastwebsockets;
|
||||
mod packet;
|
||||
mod stream;
|
||||
#[cfg(feature = "hyper_tower")]
|
||||
pub mod tokioio;
|
||||
#[cfg(feature = "hyper_tower")]
|
||||
pub mod tower;
|
||||
pub mod ws;
|
||||
#[cfg(feature = "ws_stream_wasm")]
|
||||
mod ws_stream_wasm;
|
||||
|
||||
pub use crate::packet::*;
|
||||
pub use crate::stream::*;
|
||||
|
||||
use event_listener::Event;
|
||||
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub enum Role {
|
||||
Client,
|
||||
Server,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WispError {
|
||||
PacketTooSmall,
|
||||
InvalidPacketType,
|
||||
InvalidStreamType,
|
||||
InvalidStreamId,
|
||||
InvalidUri,
|
||||
UriHasNoHost,
|
||||
UriHasNoPort,
|
||||
MaxStreamCountReached,
|
||||
StreamAlreadyClosed,
|
||||
WsFrameInvalidType,
|
||||
WsFrameNotFinished,
|
||||
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||
WsImplSocketClosed,
|
||||
WsImplNotSupported,
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
fn from(err: std::str::Utf8Error) -> WispError {
|
||||
WispError::Utf8Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WispError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
use WispError::*;
|
||||
match self {
|
||||
PacketTooSmall => write!(f, "Packet too small"),
|
||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
InvalidStreamId => write!(f, "Invalid stream id"),
|
||||
InvalidUri => write!(f, "Invalid URI"),
|
||||
UriHasNoHost => write!(f, "URI has no host"),
|
||||
UriHasNoPort => write!(f, "URI has no port"),
|
||||
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||
StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
||||
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
||||
Other(err) => write!(f, "Other error: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WispError {}
|
||||
|
||||
struct ServerMuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
||||
pub async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||
buffer_size: u32
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let ret = futures::select! {
|
||||
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
|
||||
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
|
||||
};
|
||||
self.stream_map.lock().await.iter().for_each(|x| {
|
||||
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01)));
|
||||
});
|
||||
ret
|
||||
}
|
||||
|
||||
async fn server_close_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
MuxEvent::Close(stream_id, reason, channel) => {
|
||||
if stream_map.lock().await.remove(&stream_id).is_some() {
|
||||
let _ = channel.send(
|
||||
tx.write_frame(Packet::new_close(stream_id, reason).into())
|
||||
.await,
|
||||
);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_msg_loop<R>(
|
||||
&self,
|
||||
mut rx: R,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||
buffer_size: u32,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.tx
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
|
||||
if let Ok(packet) = Packet::try_from(frame) {
|
||||
use PacketType::*;
|
||||
match packet.packet {
|
||||
Connect(inner_packet) => {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
self.stream_map.lock().await.insert(packet.stream_id, ch_tx);
|
||||
muxstream_sender
|
||||
.unbounded_send((
|
||||
inner_packet,
|
||||
MuxStream::new(
|
||||
packet.stream_id,
|
||||
Role::Server,
|
||||
ch_rx,
|
||||
self.tx.clone(),
|
||||
self.close_tx.clone(),
|
||||
AtomicBool::new(false).into(),
|
||||
AtomicU32::new(buffer_size).into(),
|
||||
Event::new().into(),
|
||||
),
|
||||
))
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.unbounded_send(WsEvent::Send(data));
|
||||
}
|
||||
}
|
||||
Continue(_) => unreachable!(),
|
||||
Close(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
||||
}
|
||||
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
drop(muxstream_sender);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerMux<W>
|
||||
where
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||
pub fn new<R>(read: R, write: W, buffer_size: u32) -> (Self, impl Future<Output = Result<(), WispError>>)
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::unbounded::<MuxEvent>();
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
(
|
||||
Self { muxstream_recv: rx },
|
||||
ServerMuxInner {
|
||||
tx: write,
|
||||
close_tx,
|
||||
stream_map: map.clone(),
|
||||
}
|
||||
.into_future(read, close_rx, tx, buffer_size),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||
self.muxstream_recv.next().await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientMuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map:
|
||||
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||
pub async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
futures::select! {
|
||||
x = self.client_bg_loop(close_rx).fuse() => x,
|
||||
x = self.client_loop(rx).fuse() => x
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_bg_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
MuxEvent::Close(stream_id, reason, channel) => {
|
||||
if self.stream_map.lock().await.remove(&stream_id).is_some() {
|
||||
let _ = channel.send(
|
||||
self.tx
|
||||
.write_frame(Packet::new_close(stream_id, reason).into())
|
||||
.await,
|
||||
);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
while let Ok(frame) = rx.wisp_read_frame(&self.tx).await {
|
||||
if let Ok(packet) = Packet::try_from(frame) {
|
||||
use PacketType::*;
|
||||
match packet.packet {
|
||||
Connect(_) => unreachable!(),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.0.unbounded_send(WsEvent::Send(data));
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
stream
|
||||
.1
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.2.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
Close(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.0.unbounded_send(WsEvent::Close(inner_packet));
|
||||
}
|
||||
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientMux<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map:
|
||||
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
|
||||
next_free_stream_id: AtomicU32,
|
||||
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
||||
buf_size: u32,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
||||
pub async fn new<R>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
||||
if first_packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
if let PacketType::Continue(packet) = first_packet.packet {
|
||||
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
Ok((
|
||||
Self {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
next_free_stream_id: AtomicU32::new(1),
|
||||
close_tx: tx,
|
||||
buf_size: packet.buffer_remaining,
|
||||
},
|
||||
ClientMuxInner {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
}
|
||||
.into_future(read, rx),
|
||||
))
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn client_new_stream(
|
||||
&self,
|
||||
stream_type: StreamType,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<MuxStream<W>, WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
let evt: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
|
||||
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
|
||||
self.tx
|
||||
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
|
||||
.await?;
|
||||
self.next_free_stream_id.store(
|
||||
stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?,
|
||||
Ordering::Release,
|
||||
);
|
||||
self.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.insert(stream_id, (ch_tx, flow_control.clone(), evt.clone()));
|
||||
Ok(MuxStream::new(
|
||||
stream_id,
|
||||
Role::Client,
|
||||
ch_rx,
|
||||
self.tx.clone(),
|
||||
self.close_tx.clone(),
|
||||
AtomicBool::new(false).into(),
|
||||
flow_control,
|
||||
evt,
|
||||
))
|
||||
}
|
||||
}
|
255
wisp/src/packet.rs
Normal file
255
wisp/src/packet.rs
Normal file
|
@ -0,0 +1,255 @@
|
|||
use crate::ws;
|
||||
use crate::WispError;
|
||||
use bytes::{Buf, BufMut, Bytes};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StreamType {
|
||||
Tcp = 0x01,
|
||||
Udp = 0x02,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for StreamType {
|
||||
type Error = WispError;
|
||||
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
|
||||
use StreamType::*;
|
||||
match stream_type {
|
||||
0x01 => Ok(Tcp),
|
||||
0x02 => Ok(Udp),
|
||||
_ => Err(Self::Error::InvalidStreamType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectPacket {
|
||||
pub stream_type: StreamType,
|
||||
pub destination_port: u16,
|
||||
pub destination_hostname: String,
|
||||
}
|
||||
|
||||
impl ConnectPacket {
|
||||
pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self {
|
||||
Self {
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ConnectPacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < (1 + 2) {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
stream_type: bytes.get_u8().try_into()?,
|
||||
destination_port: bytes.get_u16_le(),
|
||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConnectPacket> for Vec<u8> {
|
||||
fn from(packet: ConnectPacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||
encoded.put_u8(packet.stream_type as u8);
|
||||
encoded.put_u16_le(packet.destination_port);
|
||||
encoded.extend(packet.destination_hostname.bytes());
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ContinuePacket {
|
||||
pub buffer_remaining: u32,
|
||||
}
|
||||
|
||||
impl ContinuePacket {
|
||||
pub fn new(buffer_remaining: u32) -> Self {
|
||||
Self { buffer_remaining }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ContinuePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 4 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
buffer_remaining: bytes.get_u32_le(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ContinuePacket> for Vec<u8> {
|
||||
fn from(packet: ContinuePacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(4);
|
||||
encoded.put_u32_le(packet.buffer_remaining);
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClosePacket {
|
||||
pub reason: u8,
|
||||
}
|
||||
|
||||
impl ClosePacket {
|
||||
pub fn new(reason: u8) -> Self {
|
||||
Self { reason }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for ClosePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
reason: bytes.get_u8(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClosePacket> for Vec<u8> {
|
||||
fn from(packet: ClosePacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(1);
|
||||
encoded.put_u8(packet.reason);
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PacketType {
|
||||
Connect(ConnectPacket),
|
||||
Data(Bytes),
|
||||
Continue(ContinuePacket),
|
||||
Close(ClosePacket),
|
||||
}
|
||||
|
||||
impl PacketType {
|
||||
pub fn as_u8(&self) -> u8 {
|
||||
use PacketType::*;
|
||||
match self {
|
||||
Connect(_) => 0x01,
|
||||
Data(_) => 0x02,
|
||||
Continue(_) => 0x03,
|
||||
Close(_) => 0x04,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PacketType> for Vec<u8> {
|
||||
fn from(packet: PacketType) -> Self {
|
||||
use PacketType::*;
|
||||
match packet {
|
||||
Connect(x) => x.into(),
|
||||
Data(x) => x.to_vec(),
|
||||
Continue(x) => x.into(),
|
||||
Close(x) => x.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Packet {
|
||||
pub stream_id: u32,
|
||||
pub packet: PacketType,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new(stream_id: u32, packet: PacketType) -> Self {
|
||||
Self { stream_id, packet }
|
||||
}
|
||||
|
||||
pub fn new_connect(
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet: PacketType::Connect(ConnectPacket::new(
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet: PacketType::Data(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_close(stream_id: u32, reason: u8) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet: PacketType::Close(ClosePacket::new(reason)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bytes> for Packet {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 5 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
use PacketType::*;
|
||||
Ok(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet: match packet_type {
|
||||
0x01 => Connect(ConnectPacket::try_from(bytes)?),
|
||||
0x02 => Data(bytes),
|
||||
0x03 => Continue(ContinuePacket::try_from(bytes)?),
|
||||
0x04 => Close(ClosePacket::try_from(bytes)?),
|
||||
_ => return Err(Self::Error::InvalidPacketType),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for Vec<u8> {
|
||||
fn from(packet: Packet) -> Self {
|
||||
let mut encoded = Self::with_capacity(1 + 4);
|
||||
encoded.push(packet.packet.as_u8());
|
||||
encoded.put_u32_le(packet.stream_id);
|
||||
encoded.extend(Vec::<u8>::from(packet.packet));
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ws::Frame> for Packet {
|
||||
type Error = WispError;
|
||||
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> {
|
||||
if !frame.finished {
|
||||
return Err(Self::Error::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != ws::OpCode::Binary {
|
||||
return Err(Self::Error::WsFrameInvalidType);
|
||||
}
|
||||
frame.payload.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for ws::Frame {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Self::binary(Vec::<u8>::from(packet).into())
|
||||
}
|
||||
}
|
298
wisp/src/stream.rs
Normal file
298
wisp/src/stream.rs
Normal file
|
@ -0,0 +1,298 @@
|
|||
use async_io_stream::IoStream;
|
||||
use bytes::Bytes;
|
||||
use event_listener::Event;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
sink, stream,
|
||||
task::{Context, Poll},
|
||||
Sink, Stream, StreamExt,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
pub enum WsEvent {
|
||||
Send(Bytes),
|
||||
Close(crate::ClosePacket),
|
||||
}
|
||||
|
||||
pub enum MuxEvent {
|
||||
Close(u32, u8, oneshot::Sender<Result<(), crate::WispError>>),
|
||||
}
|
||||
|
||||
pub struct MuxStreamRead<W>
|
||||
where
|
||||
W: crate::ws::WebSocketWrite,
|
||||
{
|
||||
pub stream_id: u32,
|
||||
role: crate::Role,
|
||||
tx: crate::ws::LockedWebSocketWrite<W>,
|
||||
rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
|
||||
pub async fn read(&mut self) -> Option<WsEvent> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
match self.rx.next().await? {
|
||||
WsEvent::Send(bytes) => {
|
||||
if self.role == crate::Role::Server {
|
||||
let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst);
|
||||
self.tx
|
||||
.write_frame(
|
||||
crate::Packet::new_continue(self.stream_id, old_val + 1).into(),
|
||||
)
|
||||
.await
|
||||
.ok()?;
|
||||
}
|
||||
Some(WsEvent::Send(bytes))
|
||||
}
|
||||
WsEvent::Close(packet) => {
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
Some(WsEvent::Close(packet))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
|
||||
Box::pin(stream::unfold(self, |mut rx| async move {
|
||||
let evt = rx.read().await?;
|
||||
Some((
|
||||
match evt {
|
||||
WsEvent::Send(bytes) => bytes,
|
||||
WsEvent::Close(_) => return None,
|
||||
},
|
||||
rx,
|
||||
))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MuxStreamWrite<W>
|
||||
where
|
||||
W: crate::ws::WebSocketWrite,
|
||||
{
|
||||
pub stream_id: u32,
|
||||
role: crate::Role,
|
||||
tx: crate::ws::LockedWebSocketWrite<W>,
|
||||
close_channel: mpsc::UnboundedSender<MuxEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
continue_recieved: Arc<Event>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(crate::WispError::StreamAlreadyClosed);
|
||||
}
|
||||
if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) <= 0 {
|
||||
self.continue_recieved.listen().await;
|
||||
}
|
||||
self.tx
|
||||
.write_frame(crate::Packet::new_data(self.stream_id, data).into())
|
||||
.await?;
|
||||
if self.role == crate::Role::Client {
|
||||
self.flow_control.store(
|
||||
self.flow_control
|
||||
.load(Ordering::Acquire)
|
||||
.checked_add(1)
|
||||
.unwrap_or(0),
|
||||
Ordering::Release,
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
||||
MuxStreamCloser {
|
||||
stream_id: self.stream_id,
|
||||
close_channel: self.close_channel.clone(),
|
||||
is_closed: self.is_closed.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(crate::WispError::StreamAlreadyClosed);
|
||||
}
|
||||
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
|
||||
self.close_channel
|
||||
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx))
|
||||
.map_err(|x| crate::WispError::Other(Box::new(x)))?;
|
||||
rx.await
|
||||
.map_err(|x| crate::WispError::Other(Box::new(x)))??;
|
||||
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>> {
|
||||
Box::pin(sink::unfold(self, |tx, data| async move {
|
||||
tx.write(data).await?;
|
||||
Ok(tx)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: crate::ws::WebSocketWrite> Drop for MuxStreamWrite<W> {
|
||||
fn drop(&mut self) {
|
||||
let (tx, _) = oneshot::channel::<Result<(), crate::WispError>>();
|
||||
let _ = self
|
||||
.close_channel
|
||||
.unbounded_send(MuxEvent::Close(self.stream_id, 0x01, tx));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MuxStream<W>
|
||||
where
|
||||
W: crate::ws::WebSocketWrite,
|
||||
{
|
||||
pub stream_id: u32,
|
||||
rx: MuxStreamRead<W>,
|
||||
tx: MuxStreamWrite<W>,
|
||||
}
|
||||
|
||||
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
|
||||
pub(crate) fn new(
|
||||
stream_id: u32,
|
||||
role: crate::Role,
|
||||
rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
tx: crate::ws::LockedWebSocketWrite<W>,
|
||||
close_channel: mpsc::UnboundedSender<MuxEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
continue_recieved: Arc<Event>
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
rx: MuxStreamRead {
|
||||
stream_id,
|
||||
role,
|
||||
tx: tx.clone(),
|
||||
rx,
|
||||
is_closed: is_closed.clone(),
|
||||
flow_control: flow_control.clone(),
|
||||
},
|
||||
tx: MuxStreamWrite {
|
||||
stream_id,
|
||||
role,
|
||||
tx,
|
||||
close_channel,
|
||||
is_closed: is_closed.clone(),
|
||||
flow_control: flow_control.clone(),
|
||||
continue_recieved: continue_recieved.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read(&mut self) -> Option<WsEvent> {
|
||||
self.rx.read().await
|
||||
}
|
||||
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
|
||||
self.tx.write(data).await
|
||||
}
|
||||
|
||||
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
||||
self.tx.get_close_handle()
|
||||
}
|
||||
|
||||
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
|
||||
self.tx.close(reason).await
|
||||
}
|
||||
|
||||
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
|
||||
(self.rx, self.tx)
|
||||
}
|
||||
|
||||
pub fn into_io(self) -> MuxStreamIo {
|
||||
MuxStreamIo {
|
||||
rx: self.rx.into_stream(),
|
||||
tx: self.tx.into_sink(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MuxStreamCloser {
|
||||
stream_id: u32,
|
||||
close_channel: mpsc::UnboundedSender<MuxEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl MuxStreamCloser {
|
||||
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(crate::WispError::StreamAlreadyClosed);
|
||||
}
|
||||
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
|
||||
self.close_channel
|
||||
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx))
|
||||
.map_err(|x| crate::WispError::Other(Box::new(x)))?;
|
||||
rx.await
|
||||
.map_err(|x| crate::WispError::Other(Box::new(x)))??;
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct MuxStreamIo {
|
||||
#[pin]
|
||||
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
|
||||
#[pin]
|
||||
tx: Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + Send>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl MuxStreamIo {
|
||||
pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> {
|
||||
IoStream::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for MuxStreamIo {
|
||||
type Item = Result<Vec<u8>, std::io::Error>;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project()
|
||||
.rx
|
||||
.poll_next(cx)
|
||||
.map(|x| x.map(|x| Ok(x.to_vec())))
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Vec<u8>> for MuxStreamIo {
|
||||
type Error = std::io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item.into())
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
#![allow(dead_code)]
|
||||
// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs
|
||||
// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need
|
||||
// hyper-util for TokioIo.
|
||||
// hyper-util fails to compile on WASM as it has a dependency on socket2
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
|
@ -169,3 +168,9 @@ where
|
|||
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> hyper_util::client::legacy::connect::Connection for TokioIo<T> {
|
||||
fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
|
||||
hyper_util::client::legacy::connect::Connected::new()
|
||||
}
|
||||
}
|
41
wisp/src/tower.rs
Normal file
41
wisp/src/tower.rs
Normal file
|
@ -0,0 +1,41 @@
|
|||
use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError};
|
||||
use async_io_stream::IoStream;
|
||||
use futures::{
|
||||
task::{Context, Poll},
|
||||
Future,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>);
|
||||
|
||||
impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> {
|
||||
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
|
||||
type Error = WispError;
|
||||
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: hyper::Uri) -> Self::Future {
|
||||
let mux = self.0.clone();
|
||||
async move {
|
||||
Ok(TokioIo::new(
|
||||
mux.client_new_stream(
|
||||
StreamType::Tcp,
|
||||
req.host().ok_or(WispError::UriHasNoHost)?.to_string(),
|
||||
req.port().ok_or(WispError::UriHasNoPort)?.into(),
|
||||
)
|
||||
.await?
|
||||
.into_io()
|
||||
.into_asyncrw(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: WebSocketWrite + Send + 'static> Clone for ServiceWrapper<W> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
76
wisp/src/ws.rs
Normal file
76
wisp/src/ws.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use bytes::Bytes;
|
||||
use futures::lock::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum OpCode {
|
||||
Text,
|
||||
Binary,
|
||||
Close,
|
||||
Ping,
|
||||
Pong,
|
||||
}
|
||||
|
||||
pub struct Frame {
|
||||
pub finished: bool,
|
||||
pub opcode: OpCode,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
pub fn text(payload: Bytes) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Text,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary(payload: Bytes) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Binary,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(payload: Bytes) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Close,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WebSocketRead {
|
||||
fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
|
||||
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send;
|
||||
}
|
||||
|
||||
pub trait WebSocketWrite {
|
||||
fn wisp_write_frame(
|
||||
&mut self,
|
||||
frame: Frame,
|
||||
) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send;
|
||||
}
|
||||
|
||||
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
|
||||
|
||||
impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> {
|
||||
pub fn new(ws: S) -> Self {
|
||||
Self(Arc::new(Mutex::new(ws)))
|
||||
}
|
||||
|
||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
60
wisp/src/ws_stream_wasm.rs
Normal file
60
wisp/src/ws_stream_wasm.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use futures::{stream::{SplitStream, SplitSink}, SinkExt, StreamExt};
|
||||
use ws_stream_wasm::{WsErr, WsMessage, WsStream};
|
||||
|
||||
impl From<WsMessage> for crate::ws::Frame {
|
||||
fn from(msg: WsMessage) -> Self {
|
||||
use crate::ws::OpCode;
|
||||
match msg {
|
||||
WsMessage::Text(str) => Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Text,
|
||||
payload: str.into(),
|
||||
},
|
||||
WsMessage::Binary(bin) => Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Binary,
|
||||
payload: bin.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<crate::ws::Frame> for WsMessage {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(msg: crate::ws::Frame) -> Result<Self, Self::Error> {
|
||||
use crate::ws::OpCode;
|
||||
match msg.opcode {
|
||||
OpCode::Text => Ok(Self::Text(std::str::from_utf8(&msg.payload)?.to_string())),
|
||||
OpCode::Binary => Ok(Self::Binary(msg.payload.to_vec())),
|
||||
_ => Err(Self::Error::WsImplNotSupported),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WsErr> for crate::WispError {
|
||||
fn from(err: WsErr) -> Self {
|
||||
Self::WsImplError(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::ws::WebSocketRead for SplitStream<WsStream> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
_: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
||||
) -> Result<crate::ws::Frame, crate::WispError> {
|
||||
Ok(self
|
||||
.next()
|
||||
.await
|
||||
.ok_or(crate::WispError::WsImplSocketClosed)?
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::ws::WebSocketWrite for SplitSink<WsStream, WsMessage> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
||||
self
|
||||
.send(frame.try_into()?)
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue