use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, str::FromStr, }; use anyhow::Context; use bytes::BytesMut; use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use tokio::{ fs::{remove_file, try_exists}, net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream}, }; use tokio_util::either::Either; use uuid::Uuid; use wisp_mux::{ConnectPacket, StreamType}; use crate::{config::SocketType, CONFIG}; pub enum ServerListener { Tcp(TcpListener), Unix(UnixListener), } pub type ServerStream = Either; pub type ServerStreamRead = Either; pub type ServerStreamWrite = Either; pub trait ServerStreamExt { fn split(self) -> (ServerStreamRead, ServerStreamWrite); } impl ServerStreamExt for ServerStream { fn split(self) -> (ServerStreamRead, ServerStreamWrite) { match self { Self::Left(x) => { let (r, w) = x.into_split(); (Either::Left(r), Either::Left(w)) } Self::Right(x) => { let (r, w) = x.into_split(); (Either::Right(r), Either::Right(w)) } } } } impl ServerListener { pub async fn new() -> anyhow::Result { Ok(match CONFIG.server.socket { SocketType::Tcp => Self::Tcp( TcpListener::bind(&CONFIG.server.bind) .await .with_context(|| { format!("failed to bind to tcp address `{}`", CONFIG.server.bind) })?, ), SocketType::Unix => { if try_exists(&CONFIG.server.bind).await? { remove_file(&CONFIG.server.bind).await?; } Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| { format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) })?) } }) } pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { match self { Self::Tcp(x) => x .accept() .await .map(|(x, y)| (Either::Left(x), y.to_string())) .context("failed to accept tcp connection"), Self::Unix(x) => x .accept() .await .map(|(x, y)| { ( Either::Right(x), y.as_pathname() .and_then(|x| x.to_str()) .map(ToString::to_string) .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), ) }) .context("failed to accept unix socket connection"), } } } pub enum ClientStream { Tcp(TcpStream), Udp(UdpSocket), Blocked, Invalid, } pub enum ResolvedPacket { Valid(ConnectPacket), NoResolvedAddrs, Blocked, } impl ClientStream { pub async fn resolve(packet: ConnectPacket) -> anyhow::Result { if !CONFIG.stream.allow_udp && packet.stream_type == StreamType::Udp { return Ok(ResolvedPacket::Blocked); } if CONFIG .stream .blocked_ports() .iter() .any(|x| x.contains(&packet.destination_port)) && !CONFIG .stream .allowed_ports() .iter() .any(|x| x.contains(&packet.destination_port)) { return Ok(ResolvedPacket::Blocked); } if let Ok(addr) = IpAddr::from_str(&packet.destination_hostname) { if !CONFIG.stream.allow_direct_ip { return Ok(ResolvedPacket::Blocked); } if addr.is_loopback() && !CONFIG.stream.allow_loopback { return Ok(ResolvedPacket::Blocked); } if addr.is_multicast() && !CONFIG.stream.allow_multicast { return Ok(ResolvedPacket::Blocked); } if (addr.is_global() && !CONFIG.stream.allow_global) || (!addr.is_global() && !CONFIG.stream.allow_non_global) { return Ok(ResolvedPacket::Blocked); } } if CONFIG .stream .blocked_hosts() .is_match(&packet.destination_hostname) && !CONFIG .stream .allowed_hosts() .is_match(&packet.destination_hostname) { return Ok(ResolvedPacket::Blocked); } let packet = lookup_host(packet.destination_hostname + ":0") .await .context("failed to resolve hostname")? .filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4()) .map(|x| ConnectPacket { stream_type: packet.stream_type, destination_hostname: x.ip().to_string(), destination_port: packet.destination_port, }) .next(); Ok(packet .map(ResolvedPacket::Valid) .unwrap_or(ResolvedPacket::NoResolvedAddrs)) } pub async fn connect(packet: ConnectPacket) -> anyhow::Result { let ipaddr = IpAddr::from_str(&packet.destination_hostname) .context("failed to parse hostname as ipaddr")?; match packet.stream_type { StreamType::Tcp => { let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port)) .await .with_context(|| { format!("failed to connect to host {}", packet.destination_hostname) })?; Ok(ClientStream::Tcp(stream)) } StreamType::Udp => { if !CONFIG.stream.allow_udp { return Ok(ClientStream::Blocked); } let bind_addr = if ipaddr.is_ipv4() { SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) } else { SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0) }; let stream = UdpSocket::bind(bind_addr).await?; stream .connect(SocketAddr::new(ipaddr, packet.destination_port)) .await?; Ok(ClientStream::Udp(stream)) } StreamType::Unknown(_) => Ok(ClientStream::Invalid), } } } pub enum WebSocketFrame { Data(BytesMut), Close, Ignore, } pub struct WebSocketStreamWrapper(pub FragmentCollector>); impl WebSocketStreamWrapper { pub async fn read(&mut self) -> Result { let frame = self.0.read_frame().await?; Ok(match frame.opcode { OpCode::Text | OpCode::Binary => WebSocketFrame::Data(frame.payload.into()), OpCode::Close => WebSocketFrame::Close, _ => WebSocketFrame::Ignore, }) } pub async fn write(&mut self, data: &[u8]) -> Result<(), WebSocketError> { self.0 .write_frame(Frame::binary(Payload::Borrowed(data))) .await } pub async fn close(&mut self, code: u16, reason: &[u8]) -> Result<(), WebSocketError> { self.0.write_frame(Frame::close(code, reason)).await } }