redirects

This commit is contained in:
Toshit Chawda 2024-01-11 22:21:37 -08:00
parent f92062c5f5
commit 7bb39ae069
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 201 additions and 70 deletions

21
Cargo.lock generated
View file

@ -165,6 +165,12 @@ dependencies = [
"crypto-common",
]
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "errno"
version = "0.3.8"
@ -1049,6 +1055,19 @@ dependencies = [
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tracing"
version = "0.1.40"
@ -1437,6 +1456,7 @@ version = "1.0.0"
dependencies = [
"bytes",
"console_error_panic_hook",
"either",
"futures-util",
"getrandom",
"http 1.0.0",
@ -1448,6 +1468,7 @@ dependencies = [
"ring",
"tokio",
"tokio-rustls",
"tokio-util",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",

View file

@ -27,6 +27,8 @@ webpki-roots = "0.26.0"
tokio-rustls = "0.25.0"
web-sys = { version = "0.3.66", features = ["TextEncoder", "Navigator", "Response", "ResponseInit"] }
wasm-streams = "0.4.0"
either = "1.9.0"
tokio-util = "0.7.10"
[dependencies.getrandom]
features = ["js"]

View file

@ -1,32 +1,42 @@
#![feature(let_chains)]
#[macro_use]
mod utils;
mod tokioio;
mod wrappers;
use tokioio::TokioIo;
use utils::ReplaceErr;
use utils::{ReplaceErr, UriExt};
use wrappers::{IncomingBody, WsStreamWrapper};
use std::sync::Arc;
use bytes::Bytes;
use http::{uri, HeaderName, HeaderValue, Request, Response};
use hyper::{body::Incoming, client::conn as hyper_conn};
use hyper::{body::Incoming, client::conn as hyper_conn, Uri};
use js_sys::{Array, Object, Reflect, Uint8Array};
use penguin_mux_wasm::{Multiplexor, Role};
use tokio_rustls::{rustls, rustls::RootCertStore, TlsConnector};
use penguin_mux_wasm::{Multiplexor, MuxStream, Role};
use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector};
use tokio_util::either::Either;
use wasm_bindgen::prelude::*;
use web_sys::TextEncoder;
type HttpBody = http_body_util::Full<Bytes>;
async fn send_req<T>(req: http::Request<HttpBody>, io: T) -> Result<Response<Incoming>, JsError>
where
T: hyper::rt::Read + hyper::rt::Write + std::marker::Unpin + 'static,
{
let (mut req_sender, conn) = hyper_conn::http1::handshake::<T, HttpBody>(io)
.await
.replace_err("Failed to connect to host")?;
#[derive(Debug)]
enum WsTcpResponse {
Success(Response<Incoming>),
Redirect((Response<Incoming>, http::Request<HttpBody>, Uri)),
}
type WsTcpTlsStream = TlsStream<MuxStream<WsStreamWrapper>>;
type WsTcpUnencryptedStream = MuxStream<WsStreamWrapper>;
type WsTcpStream = Either<WsTcpTlsStream, WsTcpUnencryptedStream>;
async fn send_req(req: http::Request<HttpBody>, io: WsTcpStream) -> Result<WsTcpResponse, JsError> {
let (mut req_sender, conn) =
hyper_conn::http1::handshake::<TokioIo<WsTcpStream>, HttpBody>(TokioIo::new(io))
.await
.replace_err("Failed to connect to host")?;
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = conn.await {
@ -34,10 +44,38 @@ where
}
});
req_sender
let mut new_req = req.clone();
let res = req_sender
.send_request(req)
.await
.replace_err("Failed to send request")
.replace_err("Failed to send request");
match res {
Ok(res) => {
if utils::is_redirect(res.status().as_u16())
&& let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url.clone().authority().replace_err("Redirect URL must have an authority").ok()
{
let should_strip = new_req.uri().is_same_host(&redirect_url);
if should_strip {
new_req.headers_mut().remove("authorization");
new_req.headers_mut().remove("cookie");
new_req.headers_mut().remove("www-authenticate");
}
let new_url = redirect_url.clone();
*new_req.uri_mut() = redirect_url;
new_req.headers_mut().remove("Host");
new_req
.headers_mut()
.insert("Host", HeaderValue::from_str(redirect_url_authority.as_str())?);
Ok(WsTcpResponse::Redirect((res, new_req, new_url)))
} else {
Ok(WsTcpResponse::Success(res))
}
}
Err(err) => Err(err),
}
}
#[wasm_bindgen(start)]
@ -50,12 +88,13 @@ pub struct WsTcp {
rustls_config: Arc<rustls::ClientConfig>,
mux: Multiplexor<WsStreamWrapper>,
useragent: String,
redirect_limit: usize,
}
#[wasm_bindgen]
impl WsTcp {
#[wasm_bindgen(constructor)]
pub async fn new(ws_url: String, useragent: String) -> Result<WsTcp, JsError> {
pub async fn new(ws_url: String, useragent: String, redirect_limit: usize) -> Result<WsTcp, JsError> {
let ws_uri = ws_url
.parse::<uri::Uri>()
.replace_err("Failed to parse websocket url")?;
@ -87,9 +126,60 @@ impl WsTcp {
mux,
rustls_config,
useragent,
redirect_limit
})
}
async fn get_http_io(&self, url: &Uri) -> Result<WsTcpStream, JsError> {
let url_host = url.host().replace_err("URL must have a host")?;
let url_port = utils::get_url_port(url)?;
let channel = self
.mux
.client_new_stream_channel(url_host.as_bytes(), url_port)
.await
.replace_err("Failed to create multiplexor channel")?;
if *url.scheme().replace_err("URL must have a scheme")? == uri::Scheme::HTTPS {
let cloned_uri = url_host.to_string().clone();
let connector = TlsConnector::from(self.rustls_config.clone());
let io = connector
.connect(
cloned_uri
.try_into()
.replace_err("Failed to parse URL (rustls)")?,
channel,
)
.await
.replace_err("Failed to perform TLS handshake")?;
Ok(WsTcpStream::Left(io))
} else {
Ok(WsTcpStream::Right(channel))
}
}
async fn send_req(
&self,
req: http::Request<HttpBody>,
) -> Result<(hyper::Response<Incoming>, Uri, bool), JsError> {
let mut redirected = false;
let uri = req.uri().clone();
let mut current_resp: WsTcpResponse = send_req(req, self.get_http_io(&uri).await?).await?;
for _ in 0..self.redirect_limit-1 {
match current_resp {
WsTcpResponse::Success(_) => break,
WsTcpResponse::Redirect((_, req, new_url)) => {
redirected = true;
current_resp = send_req(req, self.get_http_io(&new_url).await?).await?
}
}
}
match current_resp {
WsTcpResponse::Success(resp) => Ok((resp, uri, redirected)),
WsTcpResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)),
}
}
pub async fn fetch(&self, url: String, options: Object) -> Result<web_sys::Response, JsError> {
let uri = url.parse::<uri::Uri>().replace_err("Failed to parse URL")?;
let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?;
@ -97,19 +187,6 @@ impl WsTcp {
return Err(jerr!("Scheme must be either `http` or `https`"));
}
let uri_host = uri.host().replace_err("URL must have a host")?;
let uri_port = if let Some(port) = uri.port() {
port.as_u16()
} else {
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *uri_scheme == uri::Scheme::HTTP {
80
} else if *uri_scheme == uri::Scheme::HTTPS {
443
} else {
return Err(jerr!("Failed to coerce port from scheme"));
}
};
let req_method_string: String = match Reflect::get(&options, &jval!("method")) {
Ok(val) => val.as_string().unwrap_or("GET".to_string()),
@ -174,30 +251,7 @@ impl WsTcp {
.body(HttpBody::new(body_bytes))
.replace_err("Failed to make request")?;
let channel = self
.mux
.client_new_stream_channel(uri_host.as_bytes(), uri_port)
.await
.replace_err("Failed to create multiplexor channel")?;
let resp: hyper::Response<Incoming>;
if *uri_scheme == uri::Scheme::HTTPS {
let cloned_uri = uri_host.to_string().clone();
let connector = TlsConnector::from(self.rustls_config.clone());
let io = connector
.connect(
cloned_uri
.try_into()
.replace_err("Failed to parse URL (rustls)")?,
channel,
)
.await
.replace_err("Failed to perform TLS handshake")?;
resp = send_req(request, TokioIo::new(io)).await?;
} else {
resp = send_req(request, TokioIo::new(channel)).await?;
}
let (resp, last_url, req_redirected) = self.send_req(request).await?;
let resp_headers_jsarray = resp
.headers()
@ -231,10 +285,17 @@ impl WsTcp {
Object::define_property(
&resp,
&jval!("url"),
&utils::define_property_obj(jval!(url), false)
&utils::define_property_obj(jval!(last_url.to_string()), false)
.replace_err("Failed to make define_property object for url")?,
);
Object::define_property(
&resp,
&jval!("redirected"),
&utils::define_property_obj(jval!(req_redirected), false)
.replace_err("Failed to make define_property object for redirected")?,
);
Ok(resp)
}
}

View file

@ -1,5 +1,7 @@
use wasm_bindgen::prelude::*;
use hyper::{header::HeaderValue, Uri};
use http::uri;
use js_sys::{Array, Object};
pub fn set_panic_hook() {
@ -77,6 +79,29 @@ impl<T> ReplaceErr for Option<T> {
}
}
pub trait UriExt {
fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, JsError>;
fn is_same_host(&self, other: &Uri) -> bool;
}
impl UriExt for Uri {
fn get_redirect(&self, location: &HeaderValue) -> Result<Uri, JsError> {
let new_uri = location.to_str()?.parse::<hyper::Uri>()?;
let mut new_parts: http::uri::Parts = new_uri.into();
if new_parts.scheme.is_none() {
new_parts.scheme = self.scheme().cloned();
}
if new_parts.authority.is_none() {
new_parts.authority = self.authority().cloned();
}
Ok(Uri::from_parts(new_parts)?)
}
fn is_same_host(&self, other: &Uri) -> bool {
self.host() == other.host() && self.port() == other.port()
}
}
pub fn entries_of_object(obj: &Object) -> Vec<Vec<String>> {
js_sys::Object::entries(obj)
.to_vec()
@ -96,6 +121,28 @@ pub fn define_property_obj(value: JsValue, writable: bool) -> Result<Object, JsV
Array::of2(&jval!("value"), &jval!(value)),
Array::of2(&jval!("writable"), &jval!(writable)),
]
.iter().collect::<Array>();
.iter()
.collect::<Array>();
Object::from_entries(&entries)
}
pub fn is_redirect(code: u16) -> bool {
[301, 302, 303, 307, 308].contains(&code)
}
pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
let url_scheme = url.scheme().replace_err("URL must have a scheme")?;
if let Some(port) = url.port() {
Ok(port.as_u16())
} else {
// can't use match, compiler error
// error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]`
if *url_scheme == uri::Scheme::HTTP {
Ok(80)
} else if *url_scheme == uri::Scheme::HTTPS {
Ok(443)
} else {
return Err(jerr!("Failed to coerce port from scheme"));
}
}
}

