From 569789c2a02a4c2e2e6ac48bf2eaccf68d03ac2b Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 2 Aug 2024 23:01:47 -0700 Subject: [PATCH] expose close reasons --- Cargo.lock | 12 +++ client/src/io_stream.rs | 2 +- client/src/lib.rs | 12 ++- client/src/stream_provider.rs | 131 ++++++++++++++++++++++++++++---- server/src/handle/wsproxy.rs | 1 - wisp/Cargo.toml | 1 + wisp/src/lib.rs | 27 +++++-- wisp/src/packet.rs | 139 ++++++++++++++++++++++------------ wisp/src/stream.rs | 43 ++++++++++- 9 files changed, 294 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3dae43..6ce7d4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_enum" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -2178,6 +2189,7 @@ name = "wisp-mux" version = "5.0.1" dependencies = [ "async-trait", + "atomic_enum", "bytes", "dashmap 5.5.3", "event-listener", diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index 86d7be6..d0df66e 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -111,7 +111,7 @@ pub struct EpoxyUdpStream { #[wasm_bindgen] impl EpoxyUdpStream { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { - let (mut rx, tx) = stream.into_split(); + let (mut rx, tx) = stream.into_inner().into_split(); let EpoxyHandlers { onopen, diff --git a/client/src/lib.rs b/client/src/lib.rs index 5c039b0..b9e7d64 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -30,6 +30,7 @@ use wasm_streams::ReadableStream; use web_sys::ResponseInit; #[cfg(feature = "full")] use websocket::EpoxyWebSocket; +use wisp_mux::CloseReason; #[cfg(feature = "full")] use wisp_mux::StreamType; @@ -50,6 +51,8 @@ pub enum EpoxyError { InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), #[error("Wisp: {0:?} ({0})")] Wisp(#[from] wisp_mux::WispError), + #[error("Wisp server closed: {0}")] + WispCloseReason(wisp_mux::CloseReason), #[error("IO: {0:?} ({0})")] Io(#[from] std::io::Error), #[error("HTTP: {0:?} ({0})")] @@ -61,9 +64,6 @@ pub enum EpoxyError { #[error("HTTP ToStr: {0:?} ({0})")] ToStr(#[from] http::header::ToStrError), #[cfg(feature = "full")] - #[error("Getrandom: {0:?} ({0})")] - GetRandom(#[from] getrandom::Error), - #[cfg(feature = "full")] #[error("Fastwebsockets: {0:?} ({0})")] FastWebSockets(#[from] fastwebsockets::WebSocketError), @@ -135,6 +135,12 @@ impl From for EpoxyError { } } +impl From for EpoxyError { + fn from(value: CloseReason) -> Self { + EpoxyError::WispCloseReason(value) + } +} + #[derive(Debug)] enum EpoxyResponse { Success(Response), diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index ec6e394..6f232e6 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -1,4 +1,10 @@ -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{ + io::ErrorKind, + ops::{Deref, DerefMut}, + pin::Pin, + sync::Arc, + task::Poll, +}; use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, @@ -16,7 +22,7 @@ use wasm_bindgen_futures::spawn_local; use webpki_roots::TLS_SERVER_ROOTS; use wisp_mux::{ extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder}, - ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, + ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType, }; use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError}; @@ -32,6 +38,94 @@ lazy_static! { }; } +pin_project! { + pub struct CloserWrapper { + #[pin] + pub inner: T, + pub closer: MuxStreamCloser, + } +} + +impl CloserWrapper { + pub fn new(inner: T, closer: MuxStreamCloser) -> Self { + Self { inner, closer } + } + + pub fn into_inner(self) -> T { + self.inner + } +} + +impl Deref for CloserWrapper { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for CloserWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl AsyncRead for CloserWrapper { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().inner.poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project().inner.poll_read_vectored(cx, bufs) + } +} + +impl AsyncWrite for CloserWrapper { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +impl From> for CloserWrapper { + fn from(value: CloserWrapper) -> Self { + let CloserWrapper { inner, closer } = value; + CloserWrapper::new(inner.into_asyncrw(), closer) + } +} + pub struct StreamProvider { wisp_url: String, @@ -42,8 +136,8 @@ pub struct StreamProvider { current_client: Arc>>, } -pub type ProviderUnencryptedStream = MuxStreamIo; -pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; +pub type ProviderUnencryptedStream = CloserWrapper; +pub type ProviderUnencryptedAsyncRW = CloserWrapper; pub type ProviderTlsAsyncRW = TlsStream; pub type ProviderAsyncRW = Either; @@ -101,10 +195,9 @@ impl StreamProvider { Box::pin(async { let locked = self.current_client.lock().await; if let Some(mux) = locked.as_ref() { - Ok(mux - .client_new_stream(stream_type, host, port) - .await? - .into_io()) + let stream = mux.client_new_stream(stream_type, host, port).await?; + let closer = stream.get_close_handle(); + Ok(CloserWrapper::new(stream.into_io(), closer)) } else { self.create_client(locked).await?; self.get_stream(stream_type, host, port).await @@ -119,10 +212,7 @@ impl StreamProvider { host: String, port: u16, ) -> Result { - Ok(self - .get_stream(stream_type, host, port) - .await? - .into_asyncrw()) + Ok(self.get_stream(stream_type, host, port).await?.into()) } pub async fn get_tls_stream( @@ -134,7 +224,22 @@ impl StreamProvider { .get_asyncread(StreamType::Tcp, host.clone(), port) .await?; let connector = TlsConnector::from(CLIENT_CONFIG.clone()); - Ok(connector.connect(host.try_into()?, stream).await?.into()) + let ret = connector + .connect(host.try_into()?, stream) + .into_fallible() + .await; + match ret { + Ok(stream) => Ok(stream.into()), + Err((err, stream)) => { + if matches!(err.kind(), ErrorKind::UnexpectedEof) { + // maybe actually a wisp error? + if let Some(reason) = stream.closer.get_close_reason() { + return Err(reason.into()); + } + } + Err(err.into()) + } + } } } diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs index ea85852..7a18d8e 100644 --- a/server/src/handle/wsproxy.rs +++ b/server/src/handle/wsproxy.rs @@ -2,7 +2,6 @@ use std::str::FromStr; use anyhow::Context; use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector}; -use futures_util::io::Close; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, select, diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 26e4922..7bd7aab 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -10,6 +10,7 @@ edition = "2021" [dependencies] async-trait = "0.1.79" +atomic_enum = "0.3.0" bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 8d54b3a..a078f02 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -157,9 +157,12 @@ impl std::error::Error for WispError {} struct MuxMapValue { stream: mpsc::Sender, stream_type: StreamType, + flow_control: Arc, flow_control_event: Arc, + is_closed: Arc, + close_reason: Arc, is_closed_event: Arc, } @@ -239,15 +242,20 @@ impl MuxInner { let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); let is_closed: Arc = AtomicBool::new(false).into(); + let close_reason: Arc = + AtomicCloseReason::new(CloseReason::Unknown).into(); let is_closed_event: Arc = Event::new().into(); Ok(( MuxMapValue { stream: ch_tx, stream_type, + flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + close_reason: close_reason.clone(), is_closed_event: is_closed_event.clone(), }, MuxStream::new( @@ -259,6 +267,7 @@ impl MuxInner { tx, is_closed, is_closed_event, + close_reason, flow_control, flow_control_event, target_buffer_size, @@ -309,6 +318,9 @@ impl MuxInner { } WsEvent::Close(packet, channel) => { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + if let PacketType::Close(close) = packet.packet_type { + self.close_stream(packet.stream_id, close); + } let _ = channel.send(self.tx.write_frame(packet.into()).await); drop(stream.stream) } else { @@ -328,8 +340,11 @@ impl MuxInner { } } - fn close_stream(&self, packet: Packet) { - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) { + if let Some((_, stream)) = self.stream_map.remove(&stream_id) { + stream + .close_reason + .store(close_packet.reason, Ordering::Release); stream.is_closed.store(true, Ordering::Release); stream.is_closed_event.notify(usize::MAX); stream.flow_control.store(u32::MAX, Ordering::Release); @@ -410,11 +425,11 @@ impl MuxInner { } } } - Close(_) => { + Close(inner_packet) => { if packet.stream_id == 0 { break Ok(()); } - self.close_stream(packet) + self.close_stream(packet.stream_id, inner_packet) } } } @@ -472,11 +487,11 @@ impl MuxInner { } } } - Close(_) => { + Close(inner_packet) => { if packet.stream_id == 0 { break Ok(()); } - self.close_stream(packet) + self.close_stream(packet.stream_id, inner_packet); } } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 85a82e7..ce857d2 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -38,60 +38,101 @@ impl From for u8 { } } -/// Close reason. -/// -/// See [the -/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) -#[derive(Debug, PartialEq, Copy, Clone)] -pub enum CloseReason { - /// Reason unspecified or unknown. - Unknown = 0x01, - /// Voluntary stream closure. - Voluntary = 0x02, - /// Unexpected stream closure due to a network error. - Unexpected = 0x03, - /// Incompatible extensions. Only used during the handshake. - IncompatibleExtensions = 0x04, - /// Stream creation failed due to invalid information. - ServerStreamInvalidInfo = 0x41, - /// Stream creation failed due to an unreachable destination host. - ServerStreamUnreachable = 0x42, - /// Stream creation timed out due to the destination server not responding. - ServerStreamConnectionTimedOut = 0x43, - /// Stream creation failed due to the destination server refusing the connection. - ServerStreamConnectionRefused = 0x44, - /// TCP data transfer timed out. - ServerStreamTimedOut = 0x47, - /// Stream destination address/domain is intentionally blocked by the proxy server. - ServerStreamBlockedAddress = 0x48, - /// Connection throttled by the server. - ServerStreamThrottled = 0x49, - /// The client has encountered an unexpected error. - ClientUnexpected = 0x81, -} +mod close { + use std::fmt::Display; -impl TryFrom for CloseReason { - type Error = WispError; - fn try_from(close_reason: u8) -> Result { - use CloseReason as R; - match close_reason { - 0x01 => Ok(R::Unknown), - 0x02 => Ok(R::Voluntary), - 0x03 => Ok(R::Unexpected), - 0x04 => Ok(R::IncompatibleExtensions), - 0x41 => Ok(R::ServerStreamInvalidInfo), - 0x42 => Ok(R::ServerStreamUnreachable), - 0x43 => Ok(R::ServerStreamConnectionTimedOut), - 0x44 => Ok(R::ServerStreamConnectionRefused), - 0x47 => Ok(R::ServerStreamTimedOut), - 0x48 => Ok(R::ServerStreamBlockedAddress), - 0x49 => Ok(R::ServerStreamThrottled), - 0x81 => Ok(R::ClientUnexpected), - _ => Err(Self::Error::InvalidCloseReason), + use atomic_enum::atomic_enum; + + use crate::WispError; + + /// Close reason. + /// + /// See [the + /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) + #[derive(PartialEq)] + #[repr(u8)] + #[atomic_enum] + pub enum CloseReason { + /// Reason unspecified or unknown. + Unknown = 0x01, + /// Voluntary stream closure. + Voluntary = 0x02, + /// Unexpected stream closure due to a network error. + Unexpected = 0x03, + /// Incompatible extensions. Only used during the handshake. + IncompatibleExtensions = 0x04, + /// Stream creation failed due to invalid information. + ServerStreamInvalidInfo = 0x41, + /// Stream creation failed due to an unreachable destination host. + ServerStreamUnreachable = 0x42, + /// Stream creation timed out due to the destination server not responding. + ServerStreamConnectionTimedOut = 0x43, + /// Stream creation failed due to the destination server refusing the connection. + ServerStreamConnectionRefused = 0x44, + /// TCP data transfer timed out. + ServerStreamTimedOut = 0x47, + /// Stream destination address/domain is intentionally blocked by the proxy server. + ServerStreamBlockedAddress = 0x48, + /// Connection throttled by the server. + ServerStreamThrottled = 0x49, + /// The client has encountered an unexpected error. + ClientUnexpected = 0x81, + } + + impl TryFrom for CloseReason { + type Error = WispError; + fn try_from(close_reason: u8) -> Result { + use CloseReason as R; + match close_reason { + 0x01 => Ok(R::Unknown), + 0x02 => Ok(R::Voluntary), + 0x03 => Ok(R::Unexpected), + 0x04 => Ok(R::IncompatibleExtensions), + 0x41 => Ok(R::ServerStreamInvalidInfo), + 0x42 => Ok(R::ServerStreamUnreachable), + 0x43 => Ok(R::ServerStreamConnectionTimedOut), + 0x44 => Ok(R::ServerStreamConnectionRefused), + 0x47 => Ok(R::ServerStreamTimedOut), + 0x48 => Ok(R::ServerStreamBlockedAddress), + 0x49 => Ok(R::ServerStreamThrottled), + 0x81 => Ok(R::ClientUnexpected), + _ => Err(Self::Error::InvalidCloseReason), + } + } + } + + impl Display for CloseReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use CloseReason as C; + write!( + f, + "{}", + match self { + C::Unknown => "Unknown close reason", + C::Voluntary => "Voluntarily closed", + C::Unexpected => "Unexpectedly closed", + C::IncompatibleExtensions => "Incompatible protocol extensions", + C::ServerStreamInvalidInfo => + "Stream creation failed due to invalid information", + C::ServerStreamUnreachable => + "Stream creation failed due to an unreachable destination", + C::ServerStreamConnectionTimedOut => + "Stream creation failed due to destination not responding", + C::ServerStreamConnectionRefused => + "Stream creation failed due to destination refusing connection", + C::ServerStreamTimedOut => "TCP timed out", + C::ServerStreamBlockedAddress => "Destination address is blocked", + C::ServerStreamThrottled => "Throttled", + C::ClientUnexpected => "Client encountered unexpected error", + } + ) } } } +pub(crate) use close::AtomicCloseReason; +pub use close::CloseReason; + trait Encode { fn encode(self, bytes: &mut BytesMut); } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 7233f8c..eb7c045 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,7 +1,7 @@ use crate::{ sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, - CloseReason, Packet, Role, StreamType, WispError, + AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -40,11 +40,16 @@ pub struct MuxStreamRead { pub stream_id: u32, /// Type of the stream. pub stream_type: StreamType, + role: Role, + tx: LockedWebSocketWrite, rx: mpsc::Receiver, + is_closed: Arc, is_closed_event: Arc, + close_reason: Arc, + flow_control: Arc, flow_control_read: AtomicU32, target_flow_control: u32, @@ -91,6 +96,15 @@ impl MuxStreamRead { rx: self.into_inner_stream(), } } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None + } + } } /// Write side of a multiplexor stream. @@ -99,10 +113,14 @@ pub struct MuxStreamWrite { pub stream_id: u32, /// Type of the stream. pub stream_type: StreamType, + role: Role, mux_tx: mpsc::Sender, tx: LockedWebSocketWrite, + is_closed: Arc, + close_reason: Arc, + continue_recieved: Arc, flow_control: Arc, } @@ -165,6 +183,7 @@ impl MuxStreamWrite { stream_id: self.stream_id, close_channel: self.mux_tx.clone(), is_closed: self.is_closed.clone(), + close_reason: self.close_reason.clone(), } } @@ -197,6 +216,15 @@ impl MuxStreamWrite { Ok(()) } + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None + } + } + pub(crate) fn into_inner_sink( self, ) -> Pin, Error = WispError> + Send>> { @@ -255,6 +283,7 @@ impl MuxStream { tx: LockedWebSocketWrite, is_closed: Arc, is_closed_event: Arc, + close_reason: Arc, flow_control: Arc, continue_recieved: Arc, target_flow_control: u32, @@ -269,6 +298,7 @@ impl MuxStream { rx, is_closed: is_closed.clone(), is_closed_event: is_closed_event.clone(), + close_reason: close_reason.clone(), flow_control: flow_control.clone(), flow_control_read: AtomicU32::new(0), target_flow_control, @@ -280,6 +310,7 @@ impl MuxStream { mux_tx, tx, is_closed: is_closed.clone(), + close_reason: close_reason.clone(), flow_control: flow_control.clone(), continue_recieved: continue_recieved.clone(), }, @@ -347,6 +378,7 @@ pub struct MuxStreamCloser { pub stream_id: u32, close_channel: mpsc::Sender, is_closed: Arc, + close_reason: Arc, } impl MuxStreamCloser { @@ -369,6 +401,15 @@ impl MuxStreamCloser { Ok(()) } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + if self.is_closed.load(Ordering::Acquire) { + Some(self.close_reason.load(Ordering::Acquire)) + } else { + None + } + } } /// Stream for sending arbitrary protocol extension packets.