vendor fastwebsockets upgrade

This commit is contained in:
Toshit Chawda 2024-11-27 22:45:10 -08:00
parent f942c0a7c6
commit ed8d22a52f
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 112 additions and 17 deletions

1
Cargo.lock generated
View file

@ -758,6 +758,7 @@ dependencies = [
"serde",
"serde_json",
"serde_yaml",
"sha1",
"sha2",
"shell-words",
"tikv-jemalloc-ctl",

View file

@ -17,7 +17,7 @@ clap = { version = "4.5.16", features = ["cargo", "derive"] }
ed25519-dalek = { version = "2.1.1", features = ["pem"] }
env_logger = "0.11.5"
event-listener = "5.3.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
fastwebsockets = { version = "0.8.0", features = ["unstable-split"] }
futures-util = "0.3.30"
hickory-resolver = "0.24.1"
http-body-util = "0.1.2"
@ -34,6 +34,7 @@ rustls-pemfile = "2.1.3"
serde = { version = "1.0.208", features = ["derive"] }
serde_json = "1.0.125"
serde_yaml = { version = "0.9.34", optional = true }
sha1 = "0.10.6"
sha2 = "0.10.8"
shell-words = { version = "1.1.0", optional = true }
tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] }

View file

@ -39,6 +39,8 @@ mod stats;
mod stream;
#[doc(hidden)]
mod util_chain;
#[doc(hidden)]
mod upgrade;
#[doc(hidden)]
type Client = (Mutex<HashMap<Uuid, (ConnectPacket, ConnectPacket)>>, String);

View file

@ -2,11 +2,11 @@ use std::{fmt::Display, future::Future, io::Cursor};
use anyhow::Context;
use bytes::Bytes;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite};
use fastwebsockets::{FragmentCollector, Role, WebSocket, WebSocketRead, WebSocketWrite};
use http_body_util::Full;
use hyper::{
body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder,
service::service_fn, HeaderMap, Request, Response, StatusCode,
service::service_fn, upgrade::OnUpgrade, HeaderMap, Request, Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use log::{debug, error, trace};
@ -21,6 +21,7 @@ use crate::{
generate_stats,
listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite},
stream::WebSocketStreamWrapper,
upgrade::{is_upgrade_request, upgrade},
util_chain::{chain, Chain},
CONFIG,
};
@ -55,7 +56,7 @@ impl Display for ServerRouteResult {
match self {
Self::Wisp { .. } => write!(f, "Wisp"),
Self::Wispnet { .. } => write!(f, "Wispnet"),
Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {:?} udp {:?}", path, udp),
Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {path:?} udp {udp:?}"),
}
}
}
@ -88,7 +89,7 @@ fn get_header(headers: &HeaderMap, header: &str) -> Option<String> {
headers
.get(header)
.and_then(|x| x.to_str().ok())
.map(|x| x.to_string())
.map(ToString::to_string)
}
enum HttpUpgradeResult {
@ -108,28 +109,25 @@ async fn ws_upgrade<F, R>(
callback: F,
) -> anyhow::Result<Response<Body>>
where
F: FnOnce(UpgradeFut, HttpUpgradeResult, Option<String>) -> R + Send + 'static,
F: FnOnce(OnUpgrade, HttpUpgradeResult, Option<String>) -> R + Send + 'static,
R: Future<Output = anyhow::Result<()>> + Send,
{
let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req);
let is_upgrade = is_upgrade_request(&req);
if !is_upgrade {
if let Some(stats_endpoint) = stats_endpoint {
if req.uri().path() == stats_endpoint {
return send_stats().await;
} else {
debug!("sent non_ws_response to http client");
return non_ws_resp();
}
} else {
debug!("sent non_ws_response to http client");
return non_ws_resp();
}
debug!("sent non_ws_response to http client");
return non_ws_resp();
}
trace!("recieved request {:?}", req);
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
let (resp, fut) = upgrade(&mut req)?;
// replace body of Empty<Bytes> with Full<Bytes>
let mut resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
@ -145,7 +143,8 @@ where
if req_path.ends_with(&(CONFIG.wisp.prefix.clone() + "/")) {
let has_ws_protocol = ws_protocol.is_some();
let is_wispnet = CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net";
let is_wispnet =
CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net";
tokio::spawn(async move {
if let Err(err) = (callback)(
fut,
@ -215,7 +214,10 @@ pub async fn route(
req,
stats_endpoint.clone(),
|fut, res, maybe_ip| async move {
let mut ws = fut.await.context("failed to await upgrade future")?;
let ws = fut.await.context("failed to await upgrade future")?;
let mut ws =
WebSocket::after_handshake(TokioIo::new(ws), Role::Server);
ws.set_max_message_size(CONFIG.server.max_message_size);
ws.set_auto_pong(false);
@ -250,7 +252,7 @@ pub async fn route(
}
};
(callback)(result, maybe_ip)
(callback)(result, maybe_ip);
}
HttpUpgradeResult::WsProxy { path, udp } => {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));

89
server/src/upgrade.rs Normal file
View file

@ -0,0 +1,89 @@
//! taken from https://github.com/denoland/fastwebsockets/blob/main/src/upgrade.rs
use anyhow::{bail, Context, Result};
use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{header::HeaderValue, upgrade::OnUpgrade, Request, Response};
use sha1::{Digest, Sha1};
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
}
/// Check if there is a header of the given name containing the wanted value.
fn header_contains_value(
headers: &hyper::HeaderMap,
header: impl hyper::header::AsHeaderName,
value: impl AsRef<[u8]>,
) -> bool {
let value = value.as_ref();
for header in headers.get_all(header) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
false
}
fn trim(data: &[u8]) -> &[u8] {
trim_end(trim_start(data))
}
fn trim_start(data: &[u8]) -> &[u8] {
if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
&data[start..]
} else {
b""
}
}
fn trim_end(data: &[u8]) -> &[u8] {
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
&data[..=last]
} else {
b""
}
}
fn sec_websocket_protocol(key: &[u8]) -> String {
let mut sha1 = Sha1::new();
sha1.update(key);
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); // magic string
let result = sha1.finalize();
BASE64_STANDARD.encode(&result[..])
}
// slightly modified to use anyhow
pub fn upgrade<B>(request: &mut Request<B>) -> Result<(Response<Empty<Bytes>>, OnUpgrade)> {
let key = request
.headers()
.get("Sec-WebSocket-Key")
.context("missing Sec-WebSocket-Key")?;
if request
.headers()
.get("Sec-WebSocket-Version")
.map(HeaderValue::as_bytes)
!= Some(b"13")
{
bail!("invalid Sec-WebSocket-Version, not 13");
}
let response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header(
"Sec-WebSocket-Accept",
&sec_websocket_protocol(key.as_bytes()),
)
.body(Empty::new())
.context("failed to build upgrade response")?;
Ok((response, hyper::upgrade::on(request)))
}