From 3b14ae0d91dc63338d274abfe134dde9ed302a4f Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 23 Aug 2024 21:33:37 -0700 Subject: [PATCH] add twisp to epoxy-server --- Cargo.lock | 53 +++++++++++++ server/Cargo.toml | 9 +++ server/src/config.rs | 25 +++--- server/src/handle/mod.rs | 2 + server/src/handle/twisp.rs | 148 +++++++++++++++++++++++++++++++++++ server/src/handle/wisp.rs | 51 +++++++++++- server/src/handle/wsproxy.rs | 4 + server/src/listener.rs | 35 +++++++-- server/src/stream.rs | 46 +++++++++-- 9 files changed, 340 insertions(+), 33 deletions(-) create mode 100644 server/src/handle/twisp.rs diff --git a/Cargo.lock b/Cargo.lock index 4664f5c..2db31df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -545,7 +545,9 @@ name = "epoxy-server" version = "2.0.0" dependencies = [ "anyhow", + "async-trait", "bytes", + "cfg-if", "clap", "dashmap", "env_logger", @@ -555,11 +557,15 @@ dependencies = [ "hyper", "hyper-util", "lazy_static", + "libc", "log", + "pty-process", "regex", + "rustix", "serde", "serde_json", "serde_yaml", + "shell-words", "tokio", "tokio-util", "toml", @@ -573,6 +579,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "event-listener" version = "5.3.1" @@ -1022,6 +1038,12 @@ version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "lock_api" version = "0.4.12" @@ -1255,6 +1277,17 @@ dependencies = [ "prost", ] +[[package]] +name = "pty-process" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8749b545e244c90bf74a5767764cc2194f1888bb42f84015486a64c82bea5cc0" +dependencies = [ + "libc", + "rustix", + "tokio", +] + [[package]] name = "quote" version = "1.0.36" @@ -1368,6 +1401,20 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags", + "errno", + "itoa", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustls" version = "0.23.12" @@ -1509,6 +1556,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shlex" version = "1.3.0" diff --git a/server/Cargo.toml b/server/Cargo.toml index 1d6c0fb..a809bb0 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,7 +5,9 @@ edition = "2021" [dependencies] anyhow = "1.0.86" +async-trait = { version = "0.1.81", optional = true } bytes = "1.7.1" +cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } dashmap = "6.0.1" env_logger = "0.11.5" @@ -15,11 +17,15 @@ http-body-util = "0.1.2" hyper = { version = "1.4.1", features = ["server", "http1"] } hyper-util = { version = "0.1.7", features = ["tokio"] } lazy_static = "1.5.0" +libc = { version = "0.2.158", optional = true } log = { version = "0.4.22", features = ["serde", "std"] } +pty-process = { version = "0.4.0", features = ["async", "tokio"], optional = true } regex = "1.10.6" +rustix = { version = "0.38.34", optional = true } serde = { version = "1.0.208", features = ["derive"] } serde_json = { version = "1.0.125", optional = true } serde_yaml = { version = "0.9.34", optional = true } +shell-words = { version = "1.1.0", optional = true } tokio = { version = "1.39.3", features = ["full"] } tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } toml = { version = "0.8.19", optional = true } @@ -28,6 +34,9 @@ wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", [features] default = ["toml"] + json = ["dep:serde_json"] yaml = ["dep:serde_yaml"] toml = ["dep:toml"] + +twisp = ["dep:pty-process", "dep:libc", "dep:rustix", "dep:async-trait", "dep:shell-words"] diff --git a/server/src/config.rs b/server/src/config.rs index 5d0897a..6f4eb92 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -104,6 +104,9 @@ pub struct StreamConfig { pub allow_udp: bool, /// Whether or not to enable nonstandard legacy wsproxy UDP streams. pub allow_wsproxy_udp: bool, + /// Whether or not to allow TWisp streams. + #[cfg(feature = "twisp")] + pub allow_twisp: bool, /// Whether or not to allow connections to IP addresses. pub allow_direct_ip: bool, @@ -160,8 +163,6 @@ struct ConfigCache { pub allowed_udp_hosts: RegexSet, pub blocked_udp_hosts: RegexSet, - - pub wisp_config: (Option>, u32), } lazy_static! { @@ -188,14 +189,13 @@ lazy_static! { allowed_udp_hosts: RegexSet::new(&CONFIG.stream.allow_udp_hosts).unwrap(), blocked_udp_hosts: RegexSet::new(&CONFIG.stream.block_udp_hosts).unwrap(), - - wisp_config: CONFIG.wisp.to_opts_inner().unwrap(), } }; } pub fn validate_config_cache() { - let _ = CONFIG_CACHE.wisp_config; + let _ = CONFIG_CACHE.allowed_ports; + CONFIG.wisp.to_opts().unwrap(); } impl Default for ServerConfig { @@ -236,11 +236,9 @@ impl Default for WispConfig { } impl WispConfig { - pub(super) fn to_opts_inner( - &self, - ) -> anyhow::Result<(Option>, u32)> { + pub fn to_opts(&self) -> anyhow::Result<(Option>, u32)> { if self.wisp_v2 { - let mut extensions: Vec> = Vec::new(); + let mut extensions: Vec = Vec::new(); if self.extensions.contains(&ProtocolExtension::Udp) { extensions.push(Box::new(UdpProtocolExtensionBuilder)); @@ -257,13 +255,6 @@ impl WispConfig { Ok((None, self.buffer_size)) } } - - pub fn to_opts(&self) -> (Option<&'static [AnyProtocolExtensionBuilder]>, u32) { - ( - CONFIG_CACHE.wisp_config.0.as_deref(), - CONFIG_CACHE.wisp_config.1, - ) - } } impl Default for StreamConfig { @@ -273,6 +264,8 @@ impl Default for StreamConfig { allow_udp: true, allow_wsproxy_udp: false, + #[cfg(feature = "twisp")] + allow_twisp: false, allow_direct_ip: true, allow_loopback: true, diff --git a/server/src/handle/mod.rs b/server/src/handle/mod.rs index 90663fc..a01971a 100644 --- a/server/src/handle/mod.rs +++ b/server/src/handle/mod.rs @@ -1,5 +1,7 @@ mod wisp; mod wsproxy; +#[cfg(feature = "twisp")] +pub mod twisp; pub use wisp::handle_wisp; pub use wsproxy::handle_wsproxy; diff --git a/server/src/handle/twisp.rs b/server/src/handle/twisp.rs new file mode 100644 index 0000000..bef1ceb --- /dev/null +++ b/server/src/handle/twisp.rs @@ -0,0 +1,148 @@ +use std::{ + collections::HashMap, + os::fd::{AsRawFd, RawFd}, + sync::Arc, +}; + +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use pty_process::{Pty, Size}; +use tokio::{io::copy, process::Child, select, sync::Mutex}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}; +use wisp_mux::{ + extensions::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}, + ws::{LockedWebSocketWrite, WebSocketRead}, + MuxStreamAsyncRead, MuxStreamAsyncWrite, WispError, +}; + +pub type TwispMap = Arc>>; + +pub const STREAM_TYPE: u8 = 0x03; + +#[derive(Debug, Clone)] +pub struct TWispServerProtocolExtension(TwispMap); + +impl TWispServerProtocolExtension { + const ID: u8 = 0xF0; +} + +#[async_trait] +impl ProtocolExtension for TWispServerProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + // Resize PTY + &[0xF0] + } + + fn encode(&self) -> Bytes { + Bytes::new() + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> std::result::Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + mut packet: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> std::result::Result<(), WispError> { + if packet.remaining() < 4 + 2 + 2 { + return Err(WispError::PacketTooSmall); + } + let stream_id = packet.get_u32_le(); + let row = packet.get_u16_le(); + let col = packet.get_u16_le(); + + if let Some(pty) = self.0.lock().await.get(&stream_id) { + let _ = set_term_size(*pty, Size::new(row, col)); + } + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +impl From for AnyProtocolExtension { + fn from(value: TWispServerProtocolExtension) -> Self { + AnyProtocolExtension::new(value) + } +} + +pub struct TWispServerProtocolExtensionBuilder(TwispMap); + +impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + TWispServerProtocolExtension::ID + } + + fn build_from_bytes( + &self, + _: Bytes, + _: wisp_mux::Role, + ) -> std::result::Result { + Ok(TWispServerProtocolExtension(self.0.clone()).into()) + } + + fn build_to_extension(&self, _: wisp_mux::Role) -> AnyProtocolExtension { + TWispServerProtocolExtension(self.0.clone()).into() + } +} + +fn set_term_size(fd: RawFd, size: Size) -> anyhow::Result<()> { + let size = libc::winsize::from(size); + let ret = unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, std::ptr::addr_of!(size)) }; + if ret == -1 { + Err(rustix::io::Errno::from_raw_os_error( + std::io::Error::last_os_error().raw_os_error().unwrap_or(0), + ) + .into()) + } else { + Ok(()) + } +} + +pub fn new_map() -> TwispMap { + Arc::new(Mutex::new(HashMap::new())) +} + +pub fn new_ext(map: TwispMap) -> Box { + Box::new(TWispServerProtocolExtensionBuilder(map)) +} + +pub async fn handle_twisp( + id: u32, + streamrx: &mut MuxStreamAsyncRead, + streamtx: &mut MuxStreamAsyncWrite, + map: TwispMap, + mut pty: Pty, + mut cmd: Child, +) -> anyhow::Result<()> { + map.lock().await.insert(id, pty.as_raw_fd()); + let ret = async { + let (mut ptyrx, mut ptytx) = pty.split(); + let mut streamrx = streamrx.compat(); + let mut streamtx = streamtx.compat_write(); + + select! { + x = copy(&mut ptyrx, &mut streamtx) => x.map(|_| {}), + x = copy(&mut streamrx, &mut ptytx) => x.map(|_| {}), + x = cmd.wait() => x.map(|_| {}), + }?; + Ok(()) + } + .await; + map.lock().await.remove(&id); + let _ = cmd.kill().await; + ret +} diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 4c63518..1f75d79 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -1,4 +1,5 @@ use anyhow::Context; +use cfg_if::cfg_if; use futures_util::FutureExt; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, @@ -9,7 +10,8 @@ use tokio::{ use tokio_util::compat::FuturesAsyncReadCompatExt; use uuid::Uuid; use wisp_mux::{ - CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux, + CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, + ServerMux, }; use crate::{ @@ -49,7 +51,12 @@ async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow: } } -async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) { +async fn handle_stream( + connect: ConnectPacket, + muxstream: MuxStream, + id: String, + #[cfg(feature = "twisp")] twisp_map: super::twisp::TwispMap, +) { let requested_stream = connect.clone(); let Ok(resolved) = ClientStream::resolve(connect).await else { @@ -146,6 +153,23 @@ async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) } } } + #[cfg(feature = "twisp")] + ClientStream::Pty(cmd, pty) => { + let closer = muxstream.get_close_handle(); + let id = muxstream.stream_id; + let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); + + match super::twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd) + .await + { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } ClientStream::Invalid => { let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; } @@ -161,9 +185,26 @@ async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { let (read, write) = stream; - let (extensions, buffer_size) = CONFIG.wisp.to_opts(); + cfg_if! { + if #[cfg(feature = "twisp")] { + let twisp_map = super::twisp::new_map(); + let (extensions, buffer_size) = CONFIG.wisp.to_opts()?; - let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions) + let extensions = match extensions { + Some(mut exts) => { + exts.push(super::twisp::new_ext(twisp_map.clone())); + Some(exts) + }, + None => { + None + } + }; + } else { + let (extensions, buffer_size) = CONFIG.wisp.to_opts()?; + } + } + + let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref()) .await .context("failed to create server multiplexor")? .with_no_required_extensions(); @@ -177,6 +218,8 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { connect, stream, id.clone(), + #[cfg(feature = "twisp")] + twisp_map.clone(), ))); } diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs index 8898be7..6d5dfbf 100644 --- a/server/src/handle/wsproxy.rs +++ b/server/src/handle/wsproxy.rs @@ -159,6 +159,10 @@ pub async fn handle_wsproxy( } } } + #[cfg(feature = "twisp")] + ClientStream::Pty(_, _) => { + let _ = ws.close(CloseCode::Error.into(), b"twisp is not supported").await; + } ClientStream::Blocked => { let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; } diff --git a/server/src/listener.rs b/server/src/listener.rs index 6a35246..df62c2a 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -46,10 +46,15 @@ fn non_ws_resp() -> Response { async fn ws_upgrade(mut req: Request, callback: T) -> anyhow::Result> where - T: FnOnce(UpgradeFut, bool, bool, String) -> R, - R: Future>, + T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, + R: Future> + Send, { - if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint { + let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); + + if !is_upgrade + && CONFIG.server.enable_stats_endpoint + && req.uri().path() == CONFIG.server.stats_endpoint + { match generate_stats() { Ok(x) => { return Ok(Response::builder() @@ -64,7 +69,7 @@ where .unwrap()) } } - } else if !fastwebsockets::upgrade::is_upgrade_request(&req) { + } else if !is_upgrade { return Ok(non_ws_resp()); } @@ -77,10 +82,18 @@ where .path() .starts_with(&(CONFIG.server.prefix.clone() + "/")) { - (callback)(fut, false, false, req.uri().path().to_string()); + tokio::spawn(async move { + if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await { + error!("error while serving client: {:?}", err); + } + }); } else if CONFIG.wisp.allow_wsproxy { let udp = req.uri().query().unwrap_or_default() == "?udp"; - (callback)(fut, true, udp, req.uri().path().to_string()); + tokio::spawn(async move { + if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await { + error!("error while serving client: {:?}", err); + } + }); } else { return Ok(non_ws_resp()); } @@ -90,7 +103,10 @@ where pub trait ServerStreamExt { fn split(self) -> (ServerStreamRead, ServerStreamWrite); - async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()>; + async fn route( + self, + callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, + ) -> anyhow::Result<()>; } impl ServerStreamExt for ServerStream { @@ -107,7 +123,10 @@ impl ServerStreamExt for ServerStream { } } - async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()> { + async fn route( + self, + callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, + ) -> anyhow::Result<()> { match CONFIG.server.transport { SocketTransport::WebSocket => { let stream = TokioIo::new(self); diff --git a/server/src/stream.rs b/server/src/stream.rs index 6180add..67ca2ff 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -5,6 +5,7 @@ use std::{ use anyhow::Context; use bytes::BytesMut; +use cfg_if::cfg_if; use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; @@ -37,6 +38,8 @@ fn blocked_set(stream_type: StreamType) -> &'static RegexSet { pub enum ClientStream { Tcp(TcpStream), Udp(UdpSocket), + #[cfg(feature = "twisp")] + Pty(tokio::process::Child, pty_process::Pty), Blocked, Invalid, } @@ -50,8 +53,20 @@ pub enum ResolvedPacket { impl ClientStream { pub async fn resolve(packet: ConnectPacket) -> anyhow::Result { - if matches!(packet.stream_type, StreamType::Unknown(_)) { - return Ok(ResolvedPacket::Invalid); + cfg_if! { + if #[cfg(feature = "twisp")] { + if let StreamType::Unknown(ty) = packet.stream_type { + if ty == crate::handle::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 { + return Ok(ResolvedPacket::Valid(packet)); + } else { + return Ok(ResolvedPacket::Invalid); + } + } + } else { + if matches!(packet.stream_type, StreamType::Unknown(_)) { + return Ok(ResolvedPacket::Invalid); + } + } } if !CONFIG.stream.allow_udp && packet.stream_type == StreamType::Udp { @@ -127,11 +142,10 @@ impl ClientStream { } pub async fn connect(packet: ConnectPacket) -> anyhow::Result { - let ipaddr = IpAddr::from_str(&packet.destination_hostname) - .context("failed to parse hostname as ipaddr")?; - match packet.stream_type { StreamType::Tcp => { + let ipaddr = IpAddr::from_str(&packet.destination_hostname) + .context("failed to parse hostname as ipaddr")?; let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port)) .await .with_context(|| { @@ -151,6 +165,9 @@ impl ClientStream { return Ok(ClientStream::Blocked); } + let ipaddr = IpAddr::from_str(&packet.destination_hostname) + .context("failed to parse hostname as ipaddr")?; + let bind_addr = if ipaddr.is_ipv4() { SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) } else { @@ -165,6 +182,25 @@ impl ClientStream { Ok(ClientStream::Udp(stream)) } + #[cfg(feature = "twisp")] + StreamType::Unknown(crate::handle::twisp::STREAM_TYPE) => { + if !CONFIG.stream.allow_twisp { + return Ok(ClientStream::Blocked); + } + + let cmdline: Vec = + shell_words::split(&packet.destination_hostname)? + .into_iter() + .map(Into::into) + .collect(); + let pty = pty_process::Pty::new()?; + + let cmd = pty_process::Command::new(&cmdline[0]) + .args(&cmdline[1..]) + .spawn(&pty.pts()?)?; + + Ok(ClientStream::Pty(cmd, pty)) + } StreamType::Unknown(_) => Ok(ClientStream::Invalid), } }