use hyper client

This commit is contained in:
r58Playz 2024-02-05 19:10:40 -08:00
parent 6ca14ad26a
commit b16fb8f654
12 changed files with 297 additions and 342 deletions

72
Cargo.lock generated
View file

@ -235,12 +235,14 @@ dependencies = [
"http",
"http-body-util",
"hyper",
"hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)",
"js-sys",
"pin-project-lite",
"ring",
"tokio",
"tokio-rustls",
"tokio-util",
"tower-service",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
@ -260,13 +262,19 @@ dependencies = [
"futures-util",
"http-body-util",
"hyper",
"hyper-util",
"hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio",
"tokio-native-tls",
"tokio-util",
"wisp-mux",
]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.8"
@ -292,7 +300,7 @@ dependencies = [
"base64",
"http-body-util",
"hyper",
"hyper-util",
"hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
"pin-project",
"rand",
"sha1",
@ -451,6 +459,25 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "h2"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.3"
@ -518,6 +545,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
@ -530,9 +558,24 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.2"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
dependencies = [
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"socket2",
"tokio",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
source = "git+https://github.com/r58Playz/hyper-util-wasm#40813384dc4971677cd2a9aeb90f61b392a5bb70"
dependencies = [
"bytes",
"futures-channel",
@ -541,11 +584,21 @@ dependencies = [
"http-body",
"hyper",
"pin-project-lite",
"socket2",
"tokio",
"tower",
"tower-service",
"tracing",
]
[[package]]
name = "indexmap"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f"
dependencies = [
"equivalent",
"hashbrown",
]
[[package]]
name = "itoa"
version = "1.0.10"
@ -1160,6 +1213,10 @@ version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
@ -1494,9 +1551,10 @@ dependencies = [
"futures",
"futures-util",
"hyper",
"hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)",
"pin-project-lite",
"tokio",
"tower",
"tower-service",
"ws_stream_wasm",
]

View file

@ -10,9 +10,8 @@ crate-type = ["cdylib"]
bytes = "1.5.0"
http = "1.0.0"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["client", "http1"] }
hyper = { version = "1.1.0", features = ["client", "http1", "http2"] }
pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", default_features = false }
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4.39"
ws_stream_wasm = { version = "0.7.4", features = ["tokio_io"] }
@ -26,9 +25,12 @@ tokio-util = { version = "0.7.10", features = ["io"] }
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
base64 = "0.21.7"
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] }
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io", "hyper_tower"] }
async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
getrandom = { version = "0.2.12", features = ["js"] }
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] }
tokio = { version = "1.36.0", default-features = false }
tower-service = "0.3.2"
[dependencies.ring]
features = ["wasm32_unknown_unknown_js"]

View file

