massive speed improvements

This commit is contained in:
Toshit Chawda 2024-07-05 16:03:55 -07:00
parent b22ff47f19
commit 4f0a362390
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 282 additions and 89 deletions

31
Cargo.lock generated
View file

@ -143,17 +143,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "async_io_stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c"
dependencies = [
"futures",
"rustc_version",
"tokio",
]
[[package]] [[package]]
name = "atomic-counter" name = "atomic-counter"
version = "1.0.1" version = "1.0.1"
@ -559,6 +548,7 @@ name = "epoxy-server"
version = "1.0.0" version = "1.0.0"
dependencies = [ dependencies = [
"bytes", "bytes",
"cfg-if",
"clap", "clap",
"clio", "clio",
"console-subscriber", "console-subscriber",
@ -1556,15 +1546,6 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.34" version = "0.38.34"
@ -1671,12 +1652,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "semver"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.4.0" version = "0.4.0"
@ -1963,6 +1938,7 @@ checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"futures-io",
"futures-sink", "futures-sink",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
@ -2482,10 +2458,9 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]] [[package]]
name = "wisp-mux" name = "wisp-mux"
version = "4.0.1" version = "5.0.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"async_io_stream",
"bytes", "bytes",
"dashmap", "dashmap",
"event-listener", "event-listener",

View file

@ -32,7 +32,7 @@ wasm-bindgen = "0.2.92"
wasm-bindgen-futures = "0.4.42" wasm-bindgen-futures = "0.4.42"
wasm-streams = "0.4.0" wasm-streams = "0.4.0"
web-sys = { version = "0.3.69", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "WebSocket"] } web-sys = { version = "0.3.69", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "WebSocket"] }
wisp-mux = { version = "4.0.1", path = "../wisp", features = ["wasm"] } wisp-mux = { path = "../wisp", features = ["wasm"] }
[dependencies.ring] [dependencies.ring]
# update whenever rustls updates # update whenever rustls updates

View file

@ -1,4 +1,4 @@
use bytes::{buf::UninitSlice, BufMut, BytesMut}; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
use futures_util::{ use futures_util::{
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
}; };
@ -105,7 +105,7 @@ impl EpoxyIoStream {
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyUdpStream { pub struct EpoxyUdpStream {
tx: Mutex<SplitSink<ProviderUnencryptedStream, Vec<u8>>>, tx: Mutex<SplitSink<ProviderUnencryptedStream, Bytes>>,
onerror: Function, onerror: Function,
} }
@ -154,7 +154,7 @@ impl EpoxyUdpStream {
.map_err(|_| EpoxyError::InvalidPayload)? .map_err(|_| EpoxyError::InvalidPayload)?
.0 .0
.to_vec(); .to_vec();
Ok(self.tx.lock().await.send(payload).await?) Ok(self.tx.lock().await.send(payload.into()).await?)
} }
.await; .await;

View file

@ -17,8 +17,7 @@ use tower_service::Service;
use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use wisp_mux::{ use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType
ClientMux, IoStream, MuxStreamIo, StreamType,
}; };
use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; use crate::{ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
@ -50,7 +49,7 @@ pub struct StreamProvider {
} }
pub type ProviderUnencryptedStream = MuxStreamIo; pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = IoStream<ProviderUnencryptedStream, Vec<u8>>; pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>; pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>; pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;

View file

@ -5,6 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
bytes = "1.5.0" bytes = "1.5.0"
cfg-if = "1.0.0"
clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] } clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] }
clio = { version = "0.3.5", features = ["clap-parse"] } clio = { version = "0.3.5", features = ["clap-parse"] }
console-subscriber = { version = "0.2.0", optional = true } console-subscriber = { version = "0.2.0", optional = true }
@ -15,8 +16,8 @@ http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["server", "http1"] } hyper = { version = "1.1.0", features = ["server", "http1"] }
hyper-util = { version = "0.1.2", features = ["tokio"] } hyper-util = { version = "0.1.2", features = ["tokio"] }
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
tokio-util = { version = "0.7.10", features = ["codec"] } tokio-util = { version = "0.7.10", features = ["codec", "compat"] }
wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] } wisp-mux = { path = "../wisp", features = ["fastwebsockets"] }
[features] [features]
tokio-console = ["tokio/tracing", "dep:console-subscriber"] tokio-console = ["tokio/tracing", "dep:console-subscriber"]

