add wisp lib

This commit is contained in:
Toshit Chawda 2024-01-22 08:59:53 -08:00
parent 48e9836515
commit ad7a34e86d
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 857 additions and 255 deletions

View file

@ -1,176 +1,196 @@
use std::{convert::Infallible, env, net::SocketAddr, sync::Arc};
use std::io::Error;
use bytes::Bytes;
use fastwebsockets::{
upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError,
};
use futures_util::{SinkExt, StreamExt};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version,
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use penguin_mux::{Multiplexor, MuxStream};
use tokio::{
net::{TcpListener, TcpStream},
task::{JoinError, JoinSet},
};
use tokio::net::{TcpListener, TcpStream};
use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_tungstenite::{
tungstenite::{handshake::derive_accept_key, protocol::Role},
WebSocketStream,
};
use tokio_util::codec::{BytesCodec, Framed};
type Body = http_body_util::Empty<hyper::body::Bytes>;
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
type MultiplexorStream = MuxStream<WebSocketStream<TokioIo<Upgraded>>>;
async fn forward(mut stream: MultiplexorStream) -> Result<(), JoinError> {
println!("forwarding");
let host = std::str::from_utf8(&stream.dest_host).unwrap();
let mut tcp_stream = TcpStream::connect((host, stream.dest_port)).await.unwrap();
println!("connected to {:?}", tcp_stream.peer_addr().unwrap());
tokio::io::copy_bidirectional(&mut stream, &mut tcp_stream)
.await
.unwrap();
println!("finished");
Ok(())
}
async fn handle_connection(ws_stream: WebSocketStream<TokioIo<Upgraded>>, addr: SocketAddr) {
println!("WebSocket connection established: {}", addr);
let mux = Multiplexor::new(ws_stream, penguin_mux::Role::Server, None, None);
let mut jobs = JoinSet::new();
println!("muxing");
loop {
tokio::select! {
Some(result) = jobs.join_next() => {
match result {
Ok(Ok(())) => {}
Ok(Err(err)) | Err(err) => eprintln!("failed to forward: {:?}", err),
}
}
Ok(result) = mux.server_new_stream_channel() => {
jobs.spawn(forward(result));
}
else => {
break;
}
}
}
println!("{} disconnected", &addr);
}
async fn handle_request(
mut req: Request<Incoming>,
addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
let headers = req.headers();
let derived = headers
.get(SEC_WEBSOCKET_KEY)
.map(|k| derive_accept_key(k.as_bytes()));
let mut negotiated_protocol: Option<String> = None;
if let Some(protocols) = headers
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|h| h.to_str().ok())
{
negotiated_protocol = protocols.split(',').next().map(|h| h.trim().to_string());
}
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("upgrade"))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| derived.is_none()
{
return Ok(Response::new(Body::default()));
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
let upgraded = TokioIo::new(upgraded);
handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
addr,
)
.await;
}
Err(e) => eprintln!("upgrade error: {}", e),
}
});
let mut res = Response::new(Body::default());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut()
.append(CONNECTION, HeaderValue::from_static("Upgrade"));
res.headers_mut()
.append(UPGRADE, HeaderValue::from_static("websocket"));
res.headers_mut()
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
if let Some(protocol) = negotiated_protocol {
res.headers_mut()
.append(SEC_WEBSOCKET_PROTOCOL, protocol.parse().unwrap());
}
Ok(res)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = env::args()
.nth(1)
.unwrap_or_else(|| "0.0.0.0:4000".to_string())
.parse::<SocketAddr>()?;
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Error> {
let pem = include_bytes!("./pem.pem");
let key = include_bytes!("./key.pem");
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("failed to make identity");
let prefix = if let Some(prefix) = std::env::args().nth(1) {
prefix
} else {
"/".to_string()
};
let port = if let Some(prefix) = std::env::args().nth(1) {
prefix
} else {
"4000".to_string()
};
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("invalid pem/key");
let acceptor = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity).unwrap());
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(addr).await?;
println!("listening on {}", addr);
loop {
let (stream, remote_addr) = listener.accept().await?;
let acceptor = acceptor.clone();
let socket = TcpListener::bind(format!("0.0.0.0:{}", port))
.await
.expect("failed to bind");
let acceptor = TlsAcceptor::from(
native_tls::TlsAcceptor::new(identity).expect("failed to make tls acceptor"),
);
let acceptor = std::sync::Arc::new(acceptor);
println!("listening on 0.0.0.0:4000");
while let Ok((stream, addr)) = socket.accept().await {
let acceptor_cloned = acceptor.clone();
let prefix_cloned = prefix.clone();
tokio::spawn(async move {
let stream = acceptor.accept(stream).await.expect("not tls");
let stream = acceptor_cloned.accept(stream).await.expect("not tls");
let io = TokioIo::new(stream);
let service = service_fn(move |req| handle_request(req, remote_addr));
let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone()));
let conn = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(err) = conn.await {
eprintln!("failed to serve connection: {:?}", err);
println!("{:?}: failed to serve conn: {:?}", addr, err);
}
});
}
Ok(())
}
async fn accept_http(
mut req: Request<Incoming>,
addr: String,
prefix: String,
) -> Result<Response<HttpBody>, WebSocketError> {
if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) {
let uri = req.uri().clone();
let (mut res, fut) = upgrade::upgrade(&mut req)?;
tokio::spawn(async move {
if *uri.path() != prefix {
if let Err(e) =
accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await
{
println!("{:?}: error in ws handling: {:?}", addr, e);
}
}
});
if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") {
let first_protocol = protocol
.to_str()
.expect("failed to get protocol")
.split(',')
.next()
.expect("failed to get first protocol")
.trim();
res.headers_mut().insert(
"Sec-Websocket-Protocol",
HeaderValue::from_str(first_protocol).unwrap(),
);
}
Ok(res)
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(HttpBody::new())
.unwrap())
}
}
async fn accept_wsproxy(
fut: upgrade::UpgradeFut,
incoming_uri: String,
addr: String,
prefix: String,
) -> Result<(), Box<dyn std::error::Error>> {
let mut ws_stream = FragmentCollector::new(fut.await?);
// should always have prefix
let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap();
println!("{:?}: connected", addr);
let tcp_stream = match TcpStream::connect(incoming_uri).await {
Ok(stream) => stream,
Err(err) => {
ws_stream
.write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect"))
.await
.unwrap();
return Err(Box::new(err));
}
};
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = ws_stream.read_frame() => {
match event {
Ok(frame) => {
print!("{:?}: event ws - ", addr);
match frame.opcode {
OpCode::Text | OpCode::Binary => {
if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
}
}
OpCode::Close => {
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
println!("closed success");
} else {
println!("closed FAILED");
}
break;
}
_ => {
println!("ignored");
}
}
},
Err(err) => {
print!("{:?}: err in ws: {:?} - ", addr, err);
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
println!("closed tcp success");
} else {
println!("closed tcp FAILED");
}
break;
}
}
},
event = tcp_stream_framed.next() => {
if let Some(res) = event {
print!("{:?}: event tcp - ", addr);
match res {
Ok(buf) => {
if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() {
println!("sent success");
} else {
println!("sent FAILED");
}
}
Err(_) => {
if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() {
println!("closed success");
} else {
println!("closed FAILED");
}
}
}
}
}
}
}
println!("\"{}\": connection closed", addr);
Ok(())
}