diff --git a/server/Cargo.toml b/server/Cargo.toml index 78b0834..1d6c0fb 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -18,10 +18,16 @@ lazy_static = "1.5.0" log = { version = "0.4.22", features = ["serde", "std"] } regex = "1.10.6" serde = { version = "1.0.208", features = ["derive"] } -serde_json = "1.0.125" -serde_yaml = "0.9.34" +serde_json = { version = "1.0.125", optional = true } +serde_yaml = { version = "0.9.34", optional = true } tokio = { version = "1.39.3", features = ["full"] } -tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] } -toml = "0.8.19" +tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } +toml = { version = "0.8.19", optional = true } uuid = { version = "1.10.0", features = ["v4"] } -wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] } +wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream"] } + +[features] +default = ["toml"] +json = ["dep:serde_json"] +yaml = ["dep:serde_yaml"] +toml = ["dep:toml"] diff --git a/server/src/config.rs b/server/src/config.rs index 338d7d0..5d0897a 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -22,6 +22,18 @@ pub enum SocketType { Unix, } +#[derive(Serialize, Deserialize, Default, Debug)] +#[serde(rename_all = "lowercase")] +pub enum SocketTransport { + /// WebSocket transport. + #[default] + WebSocket, + /// Little-endian u32 length-delimited codec. See + /// [tokio-util](https://docs.rs/tokio-util/latest/tokio_util/codec/length_delimited/index.html) + /// for more information. + LengthDelimitedLe, +} + #[derive(Serialize, Deserialize)] #[serde(default)] pub struct ServerConfig { @@ -29,6 +41,8 @@ pub struct ServerConfig { pub bind: String, /// Socket type to listen on. pub socket: SocketType, + /// Transport to listen on. + pub transport: SocketTransport, /// Whether or not to resolve and connect to IPV6 upstream addresses. pub resolve_ipv6: bool, /// Whether or not to enable TCP nodelay on client TCP streams. @@ -189,6 +203,7 @@ impl Default for ServerConfig { Self { bind: "127.0.0.1:4000".to_string(), socket: SocketType::default(), + transport: SocketTransport::default(), resolve_ipv6: false, tcp_nodelay: false, @@ -318,16 +333,22 @@ impl StreamConfig { impl Config { pub fn ser(&self) -> anyhow::Result { Ok(match CLI.format { + #[cfg(feature = "toml")] ConfigFormat::Toml => toml::to_string_pretty(self)?, + #[cfg(feature = "json")] ConfigFormat::Json => serde_json::to_string_pretty(self)?, + #[cfg(feature = "yaml")] ConfigFormat::Yaml => serde_yaml::to_string(self)?, }) } pub fn de(string: String) -> anyhow::Result { Ok(match CLI.format { + #[cfg(feature = "toml")] ConfigFormat::Toml => toml::from_str(&string)?, + #[cfg(feature = "json")] ConfigFormat::Json => serde_json::from_str(&string)?, + #[cfg(feature = "yaml")] ConfigFormat::Yaml => serde_yaml::from_str(&string)?, }) } @@ -335,9 +356,12 @@ impl Config { #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)] pub enum ConfigFormat { + #[cfg(feature = "toml")] #[default] Toml, + #[cfg(feature = "json")] Json, + #[cfg(feature = "yaml")] Yaml, } diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 711ba7a..4c63518 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -1,11 +1,7 @@ -use std::io::Cursor; - use anyhow::Context; -use fastwebsockets::upgrade::UpgradeFut; use futures_util::FutureExt; -use hyper_util::rt::TokioIo; use tokio::{ - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, select, task::JoinSet, @@ -17,7 +13,8 @@ use wisp_mux::{ }; use crate::{ - stream::{ClientStream, ResolvedPacket, ServerStream, ServerStreamExt}, + listener::WispResult, + stream::{ClientStream, ResolvedPacket}, CLIENTS, CONFIG, }; @@ -162,16 +159,8 @@ async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) CLIENTS.get(&id).unwrap().0.remove(&uuid); } -pub async fn handle_wisp(fut: UpgradeFut, id: String) -> anyhow::Result<()> { - let mut ws = fut.await.context("failed to await upgrade future")?; - ws.set_max_message_size(CONFIG.server.max_message_size); - - let (read, write) = ws.split(|x| { - let parts = x.into_inner().downcast::>().unwrap(); - let (r, w) = parts.io.into_inner().split(); - (Cursor::new(parts.read_buf).chain(r), w) - }); - +pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { + let (read, write) = stream; let (extensions, buffer_size) = CONFIG.wisp.to_opts(); let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions) diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs index 7a18d8e..8898be7 100644 --- a/server/src/handle/wsproxy.rs +++ b/server/src/handle/wsproxy.rs @@ -1,7 +1,6 @@ use std::str::FromStr; -use anyhow::Context; -use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector}; +use fastwebsockets::CloseCode; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, select, @@ -15,16 +14,11 @@ use crate::{ }; pub async fn handle_wsproxy( - fut: UpgradeFut, + mut ws: WebSocketStreamWrapper, id: String, path: String, udp: bool, ) -> anyhow::Result<()> { - let mut ws = fut.await.context("failed to await upgrade future")?; - ws.set_max_message_size(CONFIG.server.max_message_size); - let ws = FragmentCollector::new(ws); - let mut ws = WebSocketStreamWrapper(ws); - if udp && !CONFIG.stream.allow_wsproxy_udp { let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await; return Ok(()); diff --git a/server/src/listener.rs b/server/src/listener.rs new file mode 100644 index 0000000..6a35246 --- /dev/null +++ b/server/src/listener.rs @@ -0,0 +1,236 @@ +use std::{future::Future, io::Cursor}; + +use anyhow::Context; +use bytes::Bytes; +use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector}; +use http_body_util::Full; +use hyper::{ + body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, + StatusCode, +}; +use hyper_util::rt::TokioIo; +use log::error; +use tokio::{ + fs::{remove_file, try_exists}, + io::AsyncReadExt, + net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream}, +}; +use tokio_util::{ + codec::{FramedRead, FramedWrite, LengthDelimitedCodec}, + either::Either, +}; +use uuid::Uuid; +use wisp_mux::{ + generic::{GenericWebSocketRead, GenericWebSocketWrite}, + ws::{WebSocketRead, WebSocketWrite}, +}; + +use crate::{ + config::{SocketTransport, SocketType}, + generate_stats, + stream::WebSocketStreamWrapper, + CONFIG, +}; + +pub type ServerStream = Either; +pub type ServerStreamRead = Either; +pub type ServerStreamWrite = Either; + +type Body = Full; +fn non_ws_resp() -> Response { + Response::builder() + .status(StatusCode::OK) + .body(Body::new(CONFIG.server.non_ws_response.as_bytes().into())) + .unwrap() +} + +async fn ws_upgrade(mut req: Request, callback: T) -> anyhow::Result> +where + T: FnOnce(UpgradeFut, bool, bool, String) -> R, + R: Future>, +{ + if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint { + match generate_stats() { + Ok(x) => { + return Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::new(x.into())) + .unwrap()) + } + Err(x) => { + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::new(x.to_string().into())) + .unwrap()) + } + } + } else if !fastwebsockets::upgrade::is_upgrade_request(&req) { + return Ok(non_ws_resp()); + } + + let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; + // replace body of Empty with Full + let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); + + if req + .uri() + .path() + .starts_with(&(CONFIG.server.prefix.clone() + "/")) + { + (callback)(fut, false, false, req.uri().path().to_string()); + } else if CONFIG.wisp.allow_wsproxy { + let udp = req.uri().query().unwrap_or_default() == "?udp"; + (callback)(fut, true, udp, req.uri().path().to_string()); + } else { + return Ok(non_ws_resp()); + } + + Ok(resp) +} + +pub trait ServerStreamExt { + fn split(self) -> (ServerStreamRead, ServerStreamWrite); + async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()>; +} + +impl ServerStreamExt for ServerStream { + fn split(self) -> (ServerStreamRead, ServerStreamWrite) { + match self { + Self::Left(x) => { + let (r, w) = x.into_split(); + (Either::Left(r), Either::Left(w)) + } + Self::Right(x) => { + let (r, w) = x.into_split(); + (Either::Right(r), Either::Right(w)) + } + } + } + + async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()> { + match CONFIG.server.transport { + SocketTransport::WebSocket => { + let stream = TokioIo::new(self); + + let fut = Builder::new() + .serve_connection( + stream, + service_fn(move |req| { + let callback = callback.clone(); + + ws_upgrade(req, |fut, wsproxy, udp, path| async move { + let mut ws = fut.await.context("failed to await upgrade future")?; + ws.set_max_message_size(CONFIG.server.max_message_size); + + if wsproxy { + let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); + (callback)(ServerRouteResult::WsProxy(ws, path, udp)); + } else { + let (read, write) = ws.split(|x| { + let parts = x + .into_inner() + .downcast::>() + .unwrap(); + let (r, w) = parts.io.into_inner().split(); + (Cursor::new(parts.read_buf).chain(r), w) + }); + + (callback)(ServerRouteResult::Wisp(( + Box::new(read), + Box::new(write), + ))) + } + + Ok(()) + }) + }), + ) + .with_upgrades(); + + if let Err(e) = fut.await { + error!("error while serving client: {:?}", e); + } + } + SocketTransport::LengthDelimitedLe => { + let codec = LengthDelimitedCodec::builder() + .little_endian() + .max_frame_length(usize::MAX) + .new_codec(); + + let (read, write) = self.split(); + let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); + let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); + + (callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write)))); + } + } + Ok(()) + } +} + +pub type WispResult = ( + Box, + Box, +); + +pub enum ServerRouteResult { + Wisp(WispResult), + WsProxy(WebSocketStreamWrapper, String, bool), +} + +pub enum ServerListener { + Tcp(TcpListener), + Unix(UnixListener), +} + +impl ServerListener { + pub async fn new() -> anyhow::Result { + Ok(match CONFIG.server.socket { + SocketType::Tcp => Self::Tcp( + TcpListener::bind(&CONFIG.server.bind) + .await + .with_context(|| { + format!("failed to bind to tcp address `{}`", CONFIG.server.bind) + })?, + ), + SocketType::Unix => { + if try_exists(&CONFIG.server.bind).await? { + remove_file(&CONFIG.server.bind).await?; + } + Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| { + format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) + })?) + } + }) + } + + pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { + match self { + Self::Tcp(x) => { + let (stream, addr) = x + .accept() + .await + .context("failed to accept tcp connection")?; + if CONFIG.server.tcp_nodelay { + stream + .set_nodelay(true) + .context("failed to set tcp nodelay")?; + } + Ok((Either::Left(stream), addr.to_string())) + } + Self::Unix(x) => x + .accept() + .await + .map(|(x, y)| { + ( + Either::Right(x), + y.as_pathname() + .and_then(|x| x.to_str()) + .map(ToString::to_string) + .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), + ) + }) + .context("failed to accept unix socket connection"), + } + } +} diff --git a/server/src/main.rs b/server/src/main.rs index b100dcb..9df0430 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,26 +2,20 @@ use std::{fmt::Write, fs::read_to_string}; -use bytes::Bytes; use clap::Parser; use config::{validate_config_cache, Cli, Config}; use dashmap::DashMap; use handle::{handle_wisp, handle_wsproxy}; -use http_body_util::Full; -use hyper::{ - body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, - StatusCode, -}; -use hyper_util::rt::TokioIo; use lazy_static::lazy_static; +use listener::{ServerListener, ServerRouteResult, ServerStreamExt}; use log::{error, info}; -use stream::ServerListener; use tokio::signal::unix::{signal, SignalKind}; use uuid::Uuid; use wisp_mux::{ConnectPacket, StreamType}; mod config; mod handle; +mod listener; mod stream; type Client = (DashMap, bool); @@ -38,67 +32,6 @@ lazy_static! { pub static ref CLIENTS: DashMap = DashMap::new(); } -type Body = Full; -fn non_ws_resp() -> Response { - Response::builder() - .status(StatusCode::OK) - .body(Body::new(CONFIG.server.non_ws_response.as_bytes().into())) - .unwrap() -} - -async fn upgrade(mut req: Request, id: String) -> anyhow::Result> { - if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint { - match generate_stats() { - Ok(x) => { - return Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::new(x.into())) - .unwrap()) - } - Err(x) => { - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::new(x.to_string().into())) - .unwrap()) - } - } - } else if !fastwebsockets::upgrade::is_upgrade_request(&req) { - return Ok(non_ws_resp()); - } - - let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; - // replace body of Empty with Full - let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); - - if req - .uri() - .path() - .starts_with(&(CONFIG.server.prefix.clone() + "/")) - { - tokio::spawn(async move { - CLIENTS.insert(id.clone(), (DashMap::new(), false)); - if let Err(e) = handle_wisp(fut, id.clone()).await { - error!("error while handling upgraded client: {:?}", e); - }; - CLIENTS.remove(&id) - }); - } else if CONFIG.wisp.allow_wsproxy { - let udp = req.uri().query().unwrap_or_default() == "?udp"; - tokio::spawn(async move { - CLIENTS.insert(id.clone(), (DashMap::new(), true)); - if let Err(e) = handle_wsproxy(fut, id.clone(), req.uri().path().to_string(), udp).await - { - error!("error while handling upgraded client: {:?}", e); - }; - CLIENTS.remove(&id) - }); - } else { - return Ok(non_ws_resp()); - } - - Ok(resp) -} - fn format_stream_type(stream_type: StreamType) -> &'static str { match stream_type { StreamType::Tcp => "tcp", @@ -159,6 +92,22 @@ fn generate_stats() -> Result { Ok(out) } +fn handle_stream(stream: ServerRouteResult, id: String) { + tokio::spawn(async move { + CLIENTS.insert(id.clone(), (DashMap::new(), false)); + let res = match stream { + ServerRouteResult::Wisp(stream) => handle_wisp(stream, id.clone()).await, + ServerRouteResult::WsProxy(ws, path, udp) => { + handle_wsproxy(ws, id.clone(), path, udp).await + } + }; + if let Err(e) = res { + error!("error while handling client: {:?}", e); + } + CLIENTS.remove(&id) + }); +} + #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { if CLI.default_config { @@ -174,8 +123,8 @@ async fn main() -> anyhow::Result<()> { validate_config_cache(); info!( - "listening on {:?} with socket type {:?}", - CONFIG.server.bind, CONFIG.server.socket + "listening on {:?} with socket type {:?} and socket transport {:?}", + CONFIG.server.bind, CONFIG.server.socket, CONFIG.server.transport ); tokio::spawn(async { @@ -189,14 +138,10 @@ async fn main() -> anyhow::Result<()> { loop { let (stream, id) = listener.accept().await?; tokio::spawn(async move { - let stream = TokioIo::new(stream); + let res = stream.route(move |stream| handle_stream(stream, id)).await; - let fut = Builder::new() - .serve_connection(stream, service_fn(|req| upgrade(req, id.clone()))) - .with_upgrades(); - - if let Err(e) = fut.await { - error!("error while serving client: {:?}", e); + if let Err(e) = res { + error!("error while routing client: {:?}", e); } }); } diff --git a/server/src/stream.rs b/server/src/stream.rs index 56965b7..6180add 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -9,95 +9,10 @@ use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use regex::RegexSet; -use tokio::{ - fs::{remove_file, try_exists}, - net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream}, -}; -use tokio_util::either::Either; -use uuid::Uuid; +use tokio::net::{lookup_host, TcpStream, UdpSocket}; use wisp_mux::{ConnectPacket, StreamType}; -use crate::{config::SocketType, CONFIG}; - -pub enum ServerListener { - Tcp(TcpListener), - Unix(UnixListener), -} - -pub type ServerStream = Either; -pub type ServerStreamRead = Either; -pub type ServerStreamWrite = Either; - -pub trait ServerStreamExt { - fn split(self) -> (ServerStreamRead, ServerStreamWrite); -} - -impl ServerStreamExt for ServerStream { - fn split(self) -> (ServerStreamRead, ServerStreamWrite) { - match self { - Self::Left(x) => { - let (r, w) = x.into_split(); - (Either::Left(r), Either::Left(w)) - } - Self::Right(x) => { - let (r, w) = x.into_split(); - (Either::Right(r), Either::Right(w)) - } - } - } -} - -impl ServerListener { - pub async fn new() -> anyhow::Result { - Ok(match CONFIG.server.socket { - SocketType::Tcp => Self::Tcp( - TcpListener::bind(&CONFIG.server.bind) - .await - .with_context(|| { - format!("failed to bind to tcp address `{}`", CONFIG.server.bind) - })?, - ), - SocketType::Unix => { - if try_exists(&CONFIG.server.bind).await? { - remove_file(&CONFIG.server.bind).await?; - } - Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| { - format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) - })?) - } - }) - } - - pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { - match self { - Self::Tcp(x) => { - let (stream, addr) = x - .accept() - .await - .context("failed to accept tcp connection")?; - if CONFIG.server.tcp_nodelay { - stream - .set_nodelay(true) - .context("failed to set tcp nodelay")?; - } - Ok((Either::Left(stream), addr.to_string())) - } - Self::Unix(x) => x - .accept() - .await - .map(|(x, y)| { - ( - Either::Right(x), - y.as_pathname() - .and_then(|x| x.to_str()) - .map(ToString::to_string) - .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), - ) - }) - .context("failed to accept unix socket connection"), - } - } -} +use crate::CONFIG; fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool { blocked.is_match(str) && !allowed.is_match(str) diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 4803d5f..0d9ffac 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -22,7 +22,9 @@ pin-project-lite = "0.2.14" tokio = { version = "1.39.3", optional = true, default-features = false } [features] +default = ["generic_stream"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] +generic_stream = [] wasm = ["futures-timer/wasm-bindgen"] [package.metadata.docs.rs] diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 6f129af..63a463a 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,3 +1,5 @@ +//! WebSocketRead + WebSocketWrite implementation for the fastwebsockets library. + use std::ops::Deref; use async_trait::async_trait; diff --git a/wisp/src/generic.rs b/wisp/src/generic.rs new file mode 100644 index 0000000..262749c --- /dev/null +++ b/wisp/src/generic.rs @@ -0,0 +1,87 @@ +//! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s. + +use async_trait::async_trait; +use bytes::{Bytes, BytesMut}; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use std::error::Error; + +use crate::{ + ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + WispError, +}; + +/// WebSocketRead implementation for generic `Stream`s. +pub struct GenericWebSocketRead< + T: Stream> + Send + Unpin, + E: Error + Sync + Send + 'static, +>(T); + +impl> + Send + Unpin, E: Error + Sync + Send + 'static> + GenericWebSocketRead +{ + /// Create a new wrapper WebSocketRead implementation. + pub fn new(stream: T) -> Self { + Self(stream) + } + + /// Get the inner Stream from the wrapper. + pub fn into_inner(self) -> T { + self.0 + } +} + +#[async_trait] +impl> + Send + Unpin, E: Error + Sync + Send + 'static> + WebSocketRead for GenericWebSocketRead +{ + async fn wisp_read_frame( + &mut self, + _tx: &LockedWebSocketWrite, + ) -> Result, WispError> { + match self.0.next().await { + Some(data) => Ok(Frame::binary(Payload::Bytes( + data.map_err(|x| WispError::WsImplError(Box::new(x)))?, + ))), + None => Ok(Frame::close(Payload::Bytes(BytesMut::new()))), + } + } +} + +/// WebSocketWrite implementation for generic `Sink`s. +pub struct GenericWebSocketWrite< + T: Sink + Send + Unpin, + E: Error + Sync + Send + 'static, +>(T); + +impl + Send + Unpin, E: Error + Sync + Send + 'static> + GenericWebSocketWrite +{ + /// Create a new wrapper WebSocketWrite implementation. + pub fn new(stream: T) -> Self { + Self(stream) + } + + /// Get the inner Sink from the wrapper. + pub fn into_inner(self) -> T { + self.0 + } +} + +#[async_trait] +impl + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite + for GenericWebSocketWrite +{ + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + self.0 + .send(BytesMut::from(frame.payload).freeze()) + .await + .map_err(|x| WispError::WsImplError(Box::new(x))) + } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.0 + .close() + .await + .map_err(|x| WispError::WsImplError(Box::new(x))) + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index a078f02..1ba04e8 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -8,6 +8,9 @@ pub mod extensions; #[cfg(feature = "fastwebsockets")] #[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))] mod fastwebsockets; +#[cfg(feature = "generic_stream")] +#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))] +pub mod generic; mod packet; mod sink_unfold; mod stream; diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index fbf63af..9ca9342 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -167,7 +167,7 @@ pub trait WebSocketRead { } #[async_trait] -impl WebSocketRead for Box { +impl WebSocketRead for Box { async fn wisp_read_frame( &mut self, tx: &LockedWebSocketWrite, @@ -206,7 +206,7 @@ pub trait WebSocketWrite { } #[async_trait] -impl WebSocketWrite for Box { +impl WebSocketWrite for Box { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { self.as_mut().wisp_write_frame(frame).await }