View file

@ -2,6 +2,7 @@
use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc}; use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
use bytes::Bytes; use bytes::Bytes;
use cfg_if::cfg_if;
use clap::Parser; use clap::Parser;
use fastwebsockets::{ use fastwebsockets::{
upgrade::{self, UpgradeFut}, upgrade::{self, UpgradeFut},
@ -9,18 +10,23 @@ use fastwebsockets::{
}; };
use futures_util::{SinkExt, StreamExt, TryFutureExt}; use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{ use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, body::Incoming, server::conn::http1, service::service_fn, upgrade::Parts, Request, Response,
StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tokio::{ use tokio::{
io::copy_bidirectional, io::{copy, AsyncBufReadExt, AsyncWriteExt},
net::{lookup_host, TcpListener, TcpStream, UdpSocket}, net::{lookup_host, TcpListener, TcpStream, UdpSocket},
select,
}; };
use tokio_util::codec::{BytesCodec, Framed};
#[cfg(unix)] #[cfg(unix)]
use tokio_util::either::Either; use tokio_util::either::Either;
use tokio_util::{
codec::{BytesCodec, Framed},
compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
};
use wisp_mux::{ use wisp_mux::{
extensions::{ extensions::{
@ -28,7 +34,7 @@ use wisp_mux::{
udp::UdpProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder, ProtocolExtensionBuilder,
}, },
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError,
}; };
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -83,10 +89,13 @@ struct MuxOptions {
pub wisp_v1: bool, pub wisp_v1: bool,
} }
#[cfg(not(unix))] cfg_if! {
type ListenerStream = TcpStream; if #[cfg(unix)] {
#[cfg(unix)] type ListenerStream = Either<TcpStream, UnixStream>;
type ListenerStream = Either<TcpStream, UnixStream>; } else {
type ListenerStream = TcpStream;
}
}
enum Listener { enum Listener {
Tcp(TcpListener), Tcp(TcpListener),
@ -99,13 +108,12 @@ impl Listener {
Ok(match self { Ok(match self {
Listener::Tcp(listener) => { Listener::Tcp(listener) => {
let (stream, addr) = listener.accept().await?; let (stream, addr) = listener.accept().await?;
#[cfg(not(unix))] cfg_if! {
{ if #[cfg(unix)] {
(Either::Left(stream), addr.to_string())
} else {
(stream, addr.to_string()) (stream, addr.to_string())
} }
#[cfg(unix)]
{
(Either::Left(stream), addr.to_string())
} }
} }
#[cfg(unix)] #[cfg(unix)]
@ -123,7 +131,8 @@ impl Listener {
} }
async fn bind(addr: &str, unix: bool) -> Result<Listener, std::io::Error> { async fn bind(addr: &str, unix: bool) -> Result<Listener, std::io::Error> {
#[cfg(unix)] cfg_if! {
if #[cfg(unix)] {
if unix { if unix {
if std::fs::metadata(addr).is_ok() { if std::fs::metadata(addr).is_ok() {
println!("attempting to remove old socket {:?}", addr); println!("attempting to remove old socket {:?}", addr);
@ -131,10 +140,12 @@ async fn bind(addr: &str, unix: bool) -> Result<Listener, std::io::Error> {
} }
return Ok(Listener::Unix(UnixListener::bind(addr)?)); return Ok(Listener::Unix(UnixListener::bind(addr)?));
} }
#[cfg(not(unix))] } else {
if unix { if unix {
panic!("Unix sockets are only supported on Unix."); panic!("Unix sockets are only supported on Unix.");
} }
}
}
Ok(Listener::Tcp(TcpListener::bind(addr).await?)) Ok(Listener::Tcp(TcpListener::bind(addr).await?))
} }
@ -258,6 +269,38 @@ async fn accept_http(
} }
} }
async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> {
let (muxrx, muxtx) = mux.into_split();
let mut muxrx = muxrx.compat();
let mut muxtx = muxtx.compat_write();
let (mut tcprx, mut tcptx) = tcp.into_split();
let fast_fut = async {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume(i);
}
};
let slow_fut = copy(&mut tcprx, &mut muxtx);
select! {
x = fast_fut => x,
x = slow_fut => x.map(|_| ()),
}
}
async fn handle_mux( async fn handle_mux(
packet: ConnectPacket, packet: ConnectPacket,
stream: MuxStream, stream: MuxStream,
@ -268,9 +311,9 @@ async fn handle_mux(
); );
match packet.stream_type { match packet.stream_type {
StreamType::Tcp => { StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri).await?; let tcp_stream = TcpStream::connect(uri).await?;
let mut mux_stream = stream.into_io().into_asyncrw(); let mux = stream.into_io().into_asyncrw();
copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?; copy_buf(mux, tcp_stream).await?;
} }
StreamType::Udp => { StreamType::Udp => {
let uri = lookup_host(uri) let uri = lookup_host(uri)
@ -315,7 +358,31 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the message // to prevent memory ""leaks"" because users are sending in packets way too fast the message
// size is set to 1M // size is set to 1M
ws.set_max_message_size(1024 * 1024); ws.set_max_message_size(1024 * 1024);
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.split(|x| {
let Parts {
io, read_buf: buf, ..
} = x
.into_inner()
.downcast::<TokioIo<ListenerStream>>()
.unwrap();
assert_eq!(buf.len(), 0);
cfg_if! {
if #[cfg(unix)] {
match io.into_inner() {
Either::Left(x) => {
let (rx, tx) = x.into_split();
(Either::Left(rx), Either::Left(tx))
}
Either::Right(x) => {
let (rx, tx) = x.into_split();
(Either::Right(rx), Either::Right(tx))
}
}
} else {
io.into_inner().into_split()
}
}
});
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
println!("{:?}: connected", addr); println!("{:?}: connected", addr);

View file

@ -1,6 +1,6 @@
[package] [package]
name = "wisp-mux" name = "wisp-mux"
version = "4.0.1" version = "5.0.0"
license = "LGPL-3.0-only" license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients." description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
@ -10,7 +10,6 @@ edition = "2021"
[dependencies] [dependencies]
async-trait = "0.1.79" async-trait = "0.1.79"
async_io_stream = "0.3.3"
bytes = "1.5.0" bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] } dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0" event-listener = "5.0.0"
@ -23,7 +22,6 @@ tokio = { version = "1.35.1", optional = true, default-features = false }
[features] [features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
tokio_io = ["async_io_stream/tokio_io"]
wasm = ["futures-timer/wasm-bindgen"] wasm = ["futures-timer/wasm-bindgen"]
[package.metadata.docs.rs] [package.metadata.docs.rs]

View file

@ -9,6 +9,15 @@ use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError}; use crate::{ws::LockedWebSocketWrite, WispError};
fn match_payload(payload: Payload) -> BytesMut {
match payload {
Payload::Bytes(x) => x,
Payload::Owned(x) => BytesMut::from(x.deref()),
Payload::BorrowedMut(x) => BytesMut::from(x.deref()),
Payload::Borrowed(x) => BytesMut::from(x),
}
}
impl From<OpCode> for crate::ws::OpCode { impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self { fn from(opcode: OpCode) -> Self {
use OpCode::*; use OpCode::*;
@ -30,7 +39,7 @@ impl From<Frame<'_>> for crate::ws::Frame {
Self { Self {
finished: frame.fin, finished: frame.fin,
opcode: frame.opcode.into(), opcode: frame.opcode.into(),
payload: BytesMut::from(frame.payload.deref()), payload: match_payload(frame.payload),
} }
} }
} }

