wsproxy support with udp, logger, other random stuff

This commit is contained in:
Toshit Chawda 2024-07-21 21:35:33 -07:00
parent 4b44567a0e
commit 04b8feaaf3
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
9 changed files with 637 additions and 203 deletions

View file

@ -5,17 +5,16 @@ use std::{
use anyhow::Context;
use bytes::BytesMut;
use futures_util::AsyncBufReadExt;
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
lookup_host,
tcp::{self, OwnedReadHalf, OwnedWriteHalf},
unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
},
fs::{remove_file, try_exists},
net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream},
};
use tokio_util::either::Either;
use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
use crate::{config::SocketType, CONFIG};
@ -58,6 +57,9 @@ impl ServerListener {
})?,
),
SocketType::Unix => {
if try_exists(&CONFIG.server.bind).await? {
remove_file(&CONFIG.server.bind).await?;
}
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
@ -65,12 +67,12 @@ impl ServerListener {
})
}
pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option<String>)> {
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => x
.accept()
.await
.map(|(x, y)| (Either::Left(x), Some(y.to_string())))
.map(|(x, y)| (Either::Left(x), y.to_string()))
.context("failed to accept tcp connection"),
Self::Unix(x) => x
.accept()
@ -80,7 +82,8 @@ impl ServerListener {
Either::Right(x),
y.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string),
.map(ToString::to_string)
.unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"),
)
})
.context("failed to accept unix socket connection"),
@ -207,34 +210,31 @@ impl ClientStream {
}
}
pub async fn copy_read_fast(
mut muxrx: MuxStreamAsyncRead,
mut tcptx: OwnedWriteHalf,
) -> std::io::Result<()> {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume_unpin(i);
}
pub enum WebSocketFrame {
Data(BytesMut),
Close,
Ignore,
}
#[allow(dead_code)]
pub async fn copy_write_fast(
muxtx: MuxStreamWrite,
mut tcprx: OwnedReadHalf,
) -> anyhow::Result<()> {
loop {
let mut buf = BytesMut::with_capacity(8 * 1024);
let amt = tcprx.read(&mut buf).await?;
muxtx.write(&buf[..amt]).await?;
pub struct WebSocketStreamWrapper(pub FragmentCollector<TokioIo<Upgraded>>);
impl WebSocketStreamWrapper {
pub async fn read(&mut self) -> Result<WebSocketFrame, WebSocketError> {
let frame = self.0.read_frame().await?;
Ok(match frame.opcode {
OpCode::Text | OpCode::Binary => WebSocketFrame::Data(frame.payload.into()),
OpCode::Close => WebSocketFrame::Close,
_ => WebSocketFrame::Ignore,
})
}
pub async fn write(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
self.0
.write_frame(Frame::binary(Payload::Borrowed(data)))
.await
}
pub async fn close(&mut self, code: u16, reason: &[u8]) -> Result<(), WebSocketError> {
self.0.write_frame(Frame::close(code, reason)).await
}
}