use knockoff dynosaur to remove async_trait on wsr/wsw

This commit is contained in:
Toshit Chawda 2024-11-23 15:00:12 -08:00
parent 5e54465e58
commit 9129d767f8
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
31 changed files with 692 additions and 258 deletions

View file

@ -24,6 +24,7 @@ lazy_static = "1.5.0"
libc = { version = "0.2.158", optional = true }
log = { version = "0.4.22", features = ["serde", "std"] }
nix = { version = "0.29.0", features = ["term"] }
pin-project-lite = "0.2.15"
pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true }
regex = "1.10.6"
rustls-pemfile = "2.1.3"

View file

@ -25,7 +25,7 @@ use wisp_mux::{
};
use crate::{
route::WispResult,
route::{WispResult, WispStreamWrite},
stream::{ClientStream, ResolvedPacket},
CLIENTS, CONFIG,
};
@ -58,7 +58,7 @@ async fn copy_read_fast(
}
async fn copy_write_fast(
muxtx: MuxStreamWrite,
muxtx: MuxStreamWrite<WispStreamWrite>,
tcprx: OwnedReadHalf,
#[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
@ -83,7 +83,7 @@ async fn copy_write_fast(
async fn handle_stream(
connect: ConnectPacket,
muxstream: MuxStream,
muxstream: MuxStream<WispStreamWrite>,
id: String,
event: Arc<Event>,
#[cfg(feature = "twisp")] twisp_map: twisp::TwispMap,

View file

@ -37,6 +37,8 @@ mod route;
mod stats;
#[doc(hidden)]
mod stream;
#[doc(hidden)]
mod util_chain;
#[doc(hidden)]
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);

View file

@ -2,7 +2,7 @@ use std::{fmt::Display, future::Future, io::Cursor};
use anyhow::Context;
use bytes::Bytes;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector, WebSocketRead, WebSocketWrite};
use http_body_util::Full;
use hyper::{
body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder,
@ -10,25 +10,30 @@ use hyper::{
};
use hyper_util::rt::TokioIo;
use log::{debug, error, trace};
use tokio::io::AsyncReadExt;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use wisp_mux::{
generic::{GenericWebSocketRead, GenericWebSocketWrite},
ws::{WebSocketRead, WebSocketWrite},
ws::{EitherWebSocketRead, EitherWebSocketWrite},
};
use crate::{
config::SocketTransport,
generate_stats,
listener::{ServerStream, ServerStreamExt},
listener::{ServerStream, ServerStreamExt, ServerStreamRead, ServerStreamWrite},
stream::WebSocketStreamWrapper,
util_chain::{chain, Chain},
CONFIG,
};
pub type WispResult = (
Box<dyn WebSocketRead + Send>,
Box<dyn WebSocketWrite + Send>,
);
pub type WispStreamRead = EitherWebSocketRead<
WebSocketRead<Chain<Cursor<Bytes>, ServerStreamRead>>,
GenericWebSocketRead<FramedRead<ServerStreamRead, LengthDelimitedCodec>, std::io::Error>,
>;
pub type WispStreamWrite = EitherWebSocketWrite<
WebSocketWrite<ServerStreamWrite>,
GenericWebSocketWrite<FramedWrite<ServerStreamWrite, LengthDelimitedCodec>, std::io::Error>,
>;
pub type WispResult = (WispStreamRead, WispStreamWrite);
pub enum ServerRouteResult {
Wisp(WispResult, bool),
@ -190,12 +195,15 @@ pub async fn route(
.downcast::<TokioIo<ServerStream>>()
.unwrap();
let (r, w) = parts.io.into_inner().split();
(Cursor::new(parts.read_buf).chain(r), w)
(chain(Cursor::new(parts.read_buf), r), w)
});
(callback)(
ServerRouteResult::Wisp(
(Box::new(read), Box::new(write)),
(
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
),
is_v2,
),
maybe_ip,
@ -229,7 +237,13 @@ pub async fn route(
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)(
ServerRouteResult::Wisp((Box::new(read), Box::new(write)), true),
ServerRouteResult::Wisp(
(
EitherWebSocketRead::Right(read),
EitherWebSocketWrite::Right(write),
),
true,
),
None,
);
}

View file

@ -44,7 +44,6 @@ pub enum ClientStream {
Invalid,
}
// taken from rust 1.82.0
fn ipv4_is_global(addr: &Ipv4Addr) -> bool {
!(addr.octets()[0] == 0 // "This network"

100
server/src/util_chain.rs Normal file
View file

@ -0,0 +1,100 @@
// taken from tokio io util
use std::{
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
pin_project! {
pub struct Chain<T, U> {
#[pin]
first: T,
#[pin]
second: U,
done_first: bool,
}
}
pub fn chain<T, U>(first: T, second: U) -> Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
Chain {
first,
second,
done_first: false,
}
}
impl<T, U> fmt::Debug for Chain<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Chain")
.field("t", &self.first)
.field("u", &self.second)
.finish()
}
}
impl<T, U> AsyncRead for Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.project();
if !*me.done_first {
let rem = buf.remaining();
ready!(me.first.poll_read(cx, buf))?;
if buf.remaining() == rem {
*me.done_first = true;
} else {
return Poll::Ready(Ok(()));
}
}
me.second.poll_read(cx, buf)
}
}
impl<T, U> AsyncBufRead for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let me = self.project();
if !*me.done_first {
match ready!(me.first.poll_fill_buf(cx)?) {
[] => {
*me.done_first = true;
}
buf => return Poll::Ready(Ok(buf)),
}
}
me.second.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let me = self.project();
if !*me.done_first {
me.first.consume(amt)
} else {
me.second.consume(amt)
}
}
}