remove frame_size as that is causing issues

This commit is contained in:
Toshit Chawda 2024-03-29 13:51:48 -07:00
parent 795269ca42
commit 4301bb8b65
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D

View file

@ -4,12 +4,12 @@ use std::io::Error;
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
use fastwebsockets::{ use fastwebsockets::{
upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
WebSocket, WebSocketError, WebSocketError,
}; };
use futures_util::{SinkExt, StreamExt, TryFutureExt}; use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{ use hyper::{
body::Incoming, server::conn::http1, service::service_fn, upgrade::Upgraded, Request, Response, body::Incoming, server::conn::http1, service::service_fn, Request, Response,
StatusCode, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
@ -54,9 +54,6 @@ struct Cli {
/// Whether the server should block ports other than 80 or 443 /// Whether the server should block ports other than 80 or 443
#[arg(long)] #[arg(long)]
block_non_http: bool, block_non_http: bool,
/// Maximum WebSocket frame size allowed
#[arg(long, short, default_value_t = 64 << 20)]
frame_size: usize,
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -143,7 +140,6 @@ async fn main() -> Result<(), Error> {
while let Ok((stream, addr)) = socket.accept().await { while let Ok((stream, addr)) = socket.accept().await {
let prefix = prefix.clone(); let prefix = prefix.clone();
tokio::spawn(async move { tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(move |res| { let service = service_fn(move |res| {
accept_http( accept_http(
res, res,
@ -152,11 +148,10 @@ async fn main() -> Result<(), Error> {
opt.block_local, opt.block_local,
opt.block_udp, opt.block_udp,
opt.block_non_http, opt.block_non_http,
opt.frame_size,
) )
}); });
let conn = http1::Builder::new() let conn = http1::Builder::new()
.serve_connection(io, service) .serve_connection(TokioIo::new(stream), service)
.with_upgrades(); .with_upgrades();
if let Err(err) = conn.await { if let Err(err) = conn.await {
println!("failed to serve conn: {:?}", err); println!("failed to serve conn: {:?}", err);
@ -174,7 +169,6 @@ async fn accept_http(
block_local: bool, block_local: bool,
block_udp: bool, block_udp: bool,
block_non_http: bool, block_non_http: bool,
max_size: usize,
) -> Result<Response<HttpBody>, WebSocketError> { ) -> Result<Response<HttpBody>, WebSocketError> {
let uri = req.uri().path().to_string(); let uri = req.uri().path().to_string();
if upgrade::is_upgrade_request(&req) if upgrade::is_upgrade_request(&req)
@ -182,17 +176,13 @@ async fn accept_http(
{ {
let (res, fut) = upgrade::upgrade(&mut req)?; let (res, fut) = upgrade::upgrade(&mut req)?;
let mut ws = fut.await?;
ws.set_max_message_size(max_size);
if uri.is_empty() { if uri.is_empty() {
tokio::spawn(async move { tokio::spawn(async move {
accept_ws(ws, addr.clone(), block_local, block_udp, block_non_http).await accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await
}); });
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
tokio::spawn(async move { tokio::spawn(async move {
accept_wsproxy(ws, uri, addr.clone(), block_local, block_non_http).await accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await
}); });
} }
@ -260,13 +250,13 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
} }
async fn accept_ws( async fn accept_ws(
ws: WebSocket<TokioIo<Upgraded>>, ws: UpgradeFut,
addr: String, addr: String,
block_local: bool, block_local: bool,
block_non_http: bool, block_non_http: bool,
block_udp: bool, block_udp: bool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.await?.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
@ -334,13 +324,13 @@ async fn accept_ws(
} }
async fn accept_wsproxy( async fn accept_wsproxy(
ws: WebSocket<TokioIo<Upgraded>>, ws: UpgradeFut,
incoming_uri: String, incoming_uri: String,
addr: String, addr: String,
block_local: bool, block_local: bool,
block_non_http: bool, block_non_http: bool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let mut ws_stream = FragmentCollector::new(ws); let mut ws_stream = FragmentCollector::new(ws.await?);
println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri);