add pty listener support

This commit is contained in:
Toshit Chawda 2024-09-01 18:48:46 -07:00
parent 67c9e3d982
commit 80c91c9381
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 274 additions and 16 deletions

20
Cargo.lock generated
View file

@ -321,6 +321,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "clap"
version = "4.5.16"
@ -567,9 +573,9 @@ dependencies = [
"lazy_static",
"libc",
"log",
"nix",
"pty-process",
"regex",
"rustix",
"serde",
"serde_json",
"serde_yaml",
@ -1145,6 +1151,18 @@ dependencies = [
"getrandom",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "nohash-hasher"
version = "0.2.0"

View file

@ -20,9 +20,9 @@ hyper-util = { version = "0.1.7", features = ["tokio"] }
lazy_static = "1.5.0"
libc = { version = "0.2.158", optional = true }
log = { version = "0.4.22", features = ["serde", "std"] }
nix = { version = "0.29.0", features = ["term"] }
pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true }
regex = "1.10.6"
rustix = { version = "0.38.34", optional = true }
serde = { version = "1.0.208", features = ["derive"] }
serde_json = { version = "1.0.125", optional = true }
serde_yaml = { version = "0.9.34", optional = true }
@ -42,4 +42,4 @@ json = ["dep:serde_json"]
yaml = ["dep:serde_yaml"]
toml = ["dep:toml"]
twisp = ["dep:pty-process", "dep:libc", "dep:rustix", "dep:async-trait", "dep:shell-words"]
twisp = ["dep:pty-process", "dep:libc", "dep:async-trait", "dep:shell-words"]

View file

@ -20,6 +20,9 @@ pub enum SocketType {
Tcp,
/// Unix socket listener.
Unix,
/// File "socket" "listener".
/// "Accepts" a "connection" immediately.
File,
}
#[derive(Serialize, Deserialize, Default, Debug)]
@ -47,6 +50,8 @@ pub struct ServerConfig {
pub resolve_ipv6: bool,
/// Whether or not to enable TCP nodelay on client TCP streams.
pub tcp_nodelay: bool,
/// Whether or not to set "raw mode" for the file.
pub file_raw_mode: bool,
/// Whether or not to show what upstreams each client is connected to in stats. This can
/// heavily increase the size of the stats.
@ -206,6 +211,7 @@ impl Default for ServerConfig {
transport: SocketTransport::default(),
resolve_ipv6: false,
tcp_nodelay: false,
file_raw_mode: false,
verbose_stats: true,
stats_endpoint: "/stats".to_string(),

View file

@ -1,16 +1,190 @@
use std::{os::fd::AsFd, path::PathBuf, pin::Pin};
use anyhow::Context;
use tokio::{
fs::{remove_file, try_exists},
fs::{remove_file, try_exists, File},
io::{AsyncBufRead, AsyncRead, AsyncWrite},
net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream},
};
use tokio_util::either::Either;
use uuid::Uuid;
use crate::{config::SocketType, CONFIG};
pub type ServerStream = Either<TcpStream, UnixStream>;
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
pub enum Trio<A, B, C> {
One(A),
Two(B),
Three(C),
}
impl<A: AsyncRead + Unpin, B: AsyncRead + Unpin, C: AsyncRead + Unpin> AsyncRead for Trio<A, B, C> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_read(cx, buf),
Self::Two(x) => Pin::new(x).poll_read(cx, buf),
Self::Three(x) => Pin::new(x).poll_read(cx, buf),
}
}
}
impl<A: AsyncBufRead + Unpin, B: AsyncBufRead + Unpin, C: AsyncBufRead + Unpin> AsyncBufRead
for Trio<A, B, C>
{
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<&[u8]>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_fill_buf(cx),
Self::Two(x) => Pin::new(x).poll_fill_buf(cx),
Self::Three(x) => Pin::new(x).poll_fill_buf(cx),
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
match self.get_mut() {
Self::One(x) => Pin::new(x).consume(amt),
Self::Two(x) => Pin::new(x).consume(amt),
Self::Three(x) => Pin::new(x).consume(amt),
}
}
}
impl<A: AsyncWrite + Unpin, B: AsyncWrite + Unpin, C: AsyncWrite + Unpin> AsyncWrite
for Trio<A, B, C>
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_write(cx, buf),
Self::Two(x) => Pin::new(x).poll_write(cx, buf),
Self::Three(x) => Pin::new(x).poll_write(cx, buf),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::One(x) => x.is_write_vectored(),
Self::Two(x) => x.is_write_vectored(),
Self::Three(x) => x.is_write_vectored(),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Two(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Three(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_flush(cx),
Self::Two(x) => Pin::new(x).poll_flush(cx),
Self::Three(x) => Pin::new(x).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
Self::One(x) => Pin::new(x).poll_shutdown(cx),
Self::Two(x) => Pin::new(x).poll_shutdown(cx),
Self::Three(x) => Pin::new(x).poll_shutdown(cx),
}
}
}
pub struct Duplex<A, B>(A, B);
impl<A, B> Duplex<A, B> {
pub fn new(a: A, b: B) -> Self {
Self(a, b)
}
pub fn into_split(self) -> (A, B) {
(self.0, self.1)
}
}
impl<A: AsyncRead + Unpin, B: Unpin> AsyncRead for Duplex<A, B> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
}
}
impl<A: AsyncBufRead + Unpin, B: Unpin> AsyncBufRead for Duplex<A, B> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().0).poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.get_mut().0).consume(amt)
}
}
impl<A: Unpin, B: AsyncWrite + Unpin> AsyncWrite for Duplex<A, B> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().1).poll_write(cx, buf)
}
fn is_write_vectored(&self) -> bool {
self.1.is_write_vectored()
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().1).poll_write_vectored(cx, bufs)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().1).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().1).poll_shutdown(cx)
}
}
pub type ServerStream = Trio<TcpStream, UnixStream, Duplex<File, File>>;
pub type ServerStreamRead = Trio<tcp::OwnedReadHalf, unix::OwnedReadHalf, File>;
pub type ServerStreamWrite = Trio<tcp::OwnedWriteHalf, unix::OwnedWriteHalf, File>;
pub trait ServerStreamExt {
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
@ -19,13 +193,17 @@ pub trait ServerStreamExt {
impl ServerStreamExt for ServerStream {
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
match self {
Self::Left(x) => {
Self::One(x) => {
let (r, w) = x.into_split();
(Either::Left(r), Either::Left(w))
(Trio::One(r), Trio::One(w))
}
Self::Right(x) => {
Self::Two(x) => {
let (r, w) = x.into_split();
(Either::Right(r), Either::Right(w))
(Trio::Two(r), Trio::Two(w))
}
Self::Three(x) => {
let (r, w) = x.into_split();
(Trio::Three(r), Trio::Three(w))
}
}
}
@ -34,6 +212,7 @@ impl ServerStreamExt for ServerStream {
pub enum ServerListener {
Tcp(TcpListener),
Unix(UnixListener),
File(Option<PathBuf>),
}
impl ServerListener {
@ -54,10 +233,15 @@ impl ServerListener {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
}
SocketType::File => {
Self::File(Some(PathBuf::try_from(&CONFIG.server.bind).with_context(
|| format!("failed to parse path `{}` for file", CONFIG.server.bind),
)?))
}
})
}
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
pub async fn accept(&mut self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => {
let (stream, addr) = x
@ -69,14 +253,14 @@ impl ServerListener {
.set_nodelay(true)
.context("failed to set tcp nodelay")?;
}
Ok((Either::Left(stream), addr.to_string()))
Ok((Trio::One(stream), addr.to_string()))
}
Self::Unix(x) => x
.accept()
.await
.map(|(x, y)| {
(
Either::Right(x),
Trio::Two(x),
y.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string)
@ -84,6 +268,56 @@ impl ServerListener {
)
})
.context("failed to accept unix socket connection"),
Self::File(path) => {
if let Some(path) = path.take() {
let rx = File::options()
.read(true)
.write(false)
.open(&path)
.await
.context("failed to open read file")?;
if CONFIG.server.file_raw_mode {
let mut termios = nix::sys::termios::tcgetattr(rx.as_fd())
.context("failed to get termios for read file")?
.clone();
nix::sys::termios::cfmakeraw(&mut termios);
nix::sys::termios::tcsetattr(
rx.as_fd(),
nix::sys::termios::SetArg::TCSANOW,
&termios,
)
.context("failed to set raw mode for read file")?;
}
let tx = File::options()
.read(false)
.write(true)
.open(&path)
.await
.context("failed to open write file")?;
if CONFIG.server.file_raw_mode {
let mut termios = nix::sys::termios::tcgetattr(tx.as_fd())
.context("failed to get termios for write file")?
.clone();
nix::sys::termios::cfmakeraw(&mut termios);
nix::sys::termios::tcsetattr(
tx.as_fd(),
nix::sys::termios::SetArg::TCSANOW,
&termios,
)
.context("failed to set raw mode for write file")?;
}
Ok((
Trio::Three(Duplex::new(rx, tx)),
path.to_string_lossy().to_string(),
))
} else {
std::future::pending().await
}
}
}
}
}

View file

@ -175,7 +175,7 @@ async fn main() -> anyhow::Result<()> {
}
});
let listener = ServerListener::new().await?;
let mut listener = ServerListener::new().await?;
loop {
let (stream, id) = listener.accept().await?;
tokio::spawn(async move {