massive speed improvements

This commit is contained in:
Toshit Chawda 2024-07-05 16:03:55 -07:00
parent b22ff47f19
commit 4f0a362390
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 282 additions and 89 deletions

View file

@ -2,6 +2,7 @@
use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
use bytes::Bytes;
use cfg_if::cfg_if;
use clap::Parser;
use fastwebsockets::{
upgrade::{self, UpgradeFut},
@ -9,18 +10,23 @@ use fastwebsockets::{
};
use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
body::Incoming, server::conn::http1, service::service_fn, upgrade::Parts, Request, Response,
StatusCode,
};
use hyper_util::rt::TokioIo;
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use tokio::{
io::copy_bidirectional,
io::{copy, AsyncBufReadExt, AsyncWriteExt},
net::{lookup_host, TcpListener, TcpStream, UdpSocket},
select,
};
use tokio_util::codec::{BytesCodec, Framed};
#[cfg(unix)]
use tokio_util::either::Either;
use tokio_util::{
codec::{BytesCodec, Framed},
compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
};
use wisp_mux::{
extensions::{
@ -28,7 +34,7 @@ use wisp_mux::{
udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
},
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError,
};
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -83,10 +89,13 @@ struct MuxOptions {
pub wisp_v1: bool,
}
#[cfg(not(unix))]
type ListenerStream = TcpStream;
#[cfg(unix)]
type ListenerStream = Either<TcpStream, UnixStream>;
cfg_if! {
if #[cfg(unix)] {
type ListenerStream = Either<TcpStream, UnixStream>;
} else {
type ListenerStream = TcpStream;
}
}
enum Listener {
Tcp(TcpListener),
@ -99,13 +108,12 @@ impl Listener {
Ok(match self {
Listener::Tcp(listener) => {
let (stream, addr) = listener.accept().await?;
#[cfg(not(unix))]
{
(stream, addr.to_string())
}
#[cfg(unix)]
{
(Either::Left(stream), addr.to_string())
cfg_if! {
if #[cfg(unix)] {
(Either::Left(stream), addr.to_string())
} else {
(stream, addr.to_string())
}
}
}
#[cfg(unix)]
@ -123,17 +131,20 @@ impl Listener {
}
async fn bind(addr: &str, unix: bool) -> Result<Listener, std::io::Error> {
#[cfg(unix)]
if unix {
if std::fs::metadata(addr).is_ok() {
println!("attempting to remove old socket {:?}", addr);
std::fs::remove_file(addr)?;
cfg_if! {
if #[cfg(unix)] {
if unix {
if std::fs::metadata(addr).is_ok() {
println!("attempting to remove old socket {:?}", addr);
std::fs::remove_file(addr)?;
}
return Ok(Listener::Unix(UnixListener::bind(addr)?));
}
} else {
if unix {
panic!("Unix sockets are only supported on Unix.");
}
}
return Ok(Listener::Unix(UnixListener::bind(addr)?));
}
#[cfg(not(unix))]
if unix {
panic!("Unix sockets are only supported on Unix.");
}
Ok(Listener::Tcp(TcpListener::bind(addr).await?))
@ -258,6 +269,38 @@ async fn accept_http(
}
}
async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> {
let (muxrx, muxtx) = mux.into_split();
let mut muxrx = muxrx.compat();
let mut muxtx = muxtx.compat_write();
let (mut tcprx, mut tcptx) = tcp.into_split();
let fast_fut = async {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume(i);
}
};
let slow_fut = copy(&mut tcprx, &mut muxtx);
select! {
x = fast_fut => x,
x = slow_fut => x.map(|_| ()),
}
}
async fn handle_mux(
packet: ConnectPacket,
stream: MuxStream,
@ -268,9 +311,9 @@ async fn handle_mux(
);
match packet.stream_type {
StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri).await?;
let mut mux_stream = stream.into_io().into_asyncrw();
copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?;
let tcp_stream = TcpStream::connect(uri).await?;
let mux = stream.into_io().into_asyncrw();
copy_buf(mux, tcp_stream).await?;
}
StreamType::Udp => {
let uri = lookup_host(uri)
@ -315,7 +358,31 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the message
// size is set to 1M
ws.set_max_message_size(1024 * 1024);
let (rx, tx) = ws.split(tokio::io::split);
let (rx, tx) = ws.split(|x| {
let Parts {
io, read_buf: buf, ..
} = x
.into_inner()
.downcast::<TokioIo<ListenerStream>>()
.unwrap();
assert_eq!(buf.len(), 0);
cfg_if! {
if #[cfg(unix)] {
match io.into_inner() {
Either::Left(x) => {
let (rx, tx) = x.into_split();
(Either::Left(rx), Either::Left(tx))
}
Either::Right(x) => {
let (rx, tx) = x.into_split();
(Either::Right(rx), Either::Right(tx))
}
}
} else {
io.into_inner().into_split()
}
}
});
let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr);