mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-16 23:50:01 -04:00
add wisp lib
This commit is contained in:
parent
48e9836515
commit
ad7a34e86d
10 changed files with 857 additions and 255 deletions
|
@ -1,176 +1,196 @@
|
|||
use std::{convert::Infallible, env, net::SocketAddr, sync::Arc};
|
||||
use std::io::Error;
|
||||
|
||||
use bytes::Bytes;
|
||||
use fastwebsockets::{
|
||||
upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use hyper::{
|
||||
body::Incoming,
|
||||
header::{
|
||||
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
|
||||
SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||
},
|
||||
server::conn::http1,
|
||||
service::service_fn,
|
||||
upgrade::Upgraded,
|
||||
Method, Request, Response, StatusCode, Version,
|
||||
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
|
||||
Response, StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use penguin_mux::{Multiplexor, MuxStream};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
task::{JoinError, JoinSet},
|
||||
};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_native_tls::{native_tls, TlsAcceptor};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{handshake::derive_accept_key, protocol::Role},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio_util::codec::{BytesCodec, Framed};
|
||||
|
||||
type Body = http_body_util::Empty<hyper::body::Bytes>;
|
||||
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
||||
|
||||
type MultiplexorStream = MuxStream<WebSocketStream<TokioIo<Upgraded>>>;
|
||||
|
||||
async fn forward(mut stream: MultiplexorStream) -> Result<(), JoinError> {
|
||||
println!("forwarding");
|
||||
let host = std::str::from_utf8(&stream.dest_host).unwrap();
|
||||
let mut tcp_stream = TcpStream::connect((host, stream.dest_port)).await.unwrap();
|
||||
println!("connected to {:?}", tcp_stream.peer_addr().unwrap());
|
||||
tokio::io::copy_bidirectional(&mut stream, &mut tcp_stream)
|
||||
.await
|
||||
.unwrap();
|
||||
println!("finished");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connection(ws_stream: WebSocketStream<TokioIo<Upgraded>>, addr: SocketAddr) {
|
||||
println!("WebSocket connection established: {}", addr);
|
||||
let mux = Multiplexor::new(ws_stream, penguin_mux::Role::Server, None, None);
|
||||
let mut jobs = JoinSet::new();
|
||||
println!("muxing");
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(result) = jobs.join_next() => {
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) | Err(err) => eprintln!("failed to forward: {:?}", err),
|
||||
}
|
||||
}
|
||||
Ok(result) = mux.server_new_stream_channel() => {
|
||||
jobs.spawn(forward(result));
|
||||
}
|
||||
else => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("{} disconnected", &addr);
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
mut req: Request<Incoming>,
|
||||
addr: SocketAddr,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
let headers = req.headers();
|
||||
let derived = headers
|
||||
.get(SEC_WEBSOCKET_KEY)
|
||||
.map(|k| derive_accept_key(k.as_bytes()));
|
||||
|
||||
let mut negotiated_protocol: Option<String> = None;
|
||||
if let Some(protocols) = headers
|
||||
.get(SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
negotiated_protocol = protocols.split(',').next().map(|h| h.trim().to_string());
|
||||
}
|
||||
|
||||
if req.method() != Method::GET
|
||||
|| req.version() < Version::HTTP_11
|
||||
|| !headers
|
||||
.get(CONNECTION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
h.split(|c| c == ' ' || c == ',')
|
||||
.any(|p| p.eq_ignore_ascii_case("upgrade"))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
|| !headers
|
||||
.get(UPGRADE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| h.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false)
|
||||
|| !headers
|
||||
.get(SEC_WEBSOCKET_VERSION)
|
||||
.map(|h| h == "13")
|
||||
.unwrap_or(false)
|
||||
|| derived.is_none()
|
||||
{
|
||||
return Ok(Response::new(Body::default()));
|
||||
}
|
||||
|
||||
let ver = req.version();
|
||||
tokio::task::spawn(async move {
|
||||
match hyper::upgrade::on(&mut req).await {
|
||||
Ok(upgraded) => {
|
||||
let upgraded = TokioIo::new(upgraded);
|
||||
handle_connection(
|
||||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
||||
addr,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => eprintln!("upgrade error: {}", e),
|
||||
}
|
||||
});
|
||||
|
||||
let mut res = Response::new(Body::default());
|
||||
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
|
||||
*res.version_mut() = ver;
|
||||
res.headers_mut()
|
||||
.append(CONNECTION, HeaderValue::from_static("Upgrade"));
|
||||
res.headers_mut()
|
||||
.append(UPGRADE, HeaderValue::from_static("websocket"));
|
||||
res.headers_mut()
|
||||
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
|
||||
if let Some(protocol) = negotiated_protocol {
|
||||
res.headers_mut()
|
||||
.append(SEC_WEBSOCKET_PROTOCOL, protocol.parse().unwrap());
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr = env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "0.0.0.0:4000".to_string())
|
||||
.parse::<SocketAddr>()?;
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<(), Error> {
|
||||
let pem = include_bytes!("./pem.pem");
|
||||
let key = include_bytes!("./key.pem");
|
||||
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("failed to make identity");
|
||||
let prefix = if let Some(prefix) = std::env::args().nth(1) {
|
||||
prefix
|
||||
} else {
|
||||
"/".to_string()
|
||||
};
|
||||
let port = if let Some(prefix) = std::env::args().nth(1) {
|
||||
prefix
|
||||
} else {
|
||||
"4000".to_string()
|
||||
};
|
||||
|
||||
let identity = native_tls::Identity::from_pkcs8(pem, key).expect("invalid pem/key");
|
||||
|
||||
let acceptor = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity).unwrap());
|
||||
let acceptor = Arc::new(acceptor);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
println!("listening on {}", addr);
|
||||
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
let acceptor = acceptor.clone();
|
||||
let socket = TcpListener::bind(format!("0.0.0.0:{}", port))
|
||||
.await
|
||||
.expect("failed to bind");
|
||||
let acceptor = TlsAcceptor::from(
|
||||
native_tls::TlsAcceptor::new(identity).expect("failed to make tls acceptor"),
|
||||
);
|
||||
let acceptor = std::sync::Arc::new(acceptor);
|
||||
|
||||
println!("listening on 0.0.0.0:4000");
|
||||
while let Ok((stream, addr)) = socket.accept().await {
|
||||
let acceptor_cloned = acceptor.clone();
|
||||
let prefix_cloned = prefix.clone();
|
||||
tokio::spawn(async move {
|
||||
let stream = acceptor.accept(stream).await.expect("not tls");
|
||||
let stream = acceptor_cloned.accept(stream).await.expect("not tls");
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = service_fn(move |req| handle_request(req, remote_addr));
|
||||
|
||||
let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone()));
|
||||
let conn = http1::Builder::new()
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(err) = conn.await {
|
||||
eprintln!("failed to serve connection: {:?}", err);
|
||||
println!("{:?}: failed to serve conn: {:?}", addr, err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_http(
|
||||
mut req: Request<Incoming>,
|
||||
addr: String,
|
||||
prefix: String,
|
||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||
if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) {
|
||||
let uri = req.uri().clone();
|
||||
let (mut res, fut) = upgrade::upgrade(&mut req)?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if *uri.path() != prefix {
|
||||
if let Err(e) =
|
||||
accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await
|
||||
{
|
||||
println!("{:?}: error in ws handling: {:?}", addr, e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") {
|
||||
let first_protocol = protocol
|
||||
.to_str()
|
||||
.expect("failed to get protocol")
|
||||
.split(',')
|
||||
.next()
|
||||
.expect("failed to get first protocol")
|
||||
.trim();
|
||||
res.headers_mut().insert(
|
||||
"Sec-Websocket-Protocol",
|
||||
HeaderValue::from_str(first_protocol).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
} else {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(HttpBody::new())
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept_wsproxy(
|
||||
fut: upgrade::UpgradeFut,
|
||||
incoming_uri: String,
|
||||
addr: String,
|
||||
prefix: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut ws_stream = FragmentCollector::new(fut.await?);
|
||||
|
||||
// should always have prefix
|
||||
let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap();
|
||||
|
||||
println!("{:?}: connected", addr);
|
||||
|
||||
let tcp_stream = match TcpStream::connect(incoming_uri).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
ws_stream
|
||||
.write_frame(Frame::close(CloseCode::Away.into(), b"failed to connect"))
|
||||
.await
|
||||
.unwrap();
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
};
|
||||
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = ws_stream.read_frame() => {
|
||||
match event {
|
||||
Ok(frame) => {
|
||||
print!("{:?}: event ws - ", addr);
|
||||
match frame.opcode {
|
||||
OpCode::Text | OpCode::Binary => {
|
||||
if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() {
|
||||
println!("sent success");
|
||||
} else {
|
||||
println!("sent FAILED");
|
||||
}
|
||||
}
|
||||
OpCode::Close => {
|
||||
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
|
||||
println!("closed success");
|
||||
} else {
|
||||
println!("closed FAILED");
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
println!("ignored");
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
print!("{:?}: err in ws: {:?} - ", addr, err);
|
||||
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() {
|
||||
println!("closed tcp success");
|
||||
} else {
|
||||
println!("closed tcp FAILED");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
event = tcp_stream_framed.next() => {
|
||||
if let Some(res) = event {
|
||||
print!("{:?}: event tcp - ", addr);
|
||||
match res {
|
||||
Ok(buf) => {
|
||||
if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() {
|
||||
println!("sent success");
|
||||
} else {
|
||||
println!("sent FAILED");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() {
|
||||
println!("closed success");
|
||||
} else {
|
||||
println!("closed FAILED");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\"{}\": connection closed", addr);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue