From 7bb39ae069cc592aa8c80a9717eab2f03c9058ca Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Thu, 11 Jan 2024 22:21:37 -0800 Subject: [PATCH] redirects --- Cargo.lock | 21 +++++ client/Cargo.toml | 2 + client/src/lib.rs | 165 +++++++++++++++++++++++++++------------- client/src/utils.rs | 49 +++++++++++- client/src/web/index.js | 34 ++++----- 5 files changed, 201 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab3968a..472c372 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "errno" version = "0.3.8" @@ -1049,6 +1055,19 @@ dependencies = [ "tungstenite", ] +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" @@ -1437,6 +1456,7 @@ version = "1.0.0" dependencies = [ "bytes", "console_error_panic_hook", + "either", "futures-util", "getrandom", "http 1.0.0", @@ -1448,6 +1468,7 @@ dependencies = [ "ring", "tokio", "tokio-rustls", + "tokio-util", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", diff --git a/client/Cargo.toml b/client/Cargo.toml index 3e722dd..7455715 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -27,6 +27,8 @@ webpki-roots = "0.26.0" tokio-rustls = "0.25.0" web-sys = { version = "0.3.66", features = ["TextEncoder", "Navigator", "Response", "ResponseInit"] } wasm-streams = "0.4.0" +either = "1.9.0" +tokio-util = "0.7.10" [dependencies.getrandom] features = ["js"] diff --git a/client/src/lib.rs b/client/src/lib.rs index b85780e..8ec7992 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,32 +1,42 @@ +#![feature(let_chains)] #[macro_use] mod utils; mod tokioio; mod wrappers; use tokioio::TokioIo; -use utils::ReplaceErr; +use utils::{ReplaceErr, UriExt}; use wrappers::{IncomingBody, WsStreamWrapper}; use std::sync::Arc; use bytes::Bytes; use http::{uri, HeaderName, HeaderValue, Request, Response}; -use hyper::{body::Incoming, client::conn as hyper_conn}; +use hyper::{body::Incoming, client::conn as hyper_conn, Uri}; use js_sys::{Array, Object, Reflect, Uint8Array}; -use penguin_mux_wasm::{Multiplexor, Role}; -use tokio_rustls::{rustls, rustls::RootCertStore, TlsConnector}; +use penguin_mux_wasm::{Multiplexor, MuxStream, Role}; +use tokio_rustls::{client::TlsStream, rustls, rustls::RootCertStore, TlsConnector}; +use tokio_util::either::Either; use wasm_bindgen::prelude::*; use web_sys::TextEncoder; type HttpBody = http_body_util::Full; -async fn send_req(req: http::Request, io: T) -> Result, JsError> -where - T: hyper::rt::Read + hyper::rt::Write + std::marker::Unpin + 'static, -{ - let (mut req_sender, conn) = hyper_conn::http1::handshake::(io) - .await - .replace_err("Failed to connect to host")?; +#[derive(Debug)] +enum WsTcpResponse { + Success(Response), + Redirect((Response, http::Request, Uri)), +} + +type WsTcpTlsStream = TlsStream>; +type WsTcpUnencryptedStream = MuxStream; +type WsTcpStream = Either; + +async fn send_req(req: http::Request, io: WsTcpStream) -> Result { + let (mut req_sender, conn) = + hyper_conn::http1::handshake::, HttpBody>(TokioIo::new(io)) + .await + .replace_err("Failed to connect to host")?; wasm_bindgen_futures::spawn_local(async move { if let Err(e) = conn.await { @@ -34,10 +44,38 @@ where } }); - req_sender + let mut new_req = req.clone(); + + let res = req_sender .send_request(req) .await - .replace_err("Failed to send request") + .replace_err("Failed to send request"); + match res { + Ok(res) => { + if utils::is_redirect(res.status().as_u16()) + && let Some(location) = res.headers().get("Location") + && let Ok(redirect_url) = new_req.uri().get_redirect(location) + && let Some(redirect_url_authority) = redirect_url.clone().authority().replace_err("Redirect URL must have an authority").ok() + { + let should_strip = new_req.uri().is_same_host(&redirect_url); + if should_strip { + new_req.headers_mut().remove("authorization"); + new_req.headers_mut().remove("cookie"); + new_req.headers_mut().remove("www-authenticate"); + } + let new_url = redirect_url.clone(); + *new_req.uri_mut() = redirect_url; + new_req.headers_mut().remove("Host"); + new_req + .headers_mut() + .insert("Host", HeaderValue::from_str(redirect_url_authority.as_str())?); + Ok(WsTcpResponse::Redirect((res, new_req, new_url))) + } else { + Ok(WsTcpResponse::Success(res)) + } + } + Err(err) => Err(err), + } } #[wasm_bindgen(start)] @@ -50,12 +88,13 @@ pub struct WsTcp { rustls_config: Arc, mux: Multiplexor, useragent: String, + redirect_limit: usize, } #[wasm_bindgen] impl WsTcp { #[wasm_bindgen(constructor)] - pub async fn new(ws_url: String, useragent: String) -> Result { + pub async fn new(ws_url: String, useragent: String, redirect_limit: usize) -> Result { let ws_uri = ws_url .parse::() .replace_err("Failed to parse websocket url")?; @@ -87,9 +126,60 @@ impl WsTcp { mux, rustls_config, useragent, + redirect_limit }) } + async fn get_http_io(&self, url: &Uri) -> Result { + let url_host = url.host().replace_err("URL must have a host")?; + let url_port = utils::get_url_port(url)?; + let channel = self + .mux + .client_new_stream_channel(url_host.as_bytes(), url_port) + .await + .replace_err("Failed to create multiplexor channel")?; + + if *url.scheme().replace_err("URL must have a scheme")? == uri::Scheme::HTTPS { + let cloned_uri = url_host.to_string().clone(); + let connector = TlsConnector::from(self.rustls_config.clone()); + let io = connector + .connect( + cloned_uri + .try_into() + .replace_err("Failed to parse URL (rustls)")?, + channel, + ) + .await + .replace_err("Failed to perform TLS handshake")?; + Ok(WsTcpStream::Left(io)) + } else { + Ok(WsTcpStream::Right(channel)) + } + } + + async fn send_req( + &self, + req: http::Request, + ) -> Result<(hyper::Response, Uri, bool), JsError> { + let mut redirected = false; + let uri = req.uri().clone(); + let mut current_resp: WsTcpResponse = send_req(req, self.get_http_io(&uri).await?).await?; + for _ in 0..self.redirect_limit-1 { + match current_resp { + WsTcpResponse::Success(_) => break, + WsTcpResponse::Redirect((_, req, new_url)) => { + redirected = true; + current_resp = send_req(req, self.get_http_io(&new_url).await?).await? + } + } + } + + match current_resp { + WsTcpResponse::Success(resp) => Ok((resp, uri, redirected)), + WsTcpResponse::Redirect((resp, _, new_url)) => Ok((resp, new_url, redirected)), + } + } + pub async fn fetch(&self, url: String, options: Object) -> Result { let uri = url.parse::().replace_err("Failed to parse URL")?; let uri_scheme = uri.scheme().replace_err("URL must have a scheme")?; @@ -97,19 +187,6 @@ impl WsTcp { return Err(jerr!("Scheme must be either `http` or `https`")); } let uri_host = uri.host().replace_err("URL must have a host")?; - let uri_port = if let Some(port) = uri.port() { - port.as_u16() - } else { - // can't use match, compiler error - // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` - if *uri_scheme == uri::Scheme::HTTP { - 80 - } else if *uri_scheme == uri::Scheme::HTTPS { - 443 - } else { - return Err(jerr!("Failed to coerce port from scheme")); - } - }; let req_method_string: String = match Reflect::get(&options, &jval!("method")) { Ok(val) => val.as_string().unwrap_or("GET".to_string()), @@ -174,30 +251,7 @@ impl WsTcp { .body(HttpBody::new(body_bytes)) .replace_err("Failed to make request")?; - let channel = self - .mux - .client_new_stream_channel(uri_host.as_bytes(), uri_port) - .await - .replace_err("Failed to create multiplexor channel")?; - - let resp: hyper::Response; - - if *uri_scheme == uri::Scheme::HTTPS { - let cloned_uri = uri_host.to_string().clone(); - let connector = TlsConnector::from(self.rustls_config.clone()); - let io = connector - .connect( - cloned_uri - .try_into() - .replace_err("Failed to parse URL (rustls)")?, - channel, - ) - .await - .replace_err("Failed to perform TLS handshake")?; - resp = send_req(request, TokioIo::new(io)).await?; - } else { - resp = send_req(request, TokioIo::new(channel)).await?; - } + let (resp, last_url, req_redirected) = self.send_req(request).await?; let resp_headers_jsarray = resp .headers() @@ -231,10 +285,17 @@ impl WsTcp { Object::define_property( &resp, &jval!("url"), - &utils::define_property_obj(jval!(url), false) + &utils::define_property_obj(jval!(last_url.to_string()), false) .replace_err("Failed to make define_property object for url")?, ); + Object::define_property( + &resp, + &jval!("redirected"), + &utils::define_property_obj(jval!(req_redirected), false) + .replace_err("Failed to make define_property object for redirected")?, + ); + Ok(resp) } } diff --git a/client/src/utils.rs b/client/src/utils.rs index cc10709..69a2613 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -1,5 +1,7 @@ use wasm_bindgen::prelude::*; +use hyper::{header::HeaderValue, Uri}; +use http::uri; use js_sys::{Array, Object}; pub fn set_panic_hook() { @@ -77,6 +79,29 @@ impl ReplaceErr for Option { } } +pub trait UriExt { + fn get_redirect(&self, location: &HeaderValue) -> Result; + fn is_same_host(&self, other: &Uri) -> bool; +} + +impl UriExt for Uri { + fn get_redirect(&self, location: &HeaderValue) -> Result { + let new_uri = location.to_str()?.parse::()?; + let mut new_parts: http::uri::Parts = new_uri.into(); + if new_parts.scheme.is_none() { + new_parts.scheme = self.scheme().cloned(); + } + if new_parts.authority.is_none() { + new_parts.authority = self.authority().cloned(); + } + + Ok(Uri::from_parts(new_parts)?) + } + fn is_same_host(&self, other: &Uri) -> bool { + self.host() == other.host() && self.port() == other.port() + } +} + pub fn entries_of_object(obj: &Object) -> Vec> { js_sys::Object::entries(obj) .to_vec() @@ -96,6 +121,28 @@ pub fn define_property_obj(value: JsValue, writable: bool) -> Result(); + .iter() + .collect::(); Object::from_entries(&entries) } + +pub fn is_redirect(code: u16) -> bool { + [301, 302, 303, 307, 308].contains(&code) +} + +pub fn get_url_port(url: &Uri) -> Result { + let url_scheme = url.scheme().replace_err("URL must have a scheme")?; + if let Some(port) = url.port() { + Ok(port.as_u16()) + } else { + // can't use match, compiler error + // error: to use a constant of type `Scheme` in a pattern, `Scheme` must be annotated with `#[derive(PartialEq, Eq)]` + if *url_scheme == uri::Scheme::HTTP { + Ok(80) + } else if *url_scheme == uri::Scheme::HTTPS { + Ok(443) + } else { + return Err(jerr!("Failed to coerce port from scheme")); + } + } +} diff --git a/client/src/web/index.js b/client/src/web/index.js index 3132eed..2380ca9 100644 --- a/client/src/web/index.js +++ b/client/src/web/index.js @@ -1,20 +1,20 @@ (async () => { - console.log( - "%cWASM is significantly slower with DevTools open!", - "color:red;font-size:2rem;font-weight:bold" - ); - await wasm_bindgen("./wstcp_client_bg.wasm"); - const tconn0 = performance.now(); - let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent); - const tconn1 = performance.now(); - console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); - const t0 = performance.now(); - let resp = await wstcp.fetch("https://httpbin.org/post", { method: "POST", body: "test", headers: { "X-Header-One": "one", "x-header-one": "One", "X-Header-Two": "two" } }); - const t1 = performance.now(); + console.log( + "%cWASM is significantly slower with DevTools open!", + "color:red;font-size:2rem;font-weight:bold" + ); + await wasm_bindgen("./wstcp_client_bg.wasm"); + const tconn0 = performance.now(); + // args: websocket url, user agent, redirect limit + let wstcp = await new wasm_bindgen.WsTcp("wss://localhost:4000", navigator.userAgent, 10); + const tconn1 = performance.now(); + console.warn(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); + const t0 = performance.now(); + let resp = await wstcp.fetch("http://httpbin.org/redirect/11"); + const t1 = performance.now(); - console.warn(resp); - console.warn(await fetch("https://httpbin.org/post", { method: "POST", body: "test", headers: { "X-Header-One": "one", "x-header-one": "One", "X-Header-Two": "two" } })); - console.warn(Object.fromEntries(resp.headers)); - console.warn(await resp.text()); - console.warn(`mux 1 took ${t1 - t0} ms or ${(t1 - t0) / 1000} s`); + console.warn(resp); + console.warn(Object.fromEntries(resp.headers)); + console.warn(await resp.text()); + console.warn(`mux 1 took ${t1 - t0} ms or ${(t1 - t0) / 1000} s`); })();