View file

@ -240,7 +240,7 @@ impl Encode for InfoPacket {
bytes.put_u8(self.version.major); bytes.put_u8(self.version.major);
bytes.put_u8(self.version.minor); bytes.put_u8(self.version.minor);
for extension in self.extensions { for extension in self.extensions {
bytes.extend(Bytes::from(extension)); bytes.extend_from_slice(&Bytes::from(extension));
} }
} }
} }
@ -290,7 +290,7 @@ impl Encode for PacketType {
use PacketType as P; use PacketType as P;
match self { match self {
P::Connect(x) => x.encode(bytes), P::Connect(x) => x.encode(bytes),
P::Data(x) => bytes.extend(x), P::Data(x) => bytes.extend_from_slice(&x),
P::Continue(x) => x.encode(bytes), P::Continue(x) => x.encode(bytes),
P::Close(x) => x.encode(bytes), P::Close(x) => x.encode(bytes),
P::Info(x) => x.encode(bytes), P::Info(x) => x.encode(bytes),

View file

@ -4,15 +4,15 @@ use crate::{
CloseReason, Packet, Role, StreamType, WispError, CloseReason, Packet, Role, StreamType, WispError,
}; };
pub use async_io_stream::IoStream;
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use event_listener::Event; use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
select, stream, select,
stream::{self, IntoAsyncRead, SplitSink, SplitStream},
task::{Context, Poll}, task::{Context, Poll},
FutureExt, Sink, Stream, AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, StreamExt, TryStreamExt,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -21,6 +21,7 @@ use std::{
atomic::{AtomicBool, AtomicU32, Ordering}, atomic::{AtomicBool, AtomicU32, Ordering},
Arc, Arc,
}, },
task::ready,
}; };
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
@ -367,26 +368,24 @@ pin_project! {
} }
impl MuxStreamIo { impl MuxStreamIo {
/// Turn the stream into one that implements futures `AsyncRead + AsyncWrite`. /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`.
/// pub fn into_asyncrw(self) -> MuxStreamAsyncRW {
/// Enable the `tokio_io` feature to implement the tokio version of `AsyncRead` and let (tx, rx) = self.split();
/// `AsyncWrite`. MuxStreamAsyncRW {
pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> { rx: MuxStreamAsyncRead::new(rx),
IoStream::new(self) tx: MuxStreamAsyncWrite::new(tx),
}
} }
} }
impl Stream for MuxStreamIo { impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>; type Item = Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project() self.project().rx.poll_next(cx).map(|x| x.map(Ok))
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
} }
} }
impl Sink<Vec<u8>> for MuxStreamIo { impl Sink<Bytes> for MuxStreamIo {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
@ -394,10 +393,10 @@ impl Sink<Vec<u8>> for MuxStreamIo {
.poll_ready(cx) .poll_ready(cx)
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
self.project() self.project()
.tx .tx
.start_send(item.into()) .start_send(item)
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -413,3 +412,148 @@ impl Sink<Vec<u8>> for MuxStreamIo {
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
} }
pin_project! {
/// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`.
pub struct MuxStreamAsyncRW {
#[pin]
rx: MuxStreamAsyncRead,
#[pin]
tx: MuxStreamAsyncWrite,
}
}
impl MuxStreamAsyncRW {
/// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) {
(self.rx, self.tx)
}
}
impl AsyncRead for MuxStreamAsyncRW {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read_vectored(cx, bufs)
}
}
impl AsyncBufRead for MuxStreamAsyncRW {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().rx.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt)
}
}
impl AsyncWrite for MuxStreamAsyncRW {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().tx.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_close(cx)
}
}
pin_project! {
/// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`.
pub struct MuxStreamAsyncRead {
#[pin]
rx: IntoAsyncRead<SplitStream<MuxStreamIo>>,
}
}
impl MuxStreamAsyncRead {
pub(crate) fn new(stream: SplitStream<MuxStreamIo>) -> Self {
Self {
rx: stream.into_async_read(),
}
}
}
impl AsyncRead for MuxStreamAsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().rx.poll_read_vectored(cx, bufs)
}
}
impl AsyncBufRead for MuxStreamAsyncRead {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().rx.poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt)
}
}
pin_project! {
/// Write side of a multiplexor stream that implements futures `AsyncWrite`.
pub struct MuxStreamAsyncWrite {
#[pin]
tx: SplitSink<MuxStreamIo, Bytes>,
}
}
impl MuxStreamAsyncWrite {
pub(crate) fn new(sink: SplitSink<MuxStreamIo, Bytes>) -> Self {
Self { tx: sink }
}
}
impl AsyncWrite for MuxStreamAsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut this = self.project();
ready!(this.tx.as_mut().poll_ready(cx))?;
match this.tx.start_send(Bytes::copy_from_slice(buf)) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().tx.poll_close(cx)
}
}