diff --git a/Cargo.lock b/Cargo.lock index eed8bfa..0d6939f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -235,12 +235,14 @@ dependencies = [ "http", "http-body-util", "hyper", + "hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)", "js-sys", "pin-project-lite", "ring", "tokio", "tokio-rustls", "tokio-util", + "tower-service", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", @@ -260,13 +262,19 @@ dependencies = [ "futures-util", "http-body-util", "hyper", - "hyper-util", + "hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "tokio", "tokio-native-tls", "tokio-util", "wisp-mux", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.8" @@ -292,7 +300,7 @@ dependencies = [ "base64", "http-body-util", "hyper", - "hyper-util", + "hyper-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "pin-project", "rand", "sha1", @@ -451,6 +459,25 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -518,6 +545,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2", "http", "http-body", "httparse", @@ -530,9 +558,24 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "git+https://github.com/r58Playz/hyper-util-wasm#40813384dc4971677cd2a9aeb90f61b392a5bb70" dependencies = [ "bytes", "futures-channel", @@ -541,11 +584,21 @@ dependencies = [ "http-body", "hyper", "pin-project-lite", - "socket2", - "tokio", + "tower", + "tower-service", "tracing", ] +[[package]] +name = "indexmap" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "itoa" version = "1.0.10" @@ -1160,6 +1213,10 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", "tower-layer", "tower-service", "tracing", @@ -1494,9 +1551,10 @@ dependencies = [ "futures", "futures-util", "hyper", + "hyper-util 0.1.3 (git+https://github.com/r58Playz/hyper-util-wasm)", "pin-project-lite", "tokio", - "tower", + "tower-service", "ws_stream_wasm", ] diff --git a/client/Cargo.toml b/client/Cargo.toml index 3f8efd3..6e64d1d 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -10,9 +10,8 @@ crate-type = ["cdylib"] bytes = "1.5.0" http = "1.0.0" http-body-util = "0.1.0" -hyper = { version = "1.1.0", features = ["client", "http1"] } +hyper = { version = "1.1.0", features = ["client", "http1", "http2"] } pin-project-lite = "0.2.13" -tokio = { version = "1.35.1", default_features = false } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.39" ws_stream_wasm = { version = "0.7.4", features = ["tokio_io"] } @@ -26,9 +25,12 @@ tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io"] } +wisp-mux = { path = "../wisp", features = ["ws_stream_wasm", "tokio_io", "hyper_tower"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] } getrandom = { version = "0.2.12", features = ["js"] } +hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy", "http1", "http2"] } +tokio = { version = "1.36.0", default-features = false } +tower-service = "0.3.2" [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/demo.js b/client/demo.js index 30379b8..6f07fc6 100644 --- a/client/demo.js +++ b/client/demo.js @@ -50,7 +50,13 @@ ["https://httpbin.org/gzip", {}], ["https://httpbin.org/brotli", {}], ["https://httpbin.org/redirect/11", {}], - ["https://httpbin.org/redirect/1", { redirect: "manual" }] + ["https://httpbin.org/redirect/1", { redirect: "manual" }], + ["https://nghttp2.org/httpbin/get", {}], + ["https://nghttp2.org/httpbin/gzip", {}], + ["https://nghttp2.org/httpbin/brotli", {}], + ["https://nghttp2.org/httpbin/redirect/11", {}], + ["https://nghttp2.org/httpbin/redirect/1", { redirect: "manual" }] + ]) { let resp = await epoxy_client.fetch(url[0], url[1]); console.warn(url, resp, Object.fromEntries(resp.headers)); diff --git a/client/src/lib.rs b/client/src/lib.rs index 639c786..07214de 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,16 +1,14 @@ -#![feature(let_chains)] +#![feature(let_chains, impl_trait_in_assoc_type)] #[macro_use] mod utils; mod tls_stream; -mod tokioio; mod websocket; mod wrappers; use tls_stream::EpxTlsStream; -use tokioio::TokioIo; use utils::{ReplaceErr, UriExt}; use websocket::EpxWebSocket; -use wrappers::IncomingBody; +use wrappers::{IncomingBody, TlsWispService}; use std::sync::Arc; @@ -19,7 +17,8 @@ use async_io_stream::IoStream; use bytes::Bytes; use futures_util::{stream::SplitSink, StreamExt}; use http::{uri, HeaderName, HeaderValue, Request, Response}; -use hyper::{body::Incoming, client::conn::http1::Builder, Uri}; +use hyper::{body::Incoming, Uri}; +use hyper_util::client::legacy::Client; use js_sys::{Array, Function, Object, Reflect, Uint8Array}; use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; use tokio_util::{ @@ -28,7 +27,7 @@ use tokio_util::{ }; use wasm_bindgen::prelude::*; use web_sys::TextEncoder; -use wisp_mux::{ClientMux, MuxStreamIo, StreamType}; +use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, ClientMux, MuxStreamIo, StreamType}; use ws_stream_wasm::{WsMessage, WsMeta, WsStream}; type HttpBody = http_body_util::Full; @@ -36,7 +35,7 @@ type HttpBody = http_body_util::Full; #[derive(Debug)] enum EpxResponse { Success(Response), - Redirect((Response, http::Request, Uri)), + Redirect((Response, http::Request)), } enum EpxCompression { @@ -48,73 +47,11 @@ type EpxIoTlsStream = TlsStream>>; type EpxIoUnencryptedStream = IoStream>; type EpxIoStream = Either; -async fn send_req( - req: http::Request, - should_redirect: bool, - io: EpxIoStream, -) -> Result { - let (mut req_sender, conn) = Builder::new() - .title_case_headers(true) - .preserve_header_case(true) - .handshake(TokioIo::new(io)) - .await - .replace_err("Failed to connect to host")?; - - wasm_bindgen_futures::spawn_local(async move { - if let Err(e) = conn.await { - error!("epoxy: error in muxed hyper connection! {:?}", e); - } - }); - - let new_req = if should_redirect { - Some(req.clone()) - } else { - None - }; - - debug!("sending req"); - let res = req_sender - .send_request(req) - .await - .replace_err("Failed to send request"); - debug!("recieved res"); - match res { - Ok(res) => { - if utils::is_redirect(res.status().as_u16()) - && let Some(mut new_req) = new_req - && let Some(location) = res.headers().get("Location") - && let Ok(redirect_url) = new_req.uri().get_redirect(location) - && let Some(redirect_url_authority) = redirect_url - .clone() - .authority() - .replace_err("Redirect URL must have an authority") - .ok() - { - let should_strip = new_req.uri().is_same_host(&redirect_url); - if should_strip { - new_req.headers_mut().remove("authorization"); - new_req.headers_mut().remove("cookie"); - new_req.headers_mut().remove("www-authenticate"); - } - let new_url = redirect_url.clone(); - *new_req.uri_mut() = redirect_url; - new_req.headers_mut().insert( - "Host", - HeaderValue::from_str(redirect_url_authority.as_str())?, - ); - Ok(EpxResponse::Redirect((res, new_req, new_url))) - } else { - Ok(EpxResponse::Success(res)) - } - } - Err(err) => Err(err), - } -} - #[wasm_bindgen] pub struct EpoxyClient { rustls_config: Arc, - mux: ClientMux>, + mux: Arc>>, + hyper_client: Client>, HttpBody>, useragent: String, redirect_limit: usize, } @@ -145,6 +82,7 @@ impl EpoxyClient { debug!("connected!"); let (wtx, wrx) = ws.split(); let (mux, fut) = ClientMux::new(wrx, wtx); + let mux = Arc::new(mux); wasm_bindgen_futures::spawn_local(async move { if let Err(err) = fut.await { @@ -162,7 +100,15 @@ impl EpoxyClient { ); Ok(EpoxyClient { - mux, + mux: mux.clone(), + hyper_client: Client::builder(utils::WasmExecutor {}) + .http09_responses(true) + .http1_title_case_headers(true) + .http1_preserve_header_case(true) + .build(TlsWispService { + rustls_config: rustls_config.clone(), + service: ServiceWrapper(mux), + }), rustls_config, useragent, redirect_limit, @@ -193,26 +139,53 @@ impl EpoxyClient { Ok(io) } - async fn get_http_io(&self, url: &Uri) -> Result { - let url_host = url.host().replace_err("URL must have a host")?; - let url_port = utils::get_url_port(url)?; - - if utils::get_is_secure(url)? { - Ok(EpxIoStream::Left( - self.get_tls_io(url_host, url_port).await?, - )) + async fn send_req_inner( + &self, + req: http::Request, + should_redirect: bool, + ) -> Result { + let new_req = if should_redirect { + Some(req.clone()) } else { - debug!("making channel"); - let channel = self - .mux - .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) - .await - .replace_err("Failed to create multiplexor channel")? - .into_io() - .into_asyncrw(); - debug!("connecting channel"); - debug!("connected channel"); - Ok(EpxIoStream::Right(channel)) + None + }; + + debug!("sending req"); + let res = self + .hyper_client + .request(req) + .await + .replace_err("Failed to send request"); + debug!("recieved res"); + match res { + Ok(res) => { + if utils::is_redirect(res.status().as_u16()) + && let Some(mut new_req) = new_req + && let Some(location) = res.headers().get("Location") + && let Ok(redirect_url) = new_req.uri().get_redirect(location) + && let Some(redirect_url_authority) = redirect_url + .clone() + .authority() + .replace_err("Redirect URL must have an authority") + .ok() + { + let should_strip = new_req.uri().is_same_host(&redirect_url); + if should_strip { + new_req.headers_mut().remove("authorization"); + new_req.headers_mut().remove("cookie"); + new_req.headers_mut().remove("www-authenticate"); + } + *new_req.uri_mut() = redirect_url; + new_req.headers_mut().insert( + "Host", + HeaderValue::from_str(redirect_url_authority.as_str())?, + ); + Ok(EpxResponse::Redirect((res, new_req))) + } else { + Ok(EpxResponse::Success(res)) + } + } + Err(err) => Err(err), } } @@ -222,23 +195,22 @@ impl EpoxyClient { should_redirect: bool, ) -> Result<(hyper::Response, Uri, bool), JsError> { let mut redirected = false; - let uri = req.uri().clone(); - let mut current_resp: EpxResponse = - send_req(req, should_redirect, self.get_http_io(&uri).await?).await?; + let mut current_url = req.uri().clone(); + let mut current_resp: EpxResponse = self.send_req_inner(req, should_redirect).await?; for _ in 0..self.redirect_limit - 1 { match current_resp { EpxResponse::Success(_) => break, - EpxResponse::Redirect((_, req, new_url)) => { + EpxResponse::Redirect((_, req)) => { redirected = true; - current_resp = - send_req(req, should_redirect, self.get_http_io(&new_url).await?).await? + current_url = req.uri().clone(); + current_resp = self.send_req_inner(req, should_redirect).await? } } } match current_resp { - EpxResponse::Success(resp) => Ok((resp, uri, redirected)), - EpxResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)), + EpxResponse::Success(resp) => Ok((resp, current_url, redirected)), + EpxResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), } } @@ -353,7 +325,7 @@ impl EpoxyClient { .body(HttpBody::new(body_bytes)) .replace_err("Failed to make request")?; - let (resp, last_url, req_redirected) = self.send_req(request, req_should_redirect).await?; + let (resp, resp_uri, req_redirected) = self.send_req(request, req_should_redirect).await?; let resp_headers_raw = resp.headers().clone(); @@ -417,7 +389,7 @@ impl EpoxyClient { Object::define_property( &resp, &jval!("url"), - &utils::define_property_obj(jval!(last_url.to_string()), false) + &utils::define_property_obj(jval!(resp_uri.to_string()), false) .replace_err("Failed to make define_property object for url")?, ); diff --git a/client/src/tokioio.rs b/client/src/tokioio.rs deleted file mode 100644 index 7d6acc0..0000000 --- a/client/src/tokioio.rs +++ /dev/null @@ -1,171 +0,0 @@ -#![allow(dead_code)] -// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs -// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need -// hyper-util for TokioIo. - -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use pin_project_lite::pin_project; - -pin_project! { - /// A wrapping implementing hyper IO traits for a type that - /// implements Tokio's IO traits. - #[derive(Debug)] - pub struct TokioIo { - #[pin] - inner: T, - } -} - -impl TokioIo { - /// Wrap a type implementing Tokio's IO traits. - pub fn new(inner: T) -> Self { - Self { inner } - } - - /// Borrow the inner type. - pub fn inner(&self) -> &T { - &self.inner - } - - /// Mut borrow the inner type. - pub fn inner_mut(&mut self) -> &mut T { - &mut self.inner - } - - /// Consume this wrapper and get the inner type. - pub fn into_inner(self) -> T { - self.inner - } -} - -impl hyper::rt::Read for TokioIo -where - T: tokio::io::AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let n = unsafe { - let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); - match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { - Poll::Ready(Ok(())) => tbuf.filled().len(), - other => return other, - } - }; - - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl hyper::rt::Write for TokioIo -where - T: tokio::io::AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - tokio::io::AsyncWrite::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } -} - -impl tokio::io::AsyncRead for TokioIo -where - T: hyper::rt::Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tbuf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - //let init = tbuf.initialized().len(); - let filled = tbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); - - match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; - - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tbuf.assume_init(n_init); - tbuf.set_filled(n_filled); - } - - Poll::Ready(Ok(())) - } -} - -impl tokio::io::AsyncWrite for TokioIo -where - T: hyper::rt::Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - hyper::rt::Write::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - hyper::rt::Write::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - hyper::rt::Write::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - hyper::rt::Write::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) - } -} diff --git a/client/src/utils.rs b/client/src/utils.rs index caf4498..0c71583 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -1,8 +1,9 @@ use wasm_bindgen::prelude::*; +use hyper::rt::Executor; use hyper::{header::HeaderValue, Uri}; -use http::uri; use js_sys::{Array, Object}; +use std::future::Future; #[wasm_bindgen] extern "C" { @@ -97,6 +98,21 @@ impl UriExt for Uri { } } +#[derive(Clone)] +pub struct WasmExecutor; + +impl Executor for WasmExecutor +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, future: F) { + wasm_bindgen_futures::spawn_local(async move { + let _ = future.await; + }); + } +} + pub fn entries_of_object(obj: &Object) -> Vec> { js_sys::Object::entries(obj) .to_vec() @@ -126,41 +142,19 @@ pub fn is_redirect(code: u16) -> bool { } pub fn get_is_secure(url: &Uri) -> Result { - let url_scheme = url.scheme().replace_err("URL must have a scheme")?; let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?; - // can't use match, compiler error - // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` - if *url_scheme == uri::Scheme::HTTP { - Ok(false) - } else if *url_scheme == uri::Scheme::HTTPS { - Ok(true) - } else if url_scheme_str == "ws" { - Ok(false) - } else if url_scheme_str == "wss" { - Ok(true) - } else { - return Ok(false); + match url_scheme_str { + "https" | "wss" => Ok(true), + _ => Ok(false), } } pub fn get_url_port(url: &Uri) -> Result { - let url_scheme = url.scheme().replace_err("URL must have a scheme")?; - let url_scheme_str = url.scheme_str().replace_err("URL must have a scheme")?; if let Some(port) = url.port() { Ok(port.as_u16()) + } else if get_is_secure(url)? { + Ok(443) } else { - // can't use match, compiler error - // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` - if *url_scheme == uri::Scheme::HTTP { - Ok(80) - } else if *url_scheme == uri::Scheme::HTTPS { - Ok(443) - } else if url_scheme_str == "ws" { - Ok(80) - } else if url_scheme_str == "wss" { - Ok(443) - } else { - return Err(jerr!("Failed to coerce port from scheme")); - } + Ok(80) } } diff --git a/client/src/websocket.rs b/client/src/websocket.rs index f823077..d186b17 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -5,7 +5,7 @@ use fastwebsockets::{ CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, }; use futures_util::lock::Mutex; -use http_body_util::Empty; +use http_body_util::Full; use hyper::{ header::{CONNECTION, UPGRADE}, upgrade::Upgraded, @@ -63,23 +63,9 @@ impl EpxWebSocket { builder = builder.header("Sec-WebSocket-Protocol", protocols.join(", ")); } - let req = builder.body(Empty::::new())?; + let req = builder.body(Full::::new(Bytes::new()))?; - let stream = tcp.get_http_io(&url).await?; - - let (mut sender, conn) = Builder::new() - .title_case_headers(true) - .preserve_header_case(true) - .handshake::, Empty>(TokioIo::new(stream)) - .await?; - - wasm_bindgen_futures::spawn_local(async move { - if let Err(e) = conn.with_upgrades().await { - error!("epoxy: error in muxed hyper connection (ws)! {:?}", e); - } - }); - - let mut response = sender.send_request(req).await?; + let mut response = tcp.hyper_client.request(req).await?; verify(&response)?; let ws = WebSocket::after_handshake( diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 8526a98..5df0814 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -7,6 +7,8 @@ use std::{ use futures_util::Stream; use hyper::body::Body; use pin_project_lite::pin_project; +use std::future::Future; +use wisp_mux::{tokioio::TokioIo, tower::ServiceWrapper, WispError}; pin_project! { pub struct IncomingBody { @@ -30,7 +32,8 @@ impl Stream for IncomingBody { Poll::Ready(item) => Poll::>::Ready(match item { Some(frame) => frame .map(|x| { - x.into_data().map_err(|_| std::io::Error::other("not data frame")) + x.into_data() + .map_err(|_| std::io::Error::other("not data frame")) }) .ok(), None => None, @@ -39,3 +42,68 @@ impl Stream for IncomingBody { } } } + +pub struct TlsWispService +where + W: wisp_mux::ws::WebSocketWrite + Send + 'static, +{ + pub service: ServiceWrapper, + pub rustls_config: Arc, +} + + +impl tower_service::Service + for TlsWispService +{ + type Response = TokioIo; + type Error = WispError; + type Future = Pin>>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: http::Uri) -> Self::Future { + let mut service = self.service.clone(); + let rustls_config = self.rustls_config.clone(); + Box::pin(async move { + let uri_host = req + .host() + .ok_or(WispError::UriHasNoHost)? + .to_string() + .clone(); + let uri_parsed = Uri::builder() + .authority(format!( + "{}:{}", + uri_host, + utils::get_url_port(&req).map_err(|_| WispError::UriHasNoPort)? + )) + .build() + .map_err(|x| WispError::Other(Box::new(x)))?; + let stream = service.call(uri_parsed).await?.into_inner(); + if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { + let connector = TlsConnector::from(rustls_config); + Ok(TokioIo::new(Either::Left( + connector + .connect( + uri_host.try_into().map_err(|_| WispError::InvalidUri)?, + stream, + ) + .await + .map_err(|x| WispError::Other(Box::new(x)))?, + ))) + } else { + Ok(TokioIo::new(Either::Right(stream))) + } + }) + } +} + +impl Clone for TlsWispService { + fn clone(&self) -> Self { + Self { + rustls_config: self.rustls_config.clone(), + service: self.service.clone(), + } + } +} diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 9448613..42f65e2 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,14 +10,15 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = futures = "0.3.30" futures-util = "0.3.30" hyper = { version = "1.1.0", optional = true } +hyper-util = { git = "https://github.com/r58Playz/hyper-util-wasm", features = ["client", "client-legacy"], optional = true } pin-project-lite = "0.2.13" -tokio = { version = "1.35.1", optional = true } -tower = { version = "0.4.13", optional = true } +tokio = { version = "1.35.1", optional = true, default-features = false } +tower-service = { version = "0.3.2", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] ws_stream_wasm = ["dep:ws_stream_wasm"] tokio_io = ["async_io_stream/tokio_io"] -hyper_tower = ["dep:tower", "dep:hyper", "dep:tokio"] +hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util"] diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index e211f13..9ee6785 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -35,6 +35,9 @@ pub enum WispError { InvalidPacketType, InvalidStreamType, InvalidStreamId, + InvalidUri, + UriHasNoHost, + UriHasNoPort, MaxStreamCountReached, StreamAlreadyClosed, WsFrameInvalidType, @@ -60,6 +63,9 @@ impl std::fmt::Display for WispError { InvalidPacketType => write!(f, "Invalid packet type"), InvalidStreamType => write!(f, "Invalid stream type"), InvalidStreamId => write!(f, "Invalid stream id"), + InvalidUri => write!(f, "Invalid URI"), + UriHasNoHost => write!(f, "URI has no host"), + UriHasNoPort => write!(f, "URI has no port"), MaxStreamCountReached => write!(f, "Maximum stream count reached"), StreamAlreadyClosed => write!(f, "Stream already closed"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"), @@ -329,7 +335,7 @@ impl ClientMux { stream_type: StreamType, host: String, port: u16, - ) -> Result, WispError> { + ) -> Result, WispError> { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_id = self.next_free_stream_id.load(Ordering::Acquire); self.tx diff --git a/wisp/src/tokioio.rs b/wisp/src/tokioio.rs index 7d6acc0..a3ca7be 100644 --- a/wisp/src/tokioio.rs +++ b/wisp/src/tokioio.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] // Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs -// hyper-util fails to compile on WASM as it has a dependency on socket2, but I only need -// hyper-util for TokioIo. +// hyper-util fails to compile on WASM as it has a dependency on socket2 use std::{ pin::Pin, @@ -169,3 +168,9 @@ where hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) } } + +impl hyper_util::client::legacy::connect::Connection for TokioIo { + fn connected(&self) -> hyper_util::client::legacy::connect::Connected { + hyper_util::client::legacy::connect::Connected::new() + } +} diff --git a/wisp/src/tower.rs b/wisp/src/tower.rs index 6bf635c..06f3ebc 100644 --- a/wisp/src/tower.rs +++ b/wisp/src/tower.rs @@ -1,13 +1,41 @@ -use futures::{Future, task::{Poll, Context}}; +use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError}; +use async_io_stream::IoStream; +use futures::{ + task::{Context, Poll}, + Future, +}; +use std::sync::Arc; -impl tower::Service for crate::ClientMux { - type Response = crate::tokioio::TokioIo>; - type Error = crate::WispError; +pub struct ServiceWrapper(pub Arc>); + +impl tower_service::Service for ServiceWrapper { + type Response = TokioIo>>; + type Error = WispError; type Future = impl Future>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: hyper::Uri) -> Self::Future { + fn call(&mut self, req: hyper::Uri) -> Self::Future { + let mux = self.0.clone(); + async move { + Ok(TokioIo::new( + mux.client_new_stream( + StreamType::Tcp, + req.host().ok_or(WispError::UriHasNoHost)?.to_string(), + req.port().ok_or(WispError::UriHasNoPort)?.into(), + ) + .await? + .into_io() + .into_asyncrw(), + )) + } + } +} + +impl Clone for ServiceWrapper { + fn clone(&self) -> Self { + Self(self.0.clone()) } }