@ -50,7 +50,13 @@
["https://httpbin.org/gzip", {}],
["https://httpbin.org/brotli", {}],
["https://httpbin.org/redirect/11", {}],
["https://httpbin.org/redirect/1", { redirect: "manual" }]
["https://httpbin.org/redirect/1", { redirect: "manual" }],
["https://nghttp2.org/httpbin/get", {}],
["https://nghttp2.org/httpbin/gzip", {}],
["https://nghttp2.org/httpbin/brotli", {}],
["https://nghttp2.org/httpbin/redirect/11", {}],
["https://nghttp2.org/httpbin/redirect/1", { redirect: "manual" }]
]) {
let resp = await epoxy_client.fetch(url[0], url[1]);
console.warn(url, resp, Object.fromEntries(resp.headers));

View file

@ -1,16 +1,14 @@
#![feature(let_chains)]
#![feature(let_chains, impl_trait_in_assoc_type)]
#[macro_use]
mod utils;
mod tls_stream;
mod tokioio;
mod websocket;
mod wrappers;
use tls_stream::EpxTlsStream;
use tokioio::TokioIo;
use utils::{ReplaceErr, UriExt};
use websocket::EpxWebSocket;
use wrappers::IncomingBody;
use wrappers::{IncomingBody, TlsWispService};
use std::sync::Arc;
@ -19,7 +17,8 @@ use async_io_stream::IoStream;
use bytes::Bytes;
use futures_util::{stream::SplitSink, StreamExt};
use http::{uri, HeaderName, HeaderValue, Request, Response};
use hyper::{body::Incoming, client::conn::http1::Builder, Uri};
use hyper::{body::Incoming, Uri};
use hyper_util::client::legacy::Client;
use js_sys::{Array, Function, Object, Reflect, Uint8Array};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
use tokio_util::{
@ -28,7 +27,7 @@ use tokio_util::{
};
use wasm_bindgen::prelude::*;
use web_sys::TextEncoder;
use wisp_mux::{ClientMux, MuxStreamIo, StreamType};
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, ClientMux, MuxStreamIo, StreamType};
use ws_stream_wasm::{WsMessage, WsMeta, WsStream};
type HttpBody = http_body_util::Full<Bytes>;
@ -36,7 +35,7 @@ type HttpBody = http_body_util::Full<Bytes>;
#[derive(Debug)]
enum EpxResponse {
Success(Response<Incoming>),
Redirect((Response<Incoming>, http::Request<HttpBody>, Uri)),
Redirect((Response<Incoming>, http::Request<HttpBody>)),
}
enum EpxCompression {
@ -48,73 +47,11 @@ type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
async fn send_req(
req: http::Request<HttpBody>,
should_redirect: bool,
io: EpxIoStream,
) -> Result<EpxResponse, JsError> {
let (mut req_sender, conn) = Builder::new()
.title_case_headers(true)
.preserve_header_case(true)
.handshake(TokioIo::new(io))
.await
.replace_err("Failed to connect to host")?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.await {
error!("epoxy: error in muxed hyper connection! {:?}", e);
}
});
let new_req = if should_redirect {
Some(req.clone())
} else {
None
};
debug!("sending req");
let res = req_sender
.send_request(req)
.await
.replace_err("Failed to send request");
debug!("recieved res");
match res {
Ok(res) => {
if utils::is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url
.clone()
.authority()
.replace_err("Redirect URL must have an authority")
.ok()
{
let should_strip = new_req.uri().is_same_host(&redirect_url);
if should_strip {
new_req.headers_mut().remove("authorization");
new_req.headers_mut().remove("cookie");
new_req.headers_mut().remove("www-authenticate");
}
let new_url = redirect_url.clone();
*new_req.uri_mut() = redirect_url;
new_req.headers_mut().insert(
"Host",
HeaderValue::from_str(redirect_url_authority.as_str())?,
);
Ok(EpxResponse::Redirect((res, new_req, new_url)))
} else {
Ok(EpxResponse::Success(res))
}
}
Err(err) => Err(err),
}
}
#[wasm_bindgen]
pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>,
mux: ClientMux<SplitSink<WsStream, WsMessage>>,
mux: Arc<ClientMux<SplitSink<WsStream, WsMessage>>>,
hyper_client: Client<TlsWispService<SplitSink<WsStream, WsMessage>>, HttpBody>,
useragent: String,
redirect_limit: usize,
}
@ -145,6 +82,7 @@ impl EpoxyClient {
debug!("connected!");
let (wtx, wrx) = ws.split();
let (mux, fut) = ClientMux::new(wrx, wtx);
let mux = Arc::new(mux);
wasm_bindgen_futures::spawn_local(async move {
if let Err(err) = fut.await {
@ -162,7 +100,15 @@ impl EpoxyClient {
);
Ok(EpoxyClient {
mux,
mux: mux.clone(),
hyper_client: Client::builder(utils::WasmExecutor {})
.http09_responses(true)
.http1_title_case_headers(true)
.http1_preserve_header_case(true)
.build(TlsWispService {
rustls_config: rustls_config.clone(),
service: ServiceWrapper(mux),
}),
rustls_config,
useragent,
redirect_limit,
@ -193,26 +139,53 @@ impl EpoxyClient {
Ok(io)
}
async fn get_http_io(&self, url: &Uri) -> Result<EpxIoStream, JsError> {
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = utils::get_url_port(url)?;
if utils::get_is_secure(url)? {
Ok(EpxIoStream::Left(
self.get_tls_io(url_host, url_port).await?,
))
async fn send_req_inner(
&self,
req: http::Request<HttpBody>,
should_redirect: bool,
) -> Result<EpxResponse, JsError> {
let new_req = if should_redirect {
Some(req.clone())
} else {
debug!("making channel");
let channel = self
.mux
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
.await
.replace_err("Failed to create multiplexor channel")?
.into_io()
.into_asyncrw();
debug!("connecting channel");
debug!("connected channel");
Ok(EpxIoStream::Right(channel))
None
};
debug!("sending req");
let res = self
.hyper_client
.request(req)
.await
.replace_err("Failed to send request");
debug!("recieved res");
match res {
Ok(res) => {
if utils::is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url
.clone()
.authority()
.replace_err("Redirect URL must have an authority")
.ok()
{
let should_strip = new_req.uri().is_same_host(&redirect_url);
if should_strip {
new_req.headers_mut().remove("authorization");
new_req.headers_mut().remove("cookie");
new_req.headers_mut().remove("www-authenticate");
}
*new_req.uri_mut() = redirect_url;
new_req.headers_mut().insert(
"Host",
HeaderValue::from_str(redirect_url_authority.as_str())?,
);
Ok(EpxResponse::Redirect((res, new_req)))
} else {
Ok(EpxResponse::Success(res))
}
}
Err(err) => Err(err),
}
}
@ -222,23 +195,22 @@ impl EpoxyClient {
should_redirect: bool,
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> {
let mut redirected = false;
let uri = req.uri().clone();
let mut current_resp: EpxResponse =
send_req(req, should_redirect, self.get_http_io(&uri).await?).await?;
let mut current_url = req.uri().clone();
let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?;
for _ in 0..self.redirect_limit - 1 {
match current_resp {
EpxResponse::Success(_) => break,
EpxResponse::Redirect((_, req, new_url)) => {
EpxResponse::Redirect((_, req)) => {
redirected = true;
current_resp =
send_req(req, should_redirect, self.get_http_io(&new_url).await?).await?
current_url = req.uri().clone();
current_resp = self.send_req_inner(req, should_redirect).await?
}
}
}
match current_resp {
EpxResponse::Success(resp) => Ok((resp, uri, redirected)),
EpxResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)),
EpxResponse::Success(resp) => Ok((resp, current_url, redirected)),
EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)),
}
}
@ -353,7 +325,7 @@ impl EpoxyClient {
.body(HttpBody::new(body_bytes))
.replace_err("Failed to make request")?;
let (resp, last_url, req_redirected) = self.send_req(request, req_should_redirect).await?;
let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?;
let resp_headers_raw = resp.headers().clone();
@ -417,7 +389,7 @@ impl EpoxyClient {
Object::define_property(
&resp,
&jval!("url"),
&utils::define_property_obj(jval!(last_url.to_string()), false)
&utils::define_property_obj(jval!(resp_uri.to_string()), false)
.replace_err("Failed to make define_property object for url")?,
);

View file

@ -1,171 +0,0 @@
#![allow(dead_code)]
// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs
// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need
// hyper-util for TokioIo.
use std::{
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
pin_project! {
/// A wrapping implementing hyper IO traits for a type that
/// implements Tokio's IO traits.
#[derive(Debug)]
pub struct TokioIo<T> {
#[pin]
inner: T,
}
}
impl<T> TokioIo<T> {
/// Wrap a type implementing Tokio's IO traits.
pub fn new(inner: T) -> Self {
Self { inner }
}
/// Borrow the inner type.
pub fn inner(&self) -> &T {
&self.inner
}
/// Mut borrow the inner type.
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
/// Consume this wrapper and get the inner type.
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> tokio::io::AsyncRead for TokioIo<T>
where
T: hyper::rt::Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len();
let filled = tbuf.filled().len();
let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(),
other => return other,
}
};
let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized.
let n_init = sub_filled;
unsafe {
tbuf.assume_init(n_init);
tbuf.set_filled(n_filled);
}
Poll::Ready(Ok(()))
}
}
impl<T> tokio::io::AsyncWrite for TokioIo<T>
where
T: hyper::rt::Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
hyper::rt::Write::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}

View file

@ -1,8 +1,9 @@
use wasm_bindgen::prelude::*;
use hyper::rt::Executor;
use hyper::{header::HeaderValue, Uri};
use http::uri;
use js_sys::{Array, Object};
use std::future::Future;
#[wasm_bindgen]
extern "C" {
@ -97,6 +98,21 @@ impl UriExt for Uri {
}
}
#[derive(Clone)]
pub struct WasmExecutor;
impl<F> Executor<F> for WasmExecutor
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
fn execute(&self, future: F) {
wasm_bindgen_futures::spawn_local(async move {
let _ = future.await;
});
}
}
pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
js_sys::Object::entries(obj)
.to_vec()
@ -126,41 +142,19 @@ pub fn is_redirect(code: u16) -> bool {
}
pub fn get_is_secure(url: &Uri) -> Result<bool, JsError> {
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *url_scheme == uri::Scheme::HTTP {
Ok(false)
} else if *url_scheme == uri::Scheme::HTTPS {
Ok(true)
} else if url_scheme_str == "ws" {
Ok(false)
} else if url_scheme_str == "wss" {
Ok(true)
} else {
return Ok(false);
match url_scheme_str {
"https" | "wss" => Ok(true),
_ => Ok(false),
}
}
pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?;
if let Some(port) = url.port() {
Ok(port.as_u16())
} else if get_is_secure(url)? {
Ok(443)
} else {
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *url_scheme == uri::Scheme::HTTP {
Ok(80)
} else if *url_scheme == uri::Scheme::HTTPS {
Ok(443)
} else if url_scheme_str == "ws" {
Ok(80)
} else if url_scheme_str == "wss" {
Ok(443)
} else {
return Err(jerr!("Failed to coerce port from scheme"));
}
Ok(80)
}
}

View file

@ -5,7 +5,7 @@ use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
};
use futures_util::lock::Mutex;
use http_body_util::Empty;
use http_body_util::Full;
use hyper::{
header::{CONNECTION, UPGRADE},
upgrade::Upgraded,
@ -63,23 +63,9 @@ impl EpxWebSocket {
builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", "));
}
let req = builder.body(Empty::<Bytes>::new())?;
let req = builder.body(Full::<Bytes>::new(Bytes::new()))?;
let stream = tcp.get_http_io(&url).await?;
let (mut sender, conn) = Builder::new()
.title_case_headers(true)
.preserve_header_case(true)
.handshake::<TokioIo<EpxIoStream>, Empty<Bytes>>(TokioIo::new(stream))
.await?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.with_upgrades().await {
error!("epoxy: error in muxed hyper connection (ws)! {:?}", e);
}
});
let mut response = sender.send_request(req).await?;
let mut response = tcp.hyper_client.request(req).await?;
verify(&response)?;
let ws = WebSocket::after_handshake(

View file

@ -7,6 +7,8 @@ use std::{
use futures_util::Stream;
use hyper::body::Body;
use pin_project_lite::pin_project;
use std::future::Future;
use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, WispError};
pin_project! {
pub struct IncomingBody {
@ -30,7 +32,8 @@ impl Stream for IncomingBody {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item {
Some(frame) => frame
.map(|x| {
x.into_data().map_err(|_| std::io::Error::other("not data frame"))
x.into_data()
.map_err(|_| std::io::Error::other("not data frame"))
})
.ok(),
None => None,
@ -39,3 +42,68 @@ impl Stream for IncomingBody {
}
}
}
pub struct TlsWispService<W>
where
W: wisp_mux::ws::WebSocketWrite + Send + 'static,
{
pub service: ServiceWrapper<W>,
pub rustls_config: Arc<rustls::ClientConfig>,
}
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri>
for TlsWispService<W>
{
type Response = TokioIo<EpxIoStream>;
type Error = WispError;
type Future = Pin<Box<impl Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: http::Uri) -> Self::Future {
let mut service = self.service.clone();
let rustls_config = self.rustls_config.clone();
Box::pin(async move {
let uri_host = req
.host()
.ok_or(WispError::UriHasNoHost)?
.to_string()
.clone();
let uri_parsed = Uri::builder()
.authority(format!(
"{}:{}",
uri_host,
utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)?
))
.build()
.map_err(|x| WispError::Other(Box::new(x)))?;
let stream = service.call(uri_parsed).await?.into_inner();
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
let connector = TlsConnector::from(rustls_config);
Ok(TokioIo::new(Either::Left(
connector
.connect(
uri_host.try_into().map_err(|_| WispError::InvalidUri)?,
stream,
)
.await
.map_err(|x| WispError::Other(Box::new(x)))?,
)))
} else {
Ok(TokioIo::new(Either::Right(stream)))
}
})
}
}
impl<W: wisp_mux::ws::WebSocketWrite + Send + 'static> Clone for TlsWispService<W> {
fn clone(&self) -> Self {
Self {
rustls_config: self.rustls_config.clone(),
service: self.service.clone(),
}
}
}

View file

@ -10,14 +10,15 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional =
futures = "0.3.30"
futures-util = "0.3.30"
hyper = { version = "1.1.0", optional = true }
hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy"], optional = true }
pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true }
tower = { version = "0.4.13", optional = true }
tokio = { version = "1.35.1", optional = true, default-features = false }
tower-service = { version = "0.3.2", optional = true }
ws_stream_wasm = { version = "0.7.4", optional = true }
[features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
ws_stream_wasm = ["dep:ws_stream_wasm"]
tokio_io = ["async_io_stream/tokio_io"]
hyper_tower = ["dep:tower", "dep:hyper", "dep:tokio"]
hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util"]

View file

@ -35,6 +35,9 @@ pub enum WispError {
InvalidPacketType,
InvalidStreamType,
InvalidStreamId,
InvalidUri,
UriHasNoHost,
UriHasNoPort,
MaxStreamCountReached,
StreamAlreadyClosed,
WsFrameInvalidType,
@ -60,6 +63,9 @@ impl std::fmt::Display for WispError {
InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"),
InvalidUri => write!(f, "Invalid URI"),
UriHasNoHost => write!(f, "URI has no host"),
UriHasNoPort => write!(f, "URI has no port"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
@ -329,7 +335,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream<impl ws::WebSocketWrite>, WispError> {
) -> Result<MuxStream<W>, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
self.tx

View file

@ -1,7 +1,6 @@
#![allow(dead_code)]
// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs
// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need
// hyper-util for TokioIo.
// hyper-util fails to compile on WASM as it has a dependency on socket2
use std::{
pin::Pin,
@ -169,3 +168,9 @@ where
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> hyper_util::client::legacy::connect::Connection for TokioIo<T> {
fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
hyper_util::client::legacy::connect::Connected::new()
}
}

View file

@ -1,13 +1,41 @@
use futures::{Future, task::{Poll, Context}};
use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError};
use async_io_stream::IoStream;
use futures::{
task::{Context, Poll},
Future,
};
use std::sync::Arc;
impl<W: crate::ws::WebSocketWrite> tower::Service<hyper::Uri> for crate::ClientMux<W> {
type Response = crate::tokioio::TokioIo<crate::MuxStream<W>>;
type Error = crate::WispError;
pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>);
impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> {
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Uri) -> Self::Future {
fn call(&mut self, req: hyper::Uri) -> Self::Future {
let mux = self.0.clone();
async move {
Ok(TokioIo::new(
mux.client_new_stream(
StreamType::Tcp,
req.host().ok_or(WispError::UriHasNoHost)?.to_string(),
req.port().ok_or(WispError::UriHasNoPort)?.into(),
)
.await?
.into_io()
.into_asyncrw(),
))
}
}
}
impl<W: WebSocketWrite + Send + 'static> Clone for ServiceWrapper<W> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}