View file

@ -1,20 +1,20 @@
(async () => {
console.log(
"%cWASM is significantly slower with DevTools open!",
"color:red;font-size:2rem;font-weight:bold"
);
await wasm_bindgen("./wstcp_client_bg.wasm");
const tconn0 = performance.now();
let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent);
const tconn1 = performance.now();
console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
const t0 = performance.now();
let resp = await wstcp.fetch("https://httpbin.org/post", { method: "POST", body: "test", headers: { "X-Header-One": "one", "x-header-one": "One", "X-Header-Two": "two" } });
const t1 = performance.now();
console.log(
"%cWASM is significantly slower with DevTools open!",
"color:red;font-size:2rem;font-weight:bold"
);
await wasm_bindgen("./wstcp_client_bg.wasm");
const tconn0 = performance.now();
// args: websocket url, user agent, redirect limit
let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10);
const tconn1 = performance.now();
console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`);
const t0 = performance.now();
let resp = await wstcp.fetch("http://httpbin.org/redirect/11");
const t1 = performance.now();
console.warn(resp);
console.warn(await fetch("https://httpbin.org/post", { method: "POST", body: "test", headers: { "X-Header-One": "one", "x-header-one": "One", "X-Header-Two": "two" } }));
console.warn(Object.fromEntries(resp.headers));
console.warn(await resp.text());
console.warn(`mux 1 took ${t1 - t0} ms or ${(t1 - t0) / 1000} s`);
console.warn(resp);
console.warn(Object.fromEntries(resp.headers));
console.warn(await resp.text());
console.warn(`mux 1 took ${t1 - t0} ms or ${(t1 - t0) / 1000} s`);
})();