mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
move the wisp logic into wisp lib
This commit is contained in:
parent
379e07d643
commit
2a5684192a
8 changed files with 314 additions and 198 deletions
|
@ -1,25 +1,22 @@
|
|||
mod lockedws;
|
||||
|
||||
use std::{io::Error, sync::Arc};
|
||||
#![feature(let_chains)]
|
||||
use std::io::Error;
|
||||
|
||||
use bytes::Bytes;
|
||||
use fastwebsockets::{
|
||||
upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError,
|
||||
upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
|
||||
WebSocketError,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use futures_util::{SinkExt, StreamExt, TryFutureExt};
|
||||
use hyper::{
|
||||
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
|
||||
Response, StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_native_tls::{native_tls, TlsAcceptor};
|
||||
use tokio_util::codec::{BytesCodec, Framed};
|
||||
|
||||
use wisp_mux::{ws, Packet, PacketType};
|
||||
use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent};
|
||||
|
||||
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
||||
|
||||
|
@ -73,37 +70,26 @@ async fn accept_http(
|
|||
addr: String,
|
||||
prefix: String,
|
||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||
if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) {
|
||||
if upgrade::is_upgrade_request(&req)
|
||||
&& req.uri().path().to_string().starts_with(&prefix)
|
||||
&& let Some(protocol) = req.headers().get("Sec-Websocket-Protocol")
|
||||
&& protocol == "wisp-v1"
|
||||
{
|
||||
let uri = req.uri().clone();
|
||||
let (mut res, fut) = upgrade::upgrade(&mut req)?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if *uri.path() != prefix {
|
||||
if let Err(e) =
|
||||
accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone())
|
||||
.await
|
||||
{
|
||||
println!("{:?}: error in ws handling: {:?}", addr, e);
|
||||
}
|
||||
} else if let Err(e) = accept_ws(fut, addr.clone()).await {
|
||||
println!("{:?}: error in ws handling: {:?}", addr, e);
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") {
|
||||
let first_protocol = protocol
|
||||
.to_str()
|
||||
.expect("failed to get protocol")
|
||||
.split(',')
|
||||
.next()
|
||||
.expect("failed to get first protocol")
|
||||
.trim();
|
||||
res.headers_mut().insert(
|
||||
"Sec-Websocket-Protocol",
|
||||
HeaderValue::from_str(first_protocol).unwrap(),
|
||||
);
|
||||
if *uri.path() != prefix {
|
||||
tokio::spawn(async move {
|
||||
accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await
|
||||
});
|
||||
} else {
|
||||
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
|
||||
}
|
||||
|
||||
res.headers_mut().insert(
|
||||
"Sec-Websocket-Protocol",
|
||||
HeaderValue::from_str("wisp-v1").unwrap(),
|
||||
);
|
||||
Ok(res)
|
||||
} else {
|
||||
Ok(Response::builder()
|
||||
|
@ -113,127 +99,80 @@ async fn accept_http(
|
|||
}
|
||||
}
|
||||
|
||||
enum WsEvent {
|
||||
Send(Bytes),
|
||||
Close,
|
||||
async fn handle_mux(
|
||||
packet: ConnectPacket,
|
||||
mut stream: MuxStream<impl ws::WebSocketWrite>,
|
||||
) -> Result<(), WispError> {
|
||||
let uri = format!(
|
||||
"{}:{}",
|
||||
packet.destination_hostname, packet.destination_port
|
||||
);
|
||||
match packet.stream_type {
|
||||
StreamType::Tcp => {
|
||||
let tcp_stream = TcpStream::connect(uri)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = stream.read() => {
|
||||
match event {
|
||||
Some(event) => match event {
|
||||
WsEvent::Send(data) => {
|
||||
tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
WsEvent::Close(_) => break,
|
||||
},
|
||||
None => break
|
||||
}
|
||||
},
|
||||
event = tcp_stream_framed.next() => {
|
||||
match event.and_then(|x| x.ok()) {
|
||||
Some(event) => stream.write(event.into()).await?,
|
||||
None => break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
StreamType::Udp => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_ws(
|
||||
fut: upgrade::UpgradeFut,
|
||||
addr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut rx, tx) = fut.await?.split(tokio::io::split);
|
||||
let tx = lockedws::LockedWebSocketWrite::new(tx);
|
||||
|
||||
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
|
||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||
let (rx, tx) = fut.await?.split(tokio::io::split);
|
||||
let rx = FragmentCollectorRead::new(rx);
|
||||
|
||||
println!("{:?}: connected", addr);
|
||||
|
||||
tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into())
|
||||
.await?;
|
||||
let mut mux = ServerMux::new(rx, tx);
|
||||
|
||||
while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await {
|
||||
if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) {
|
||||
use PacketType::*;
|
||||
match packet.packet {
|
||||
Connect(inner_packet) => {
|
||||
let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::<WsEvent>();
|
||||
stream_map.clone().insert(packet.stream_id, ch_tx);
|
||||
let tx_cloned = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let tcp_stream = match TcpStream::connect(format!(
|
||||
"{}:{}",
|
||||
inner_packet.destination_hostname, inner_packet.destination_port
|
||||
))
|
||||
mux.server_loop(&mut |packet, stream| async move {
|
||||
let tx_cloned = stream.get_write_half();
|
||||
let stream_id = stream.stream_id;
|
||||
tokio::spawn(async move {
|
||||
let _ = handle_mux(packet, stream)
|
||||
.or_else(|err| async {
|
||||
let _ = tx_cloned
|
||||
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03)))
|
||||
.await;
|
||||
Err(err)
|
||||
})
|
||||
.and_then(|_| async {
|
||||
tx_cloned
|
||||
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02)))
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
tx_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(Packet::new_close(packet.stream_id, 0x03))
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
};
|
||||
let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new());
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = tcp_stream.next() => {
|
||||
if let Some(res) = event {
|
||||
match res {
|
||||
Ok(buf) => {
|
||||
tx_cloned.write_frame(
|
||||
ws::Frame::from(
|
||||
Packet::new_data(
|
||||
packet.stream_id,
|
||||
buf.to_vec()
|
||||
)
|
||||
).into()
|
||||
).await.map_err(std::io::Error::other)?;
|
||||
}
|
||||
Err(err) => {
|
||||
tx_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(Packet::new_close(
|
||||
packet.stream_id,
|
||||
0x03,
|
||||
))
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
event = ch_rx.recv() => {
|
||||
if let Some(event) = event {
|
||||
match event {
|
||||
WsEvent::Send(buf) => {
|
||||
tcp_stream.send(buf).await?;
|
||||
tx_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(
|
||||
Packet::new_continue(
|
||||
packet.stream_id,
|
||||
u32::MAX
|
||||
)
|
||||
).into()
|
||||
).await.map_err(std::io::Error::other)?;
|
||||
}
|
||||
WsEvent::Close => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
Data(inner_packet) => {
|
||||
println!("recieved data for {:?}", packet.stream_id);
|
||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.send(WsEvent::Send(inner_packet.into()));
|
||||
}
|
||||
}
|
||||
Continue(_) => unreachable!(),
|
||||
Close(_) => {
|
||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.send(WsEvent::Close);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
});
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
println!("{:?}: disconnected", addr);
|
||||
Ok(())
|
||||
|
@ -243,7 +182,7 @@ async fn accept_wsproxy(
|
|||
fut: upgrade::UpgradeFut,
|
||||
incoming_uri: &str,
|
||||
addr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||
let mut ws_stream = FragmentCollector::new(fut.await?);
|
||||
|
||||
println!("{:?}: connected (wsproxy)", addr);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue