diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 0d74e42..425aeda 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -18,7 +18,7 @@ use wisp_mux::{ }; use crate::{ - listener::WispResult, + route::WispResult, stream::{ClientStream, ResolvedPacket}, CLIENTS, CONFIG, }; @@ -51,7 +51,7 @@ async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow: let len = buf.len(); if len == 0 { - return Ok(()) + return Ok(()); } muxtx.write(&buf).await?; @@ -261,7 +261,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { let _ = mux.close().await; event.notify(usize::MAX); - while set.join_next().await.is_some() {}; + while set.join_next().await.is_some() {} debug!("wisp client id {:?} disconnected", id); diff --git a/server/src/listener.rs b/server/src/listener.rs index 41fc736..9d48d73 100644 --- a/server/src/listener.rs +++ b/server/src/listener.rs @@ -1,116 +1,19 @@ -use std::{future::Future, io::Cursor}; - use anyhow::Context; -use bytes::Bytes; -use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; -use http_body_util::Full; -use hyper::{ - body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, - StatusCode, -}; -use hyper_util::rt::TokioIo; -use log::{debug, error}; use tokio::{ fs::{remove_file, try_exists}, - io::AsyncReadExt, net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream}, }; -use tokio_util::{ - codec::{FramedRead, FramedWrite, LengthDelimitedCodec}, - either::Either, -}; +use tokio_util::either::Either; use uuid::Uuid; -use wisp_mux::{ - generic::{GenericWebSocketRead, GenericWebSocketWrite}, - ws::{WebSocketRead, WebSocketWrite}, -}; -use crate::{ - config::{SocketTransport, SocketType}, - generate_stats, - stream::WebSocketStreamWrapper, - CONFIG, -}; +use crate::{config::SocketType, CONFIG}; pub type ServerStream = Either; pub type ServerStreamRead = Either; pub type ServerStreamWrite = Either; -type Body = Full; -fn non_ws_resp() -> Response { - Response::builder() - .status(StatusCode::OK) - .body(Body::new(CONFIG.server.non_ws_response.as_bytes().into())) - .unwrap() -} - -async fn ws_upgrade(mut req: Request, callback: T) -> anyhow::Result> -where - T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, - R: Future> + Send, -{ - let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); - - if !is_upgrade - && CONFIG.server.enable_stats_endpoint - && req.uri().path() == CONFIG.server.stats_endpoint - { - match generate_stats() { - Ok(x) => { - debug!("sent server stats to http client"); - return Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::new(x.into())) - .unwrap()); - } - Err(x) => { - error!("failed to send stats to http client: {:?}", x); - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::new(x.to_string().into())) - .unwrap()); - } - } - } else if !is_upgrade { - debug!("sent non_ws_response to http client"); - return Ok(non_ws_resp()); - } - - let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; - // replace body of Empty with Full - let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); - - if req - .uri() - .path() - .starts_with(&(CONFIG.server.prefix.clone() + "/")) - { - tokio::spawn(async move { - if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await { - error!("error while serving client: {:?}", err); - } - }); - } else if CONFIG.wisp.allow_wsproxy { - let udp = req.uri().query().unwrap_or_default() == "?udp"; - tokio::spawn(async move { - if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await { - error!("error while serving client: {:?}", err); - } - }); - } else { - debug!("sent non_ws_response to http client"); - return Ok(non_ws_resp()); - } - - Ok(resp) -} - pub trait ServerStreamExt { fn split(self) -> (ServerStreamRead, ServerStreamWrite); - async fn route( - self, - callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, - ) -> anyhow::Result<()>; } impl ServerStreamExt for ServerStream { @@ -126,79 +29,6 @@ impl ServerStreamExt for ServerStream { } } } - - async fn route( - self, - callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, - ) -> anyhow::Result<()> { - match CONFIG.server.transport { - SocketTransport::WebSocket => { - let stream = TokioIo::new(self); - - let fut = Builder::new() - .serve_connection( - stream, - service_fn(move |req| { - let callback = callback.clone(); - - ws_upgrade(req, |fut, wsproxy, udp, path| async move { - let mut ws = fut.await.context("failed to await upgrade future")?; - ws.set_max_message_size(CONFIG.server.max_message_size); - - if wsproxy { - let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); - (callback)(ServerRouteResult::WsProxy(ws, path, udp)); - } else { - let (read, write) = ws.split(|x| { - let parts = x - .into_inner() - .downcast::>() - .unwrap(); - let (r, w) = parts.io.into_inner().split(); - (Cursor::new(parts.read_buf).chain(r), w) - }); - - (callback)(ServerRouteResult::Wisp(( - Box::new(read), - Box::new(write), - ))) - } - - Ok(()) - }) - }), - ) - .with_upgrades(); - - if let Err(e) = fut.await { - error!("error while serving client: {:?}", e); - } - } - SocketTransport::LengthDelimitedLe => { - let codec = LengthDelimitedCodec::builder() - .little_endian() - .max_frame_length(usize::MAX) - .new_codec(); - - let (read, write) = self.split(); - let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); - let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); - - (callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write)))); - } - } - Ok(()) - } -} - -pub type WispResult = ( - Box, - Box, -); - -pub enum ServerRouteResult { - Wisp(WispResult), - WsProxy(WebSocketStreamWrapper, String, bool), } pub enum ServerListener { diff --git a/server/src/main.rs b/server/src/main.rs index c497adc..321f3f3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -7,8 +7,9 @@ use config::{validate_config_cache, Cli, Config}; use dashmap::DashMap; use handle::{handle_wisp, handle_wsproxy}; use lazy_static::lazy_static; -use listener::{ServerListener, ServerRouteResult, ServerStreamExt}; +use listener::ServerListener; use log::{error, info}; +use route::ServerRouteResult; use tokio::signal::unix::{signal, SignalKind}; use uuid::Uuid; use wisp_mux::{ConnectPacket, StreamType}; @@ -16,6 +17,7 @@ use wisp_mux::{ConnectPacket, StreamType}; mod config; mod handle; mod listener; +mod route; mod stream; type Client = (DashMap, bool); @@ -177,7 +179,7 @@ async fn main() -> anyhow::Result<()> { loop { let (stream, id) = listener.accept().await?; tokio::spawn(async move { - let res = stream.route(move |stream| handle_stream(stream, id)).await; + let res = route::route(stream, move |stream| handle_stream(stream, id)).await; if let Err(e) = res { error!("error while routing client: {:?}", e); diff --git a/server/src/route.rs b/server/src/route.rs new file mode 100644 index 0000000..08ee608 --- /dev/null +++ b/server/src/route.rs @@ -0,0 +1,166 @@ +use std::{future::Future, io::Cursor}; + +use anyhow::Context; +use bytes::Bytes; +use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; +use http_body_util::Full; +use hyper::{ + body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, + StatusCode, +}; +use hyper_util::rt::TokioIo; +use log::{debug, error}; +use tokio::io::AsyncReadExt; +use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; +use wisp_mux::{ + generic::{GenericWebSocketRead, GenericWebSocketWrite}, + ws::{WebSocketRead, WebSocketWrite}, +}; + +use crate::{ + config::SocketTransport, + generate_stats, + listener::{ServerStream, ServerStreamExt}, + stream::WebSocketStreamWrapper, + CONFIG, +}; + +type Body = Full; +fn non_ws_resp() -> Response { + Response::builder() + .status(StatusCode::OK) + .body(Body::new(CONFIG.server.non_ws_response.as_bytes().into())) + .unwrap() +} + +async fn ws_upgrade(mut req: Request, callback: T) -> anyhow::Result> +where + T: FnOnce(UpgradeFut, bool, bool, String) -> R + Send + 'static, + R: Future> + Send, +{ + let is_upgrade = fastwebsockets::upgrade::is_upgrade_request(&req); + + if !is_upgrade + && CONFIG.server.enable_stats_endpoint + && req.uri().path() == CONFIG.server.stats_endpoint + { + match generate_stats() { + Ok(x) => { + debug!("sent server stats to http client"); + return Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::new(x.into())) + .unwrap()); + } + Err(x) => { + error!("failed to send stats to http client: {:?}", x); + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::new(x.to_string().into())) + .unwrap()); + } + } + } else if !is_upgrade { + debug!("sent non_ws_response to http client"); + return Ok(non_ws_resp()); + } + + let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; + // replace body of Empty with Full + let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); + + if req + .uri() + .path() + .starts_with(&(CONFIG.server.prefix.clone() + "/")) + { + tokio::spawn(async move { + if let Err(err) = (callback)(fut, false, false, req.uri().path().to_string()).await { + error!("error while serving client: {:?}", err); + } + }); + } else if CONFIG.wisp.allow_wsproxy { + let udp = req.uri().query().unwrap_or_default() == "?udp"; + tokio::spawn(async move { + if let Err(err) = (callback)(fut, false, udp, req.uri().path().to_string()).await { + error!("error while serving client: {:?}", err); + } + }); + } else { + debug!("sent non_ws_response to http client"); + return Ok(non_ws_resp()); + } + + Ok(resp) +} + +pub async fn route( + stream: ServerStream, + callback: impl FnOnce(ServerRouteResult) + Clone + Send + 'static, +) -> anyhow::Result<()> { + match CONFIG.server.transport { + SocketTransport::WebSocket => { + let stream = TokioIo::new(stream); + + let fut = Builder::new() + .serve_connection( + stream, + service_fn(move |req| { + let callback = callback.clone(); + + ws_upgrade(req, |fut, wsproxy, udp, path| async move { + let mut ws = fut.await.context("failed to await upgrade future")?; + ws.set_max_message_size(CONFIG.server.max_message_size); + + if wsproxy { + let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); + (callback)(ServerRouteResult::WsProxy(ws, path, udp)); + } else { + let (read, write) = ws.split(|x| { + let parts = + x.into_inner().downcast::>().unwrap(); + let (r, w) = parts.io.into_inner().split(); + (Cursor::new(parts.read_buf).chain(r), w) + }); + + (callback)(ServerRouteResult::Wisp(( + Box::new(read), + Box::new(write), + ))) + } + + Ok(()) + }) + }), + ) + .with_upgrades(); + + if let Err(e) = fut.await { + error!("error while serving client: {:?}", e); + } + } + SocketTransport::LengthDelimitedLe => { + let codec = LengthDelimitedCodec::builder() + .little_endian() + .max_frame_length(usize::MAX) + .new_codec(); + + let (read, write) = stream.split(); + let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); + let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); + + (callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write)))); + } + } + Ok(()) +} + +pub type WispResult = ( + Box, + Box, +); + +pub enum ServerRouteResult { + Wisp(WispResult), + WsProxy(WebSocketStreamWrapper, String, bool), +}