diff --git a/Cargo.lock b/Cargo.lock index 972812a..8dd7229 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -758,6 +758,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sha1", "sha2", "shell-words", "tikv-jemalloc-ctl", diff --git a/server/Cargo.toml b/server/Cargo.toml index 0248352..b3094e6 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -17,7 +17,7 @@ clap = { version = "4.5.16", features = ["cargo", "derive"] } ed25519-dalek = { version = "2.1.1", features = ["pem"] } env_logger = "0.11.5" event-listener = "5.3.1" -fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } +fastwebsockets = { version = "0.8.0", features = ["unstable-split"] } futures-util = "0.3.30" hickory-resolver = "0.24.1" http-body-util = "0.1.2" @@ -34,6 +34,7 @@ rustls-pemfile = "2.1.3" serde = { version = "1.0.208", features = ["derive"] } serde_json = "1.0.125" serde_yaml = { version = "0.9.34", optional = true } +sha1 = "0.10.6" sha2 = "0.10.8" shell-words = { version = "1.1.0", optional = true } tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] } diff --git a/server/src/main.rs b/server/src/main.rs index a0c76c9..74b5186 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -39,6 +39,8 @@ mod stats; mod stream; #[doc(hidden)] mod util_chain; +#[doc(hidden)] +mod upgrade; #[doc(hidden)] type Client = (Mutex>, String); diff --git a/server/src/route.rs b/server/src/route.rs index 64e6eac..5216d86 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -2,11 +2,11 @@ use std::{fmt::Display, future::Future, io::Cursor}; use anyhow::Context; use bytes::Bytes; -use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite}; +use fastwebsockets::{FragmentCollector, Role, WebSocket, WebSocketRead, WebSocketWrite}; use http_body_util::Full; use hyper::{ body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder, - service::service_fn, HeaderMap, Request, Response, StatusCode, + service::service_fn, upgrade::OnUpgrade, HeaderMap, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; use log::{debug, error, trace}; @@ -21,6 +21,7 @@ use crate::{ generate_stats, listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite}, stream::WebSocketStreamWrapper, + upgrade::{is_upgrade_request, upgrade}, util_chain::{chain, Chain}, CONFIG, }; @@ -55,7 +56,7 @@ impl Display for ServerRouteResult { match self { Self::Wisp { .. } => write!(f, "Wisp"), Self::Wispnet { .. } => write!(f, "Wispnet"), - Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {:?} udp {:?}", path, udp), + Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {path:?} udp {udp:?}"), } } } @@ -88,7 +89,7 @@ fn get_header(headers: &HeaderMap, header: &str) -> Option { headers .get(header) .and_then(|x| x.to_str().ok()) - .map(|x| x.to_string()) + .map(ToString::to_string) } enum HttpUpgradeResult { @@ -108,28 +109,25 @@ async fn ws_upgrade( callback: F, ) -> anyhow::Result> where - F: FnOnce(UpgradeFut, HttpUpgradeResult, Option) -> R + Send + 'static, + F: FnOnce(OnUpgrade, HttpUpgradeResult, Option) -> R + Send + 'static, R: Future> + Send, { - let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); + let is_upgrade = is_upgrade_request(&req); if !is_upgrade { if let Some(stats_endpoint) = stats_endpoint { if req.uri().path() == stats_endpoint { return send_stats().await; - } else { - debug!("sent non_ws_response to http client"); - return non_ws_resp(); } - } else { - debug!("sent non_ws_response to http client"); - return non_ws_resp(); } + + debug!("sent non_ws_response to http client"); + return non_ws_resp(); } trace!("recieved request {:?}", req); - let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; + let (resp, fut) = upgrade(&mut req)?; // replace body of Empty with Full let mut resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); @@ -145,7 +143,8 @@ where if req_path.ends_with(&(CONFIG.wisp.prefix.clone() + "/")) { let has_ws_protocol = ws_protocol.is_some(); - let is_wispnet = CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net"; + let is_wispnet = + CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net"; tokio::spawn(async move { if let Err(err) = (callback)( fut, @@ -215,7 +214,10 @@ pub async fn route( req, stats_endpoint.clone(), |fut, res, maybe_ip| async move { - let mut ws = fut.await.context("failed to await upgrade future")?; + let ws = fut.await.context("failed to await upgrade future")?; + + let mut ws = + WebSocket::after_handshake(TokioIo::new(ws), Role::Server); ws.set_max_message_size(CONFIG.server.max_message_size); ws.set_auto_pong(false); @@ -250,7 +252,7 @@ pub async fn route( } }; - (callback)(result, maybe_ip) + (callback)(result, maybe_ip); } HttpUpgradeResult::WsProxy { path, udp } => { let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); diff --git a/server/src/upgrade.rs b/server/src/upgrade.rs new file mode 100644 index 0000000..db966e1 --- /dev/null +++ b/server/src/upgrade.rs @@ -0,0 +1,89 @@ +//! taken from https://github.com/denoland/fastwebsockets/blob/main/src/upgrade.rs + +use anyhow::{bail, Context, Result}; +use base64::{prelude::BASE64_STANDARD, Engine}; +use bytes::Bytes; +use http_body_util::Empty; +use hyper::{header::HeaderValue, upgrade::OnUpgrade, Request, Response}; +use sha1::{Digest, Sha1}; + +pub fn is_upgrade_request(request: &hyper::Request) -> bool { + header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade") + && header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket") +} + +/// Check if there is a header of the given name containing the wanted value. +fn header_contains_value( + headers: &hyper::HeaderMap, + header: impl hyper::header::AsHeaderName, + value: impl AsRef<[u8]>, +) -> bool { + let value = value.as_ref(); + for header in headers.get_all(header) { + if header + .as_bytes() + .split(|&c| c == b',') + .any(|x| trim(x).eq_ignore_ascii_case(value)) + { + return true; + } + } + false +} + +fn trim(data: &[u8]) -> &[u8] { + trim_end(trim_start(data)) +} + +fn trim_start(data: &[u8]) -> &[u8] { + if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) { + &data[start..] + } else { + b"" + } +} + +fn trim_end(data: &[u8]) -> &[u8] { + if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { + &data[..=last] + } else { + b"" + } +} + +fn sec_websocket_protocol(key: &[u8]) -> String { + let mut sha1 = Sha1::new(); + sha1.update(key); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); // magic string + let result = sha1.finalize(); + BASE64_STANDARD.encode(&result[..]) +} + +// slightly modified to use anyhow +pub fn upgrade(request: &mut Request) -> Result<(Response>, OnUpgrade)> { + let key = request + .headers() + .get("Sec-WebSocket-Key") + .context("missing Sec-WebSocket-Key")?; + if request + .headers() + .get("Sec-WebSocket-Version") + .map(HeaderValue::as_bytes) + != Some(b"13") + { + bail!("invalid Sec-WebSocket-Version, not 13"); + } + + let response = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, "upgrade") + .header(hyper::header::UPGRADE, "websocket") + .header( + "Sec-WebSocket-Accept", + &sec_websocket_protocol(key.as_bytes()), + ) + .body(Empty::new()) + .context("failed to build upgrade response")?; + + Ok((response, hyper::upgrade::on(request))) +}