From 80c91c9381bbeb7c6697ce0d5c2945d0a1aff4ec Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 1 Sep 2024 18:48:46 -0700 Subject: [PATCH] add pty listener support --- Cargo.lock | 20 +++- server/Cargo.toml | 4 +- server/src/config.rs | 6 + server/src/listener.rs | 258 +++++++++++++++++++++++++++++++++++++++-- server/src/main.rs | 2 +- 5 files changed, 274 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1873e41..d10e6e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clap" version = "4.5.16" @@ -567,9 +573,9 @@ dependencies = [ "lazy_static", "libc", "log", + "nix", "pty-process", "regex", - "rustix", "serde", "serde_json", "serde_yaml", @@ -1145,6 +1151,18 @@ dependencies = [ "getrandom", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nohash-hasher" version = "0.2.0" diff --git a/server/Cargo.toml b/server/Cargo.toml index 76dfd83..0d7fb3e 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -20,9 +20,9 @@ hyper-util = { version = "0.1.7", features = ["tokio"] } lazy_static = "1.5.0" libc = { version = "0.2.158", optional = true } log = { version = "0.4.22", features = ["serde", "std"] } +nix = { version = "0.29.0", features = ["term"] } pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true } regex = "1.10.6" -rustix = { version = "0.38.34", optional = true } serde = { version = "1.0.208", features = ["derive"] } serde_json = { version = "1.0.125", optional = true } serde_yaml = { version = "0.9.34", optional = true } @@ -42,4 +42,4 @@ json = ["dep:serde_json"] yaml = ["dep:serde_yaml"] toml = ["dep:toml"] -twisp = ["dep:pty-process", "dep:libc", "dep:rustix", "dep:async-trait", "dep:shell-words"] +twisp = ["dep:pty-process", "dep:libc", "dep:async-trait", "dep:shell-words"] diff --git a/server/src/config.rs b/server/src/config.rs index 6f4eb92..5ed11d3 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -20,6 +20,9 @@ pub enum SocketType { Tcp, /// Unix socket listener. Unix, + /// File "socket" "listener". + /// "Accepts" a "connection" immediately. + File, } #[derive(Serialize, Deserialize, Default, Debug)] @@ -47,6 +50,8 @@ pub struct ServerConfig { pub resolve_ipv6: bool, /// Whether or not to enable TCP nodelay on client TCP streams. pub tcp_nodelay: bool, + /// Whether or not to set "raw mode" for the file. + pub file_raw_mode: bool, /// Whether or not to show what upstreams each client is connected to in stats. This can /// heavily increase the size of the stats. @@ -206,6 +211,7 @@ impl Default for ServerConfig { transport: SocketTransport::default(), resolve_ipv6: false, tcp_nodelay: false, + file_raw_mode: false, verbose_stats: true, stats_endpoint: "/stats".to_string(), diff --git a/server/src/listener.rs b/server/src/listener.rs index 9d48d73..6529455 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -1,16 +1,190 @@ +use std::{os::fd::AsFd, path::PathBuf, pin::Pin}; + use anyhow::Context; use tokio::{ - fs::{remove_file, try_exists}, + fs::{remove_file, try_exists, File}, + io::{AsyncBufRead, AsyncRead, AsyncWrite}, net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream}, }; -use tokio_util::either::Either; use uuid::Uuid; use crate::{config::SocketType, CONFIG}; -pub type ServerStream = Either; -pub type ServerStreamRead = Either; -pub type ServerStreamWrite = Either; +pub enum Trio { + One(A), + Two(B), + Three(C), +} + +impl AsyncRead for Trio { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_read(cx, buf), + Self::Two(x) => Pin::new(x).poll_read(cx, buf), + Self::Three(x) => Pin::new(x).poll_read(cx, buf), + } + } +} + +impl AsyncBufRead + for Trio +{ + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_fill_buf(cx), + Self::Two(x) => Pin::new(x).poll_fill_buf(cx), + Self::Three(x) => Pin::new(x).poll_fill_buf(cx), + } + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + match self.get_mut() { + Self::One(x) => Pin::new(x).consume(amt), + Self::Two(x) => Pin::new(x).consume(amt), + Self::Three(x) => Pin::new(x).consume(amt), + } + } +} + +impl AsyncWrite + for Trio +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_write(cx, buf), + Self::Two(x) => Pin::new(x).poll_write(cx, buf), + Self::Three(x) => Pin::new(x).poll_write(cx, buf), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::One(x) => x.is_write_vectored(), + Self::Two(x) => x.is_write_vectored(), + Self::Three(x) => x.is_write_vectored(), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_write_vectored(cx, bufs), + Self::Two(x) => Pin::new(x).poll_write_vectored(cx, bufs), + Self::Three(x) => Pin::new(x).poll_write_vectored(cx, bufs), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_flush(cx), + Self::Two(x) => Pin::new(x).poll_flush(cx), + Self::Three(x) => Pin::new(x).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::One(x) => Pin::new(x).poll_shutdown(cx), + Self::Two(x) => Pin::new(x).poll_shutdown(cx), + Self::Three(x) => Pin::new(x).poll_shutdown(cx), + } + } +} + +pub struct Duplex(A, B); + +impl Duplex { + pub fn new(a: A, b: B) -> Self { + Self(a, b) + } + + pub fn into_split(self) -> (A, B) { + (self.0, self.1) + } +} + +impl AsyncRead for Duplex { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().0).poll_read(cx, buf) + } +} + +impl AsyncBufRead for Duplex { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().0).poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut self.get_mut().0).consume(amt) + } +} + +impl AsyncWrite for Duplex { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().1).poll_write(cx, buf) + } + + fn is_write_vectored(&self) -> bool { + self.1.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().1).poll_write_vectored(cx, bufs) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().1).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().1).poll_shutdown(cx) + } +} + +pub type ServerStream = Trio>; +pub type ServerStreamRead = Trio; +pub type ServerStreamWrite = Trio; pub trait ServerStreamExt { fn split(self) -> (ServerStreamRead, ServerStreamWrite); @@ -19,13 +193,17 @@ pub trait ServerStreamExt { impl ServerStreamExt for ServerStream { fn split(self) -> (ServerStreamRead, ServerStreamWrite) { match self { - Self::Left(x) => { + Self::One(x) => { let (r, w) = x.into_split(); - (Either::Left(r), Either::Left(w)) + (Trio::One(r), Trio::One(w)) } - Self::Right(x) => { + Self::Two(x) => { let (r, w) = x.into_split(); - (Either::Right(r), Either::Right(w)) + (Trio::Two(r), Trio::Two(w)) + } + Self::Three(x) => { + let (r, w) = x.into_split(); + (Trio::Three(r), Trio::Three(w)) } } } @@ -34,6 +212,7 @@ impl ServerStreamExt for ServerStream { pub enum ServerListener { Tcp(TcpListener), Unix(UnixListener), + File(Option), } impl ServerListener { @@ -54,10 +233,15 @@ impl ServerListener { format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) })?) } + SocketType::File => { + Self::File(Some(PathBuf::try_from(&CONFIG.server.bind).with_context( + || format!("failed to parse path `{}` for file", CONFIG.server.bind), + )?)) + } }) } - pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { + pub async fn accept(&mut self) -> anyhow::Result<(ServerStream, String)> { match self { Self::Tcp(x) => { let (stream, addr) = x @@ -69,14 +253,14 @@ impl ServerListener { .set_nodelay(true) .context("failed to set tcp nodelay")?; } - Ok((Either::Left(stream), addr.to_string())) + Ok((Trio::One(stream), addr.to_string())) } Self::Unix(x) => x .accept() .await .map(|(x, y)| { ( - Either::Right(x), + Trio::Two(x), y.as_pathname() .and_then(|x| x.to_str()) .map(ToString::to_string) @@ -84,6 +268,56 @@ impl ServerListener { ) }) .context("failed to accept unix socket connection"), + Self::File(path) => { + if let Some(path) = path.take() { + let rx = File::options() + .read(true) + .write(false) + .open(&path) + .await + .context("failed to open read file")?; + + if CONFIG.server.file_raw_mode { + let mut termios = nix::sys::termios::tcgetattr(rx.as_fd()) + .context("failed to get termios for read file")? + .clone(); + nix::sys::termios::cfmakeraw(&mut termios); + nix::sys::termios::tcsetattr( + rx.as_fd(), + nix::sys::termios::SetArg::TCSANOW, + &termios, + ) + .context("failed to set raw mode for read file")?; + } + + let tx = File::options() + .read(false) + .write(true) + .open(&path) + .await + .context("failed to open write file")?; + + if CONFIG.server.file_raw_mode { + let mut termios = nix::sys::termios::tcgetattr(tx.as_fd()) + .context("failed to get termios for write file")? + .clone(); + nix::sys::termios::cfmakeraw(&mut termios); + nix::sys::termios::tcsetattr( + tx.as_fd(), + nix::sys::termios::SetArg::TCSANOW, + &termios, + ) + .context("failed to set raw mode for write file")?; + } + + Ok(( + Trio::Three(Duplex::new(rx, tx)), + path.to_string_lossy().to_string(), + )) + } else { + std::future::pending().await + } + } } } } diff --git a/server/src/main.rs b/server/src/main.rs index 321f3f3..91442e0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -175,7 +175,7 @@ async fn main() -> anyhow::Result<()> { } }); - let listener = ServerListener::new().await?; + let mut listener = ServerListener::new().await?; loop { let (stream, id) = listener.accept().await?; tokio::spawn(async move {