diff --git a/Cargo.lock b/Cargo.lock index 5c9e0b1..fc2ffc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1860,7 +1860,7 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "wisp-mux" -version = "1.0.0" +version = "1.1.0" dependencies = [ "async_io_stream", "bytes", diff --git a/client/demo.js b/client/demo.js index 5b11bd1..89b0bd5 100644 --- a/client/demo.js +++ b/client/demo.js @@ -160,6 +160,7 @@ onmessage = async (msg) => { "alicesworld.tech:443", ); await ws.send("GET / HTTP 1.1\r\nHost: alicesworld.tech\r\nConnection: close\r\n\r\n"); + await ws.close(); } else { let resp = await epoxy_client.fetch("https://httpbin.org/get"); console.warn(resp, Object.fromEntries(resp.headers)); diff --git a/server/src/main.rs b/server/src/main.rs index d3c5d94..31dc7dd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -17,7 +17,9 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_util::codec::{BytesCodec, Framed}; -use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, MuxEvent}; +use wisp_mux::{ + ws, CloseReason, ConnectPacket, MuxEvent, MuxStream, ServerMux, StreamType, WispError, +}; type HttpBody = http_body_util::Full; @@ -192,12 +194,12 @@ async fn accept_ws( let close_ok = stream.get_close_handle(); let _ = handle_mux(packet, stream) .or_else(|err| async move { - let _ = close_err.close(0x03).await; + let _ = close_err.close(CloseReason::Unexpected).await; Err(err) }) .and_then(|should_send| async move { if should_send { - close_ok.close(0x02).await + close_ok.close(CloseReason::Voluntary).await } else { Ok(()) } @@ -222,7 +224,9 @@ async fn accept_wsproxy( match hyper::Uri::try_from(incoming_uri.clone()) { Ok(_) => (), Err(err) => { - ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")).await?; + ws_stream + .write_frame(Frame::close(CloseCode::Away.into(), b"invalid uri")) + .await?; return Err(Box::new(err)); } } diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 3241aac..8a9240d 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "1.0.0" +version = "1.1.0" license = "AGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -28,3 +28,5 @@ ws_stream_wasm = ["dep:ws_stream_wasm"] tokio_io = ["async_io_stream/tokio_io"] hyper_tower = ["dep:tower-service", "dep:hyper", "dep:tokio", "dep:hyper-util-wasm"] +[package.metadata.docs.rs] +features = ["hyper_tower"] diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index f3e153e..3c46837 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -6,6 +6,7 @@ #[cfg(feature = "fastwebsockets")] mod fastwebsockets; +mod sink_unfold; mod packet; mod stream; #[cfg(feature = "hyper_tower")] @@ -49,6 +50,8 @@ pub enum WispError { InvalidStreamType, /// The stream had an invalid ID. InvalidStreamId, + /// The close packet had an invalid reason. + InvalidCloseReason, /// The URI recieved was invalid. InvalidUri, /// The URI recieved had no host. @@ -89,6 +92,7 @@ impl std::fmt::Display for WispError { InvalidPacketType => write!(f, "Invalid packet type"), InvalidStreamType => write!(f, "Invalid stream type"), InvalidStreamId => write!(f, "Invalid stream id"), + InvalidCloseReason => write!(f, "Invalid close reason"), InvalidUri => write!(f, "Invalid URI"), UriHasNoHost => write!(f, "URI has no host"), UriHasNoPort => write!(f, "URI has no port"), @@ -132,7 +136,7 @@ impl ServerMuxInner { x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x }; self.stream_map.lock().await.iter().for_each(|x| { - let _ = x.1.unbounded_send(MuxEvent::Close(ClosePacket::new(0x01))); + let _ = x.1.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown))); }); ret } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index e52a8be..513e037 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -23,6 +23,57 @@ impl TryFrom for StreamType { } } +/// Close reason. +/// +/// See [the +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum CloseReason { + /// Reason unspecified or unknown. + Unknown = 0x01, + /// Voluntary stream closure. + Voluntary = 0x02, + /// Unexpected stream closure due to a network error. + Unexpected = 0x03, + /// Stream creation failed due to invalid information. + ServerStreamInvalidInfo = 0x41, + /// Stream creation failed due to an unreachable destination host. + ServerStreamUnreachable = 0x42, + /// Stream creation timed out due to the destination server not responding. + ServerStreamConnectionTimedOut = 0x43, + /// Stream creation failed due to the destination server refusing the connection. + ServerStreamConnectionRefused = 0x44, + /// TCP data transfer timed out. + ServerStreamTimedOut = 0x47, + /// Stream destination address/domain is intentionally blocked by the proxy server. + ServerStreamBlockedAddress = 0x48, + /// Connection throttled by the server. + ServerStreamThrottled = 0x49, + /// The client has encountered an unexpected error. + ClientUnexpected = 0x81, +} + +impl TryFrom for CloseReason { + type Error = WispError; + fn try_from(stream_type: u8) -> Result { + use CloseReason::*; + match stream_type { + 0x01 => Ok(Unknown), + 0x02 => Ok(Voluntary), + 0x03 => Ok(Unexpected), + 0x41 => Ok(ServerStreamInvalidInfo), + 0x42 => Ok(ServerStreamUnreachable), + 0x43 => Ok(ServerStreamConnectionTimedOut), + 0x44 => Ok(ServerStreamConnectionRefused), + 0x47 => Ok(ServerStreamTimedOut), + 0x48 => Ok(ServerStreamBlockedAddress), + 0x49 => Ok(ServerStreamThrottled), + 0x81 => Ok(ClientUnexpected), + _ => Err(Self::Error::InvalidStreamType), + } + } +} + /// Packet used to create a new stream. /// /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). @@ -118,15 +169,12 @@ impl From for Vec { #[derive(Debug, Copy, Clone)] pub struct ClosePacket { /// The close reason. - /// - /// See [the - /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons). - pub reason: u8, + pub reason: CloseReason, } impl ClosePacket { /// Create a new close packet. - pub fn new(reason: u8) -> Self { + pub fn new(reason: CloseReason) -> Self { Self { reason } } } @@ -138,7 +186,7 @@ impl TryFrom for ClosePacket { return Err(Self::Error::PacketTooSmall); } Ok(Self { - reason: bytes.get_u8(), + reason: bytes.get_u8().try_into()?, }) } } @@ -146,7 +194,7 @@ impl TryFrom for ClosePacket { impl From for Vec { fn from(packet: ClosePacket) -> Self { let mut encoded = Self::with_capacity(1); - encoded.put_u8(packet.reason); + encoded.put_u8(packet.reason as u8); encoded } } @@ -240,7 +288,7 @@ impl Packet { } /// Create a new close packet. - pub fn new_close(stream_id: u32, reason: u8) -> Self { + pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { Self { stream_id, packet: PacketType::Close(ClosePacket::new(reason)), diff --git a/wisp/src/sink_unfold.rs b/wisp/src/sink_unfold.rs new file mode 100644 index 0000000..ee9e337 --- /dev/null +++ b/wisp/src/sink_unfold.rs @@ -0,0 +1,109 @@ +//! futures sink unfold with a close function +use core::{future::Future, pin::Pin}; +use futures::ready; +use futures::task::{Context, Poll}; +use futures::Sink; +use pin_project_lite::pin_project; + +pin_project! { + /// UnfoldState used for stream and sink unfolds + #[project = UnfoldStateProj] + #[project_replace = UnfoldStateProjReplace] + #[derive(Debug)] + pub(crate) enum UnfoldState { + Value { + value: T, + }, + Future { + #[pin] + future: Fut, + }, + Empty, + } +} + +impl UnfoldState { + pub(crate) fn project_future(self: Pin<&mut Self>) -> Option> { + match self.project() { + UnfoldStateProj::Future { future } => Some(future), + _ => None, + } + } + + pub(crate) fn take_value(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Value { .. } => match self.project_replace(Self::Empty) { + UnfoldStateProjReplace::Value { value } => Some(value), + _ => unreachable!(), + }, + _ => None, + } + } +} + +pin_project! { + /// Sink for the [`unfold`] function. + #[derive(Debug)] + #[must_use = "sinks do nothing unless polled"] + pub struct Unfold { + function: F, + close_function: FC, + #[pin] + state: UnfoldState, + } +} + +pub(crate) fn unfold(init: T, function: F, close_function: FC) -> Unfold +where + F: FnMut(T, Item) -> R, + R: Future>, + FC: Fn() -> Result<(), E>, +{ + Unfold { function, close_function, state: UnfoldState::Value { value: init } } +} + +impl Sink for Unfold +where + F: FnMut(T, Item) -> R, + R: Future>, + FC: Fn() -> Result<(), E>, +{ + type Error = E; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + let mut this = self.project(); + let future = match this.state.as_mut().take_value() { + Some(value) => (this.function)(value, item), + None => panic!("start_send called without poll_ready being called first"), + }; + this.state.set(UnfoldState::Future { future }); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + Poll::Ready(if let Some(future) = this.state.as_mut().project_future() { + match ready!(future.poll(cx)) { + Ok(state) => { + this.state.set(UnfoldState::Value { value: state }); + Ok(()) + } + Err(err) => { + this.state.set(UnfoldState::Empty); + Err(err) + } + } + } else { + Ok(()) + }) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + Poll::Ready((self.close_function)()) + } +} diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 1c249a2..84ac0ce 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -3,7 +3,7 @@ use bytes::Bytes; use event_listener::Event; use futures::{ channel::{mpsc, oneshot}, - sink, stream, + stream, task::{Context, Poll}, Sink, Stream, StreamExt, }; @@ -25,7 +25,7 @@ pub enum MuxEvent { } pub(crate) enum WsEvent { - Close(u32, u8, oneshot::Sender>), + Close(u32, crate::CloseReason, oneshot::Sender>), } /// Read side of a multiplexor stream. @@ -143,7 +143,7 @@ impl MuxStreamWrite { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } @@ -159,9 +159,12 @@ impl MuxStreamWrite { } pub(crate) fn into_sink(self) -> Pin + Send>> { - Box::pin(sink::unfold(self, |tx, data| async move { + let handle = self.get_close_handle(); + Box::pin(crate::sink_unfold::unfold(self, |tx, data| async move { tx.write(data).await?; Ok(tx) + }, move || { + handle.close_sync(crate::CloseReason::Unknown) })) } } @@ -171,7 +174,7 @@ impl Drop for MuxStreamWrite { let (tx, _) = oneshot::channel::>(); let _ = self .close_channel - .unbounded_send(WsEvent::Close(self.stream_id, 0x01, tx)); + .unbounded_send(WsEvent::Close(self.stream_id, crate::CloseReason::Unknown, tx)); } } @@ -248,7 +251,7 @@ impl MuxStream { } /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { self.tx.close(reason).await } @@ -267,6 +270,7 @@ impl MuxStream { } /// Close handle for a multiplexor stream. +#[derive(Clone)] pub struct MuxStreamCloser { /// ID of the stream. pub stream_id: u32, @@ -276,7 +280,7 @@ pub struct MuxStreamCloser { impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { + pub async fn close(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(crate::WispError::StreamAlreadyClosed); } @@ -289,6 +293,19 @@ impl MuxStreamCloser { self.is_closed.store(true, Ordering::Release); Ok(()) } + + /// Close the stream. This function does not check if it was actually closed. + pub(crate) fn close_sync(&self, reason: crate::CloseReason) -> Result<(), crate::WispError> { + if self.is_closed.load(Ordering::Acquire) { + return Err(crate::WispError::StreamAlreadyClosed); + } + let (tx, _) = oneshot::channel::>(); + self.close_channel + .unbounded_send(WsEvent::Close(self.stream_id, reason, tx)) + .map_err(|x| crate::WispError::Other(Box::new(x)))?; + self.is_closed.store(true, Ordering::Release); + Ok(()) + } } pin_project! {