expose close reasons

This commit is contained in:
Toshit Chawda 2024-08-02 23:01:47 -07:00
parent 8cbab94955
commit 569789c2a0
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
9 changed files with 294 additions and 74 deletions

12
Cargo.lock generated
View file

@ -155,6 +155,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atomic_enum"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "autocfg"
version = "1.3.0"
@ -2178,6 +2189,7 @@ name = "wisp-mux"
version = "5.0.1"
dependencies = [
"async-trait",
"atomic_enum",
"bytes",
"dashmap 5.5.3",
"event-listener",

View file

@ -111,7 +111,7 @@ pub struct EpoxyUdpStream {
#[wasm_bindgen]
impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (mut rx, tx) = stream.into_split();
let (mut rx, tx) = stream.into_inner().into_split();
let EpoxyHandlers {
onopen,

View file

@ -30,6 +30,7 @@ use wasm_streams::ReadableStream;
use web_sys::ResponseInit;
#[cfg(feature = "full")]
use websocket::EpoxyWebSocket;
use wisp_mux::CloseReason;
#[cfg(feature = "full")]
use wisp_mux::StreamType;
@ -50,6 +51,8 @@ pub enum EpoxyError {
InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError),
#[error("Wisp: {0:?} ({0})")]
Wisp(#[from] wisp_mux::WispError),
#[error("Wisp server closed: {0}")]
WispCloseReason(wisp_mux::CloseReason),
#[error("IO: {0:?} ({0})")]
Io(#[from] std::io::Error),
#[error("HTTP: {0:?} ({0})")]
@ -61,9 +64,6 @@ pub enum EpoxyError {
#[error("HTTP ToStr: {0:?} ({0})")]
ToStr(#[from] http::header::ToStrError),
#[cfg(feature = "full")]
#[error("Getrandom: {0:?} ({0})")]
GetRandom(#[from] getrandom::Error),
#[cfg(feature = "full")]
#[error("Fastwebsockets: {0:?} ({0})")]
FastWebSockets(#[from] fastwebsockets::WebSocketError),
@ -135,6 +135,12 @@ impl From<InvalidMethod> for EpoxyError {
}
}
impl From<CloseReason> for EpoxyError {
fn from(value: CloseReason) -> Self {
EpoxyError::WispCloseReason(value)
}
}
#[derive(Debug)]
enum EpoxyResponse {
Success(Response<Incoming>),

View file

@ -1,4 +1,10 @@
use std::{pin::Pin, sync::Arc, task::Poll};
use std::{
io::ErrorKind,
ops::{Deref, DerefMut},
pin::Pin,
sync::Arc,
task::Poll,
};
use futures_rustls::{
rustls::{ClientConfig, RootCertStore},
@ -16,7 +22,7 @@ use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType,
};
use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
@ -32,6 +38,94 @@ lazy_static! {
};
}
pin_project! {
pub struct CloserWrapper<T> {
#[pin]
pub inner: T,
pub closer: MuxStreamCloser,
}
}
impl<T> CloserWrapper<T> {
pub fn new(inner: T, closer: MuxStreamCloser) -> Self {
Self { inner, closer }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> Deref for CloserWrapper<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for CloserWrapper<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T: AsyncRead> AsyncRead for CloserWrapper<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read_vectored(cx, bufs)
}
}
impl<T: AsyncWrite> AsyncWrite for CloserWrapper<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl From<CloserWrapper<MuxStreamIo>> for CloserWrapper<MuxStreamAsyncRW> {
fn from(value: CloserWrapper<MuxStreamIo>) -> Self {
let CloserWrapper { inner, closer } = value;
CloserWrapper::new(inner.into_asyncrw(), closer)
}
}
pub struct StreamProvider {
wisp_url: String,
@ -42,8 +136,8 @@ pub struct StreamProvider {
current_client: Arc<Mutex<Option<ClientMux>>>,
}
pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderUnencryptedStream = CloserWrapper<MuxStreamIo>;
pub type ProviderUnencryptedAsyncRW = CloserWrapper<MuxStreamAsyncRW>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
@ -101,10 +195,9 @@ impl StreamProvider {
Box::pin(async {
let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() {
Ok(mux
.client_new_stream(stream_type, host, port)
.await?
.into_io())
let stream = mux.client_new_stream(stream_type, host, port).await?;
let closer = stream.get_close_handle();
Ok(CloserWrapper::new(stream.into_io(), closer))
} else {
self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await
@ -119,10 +212,7 @@ impl StreamProvider {
host: String,
port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self
.get_stream(stream_type, host, port)
.await?
.into_asyncrw())
Ok(self.get_stream(stream_type, host, port).await?.into())
}
pub async fn get_tls_stream(
@ -134,7 +224,22 @@ impl StreamProvider {
.get_asyncread(StreamType::Tcp, host.clone(), port)
.await?;
let connector = TlsConnector::from(CLIENT_CONFIG.clone());
Ok(connector.connect(host.try_into()?, stream).await?.into())
let ret = connector
.connect(host.try_into()?, stream)
.into_fallible()
.await;
match ret {
Ok(stream) => Ok(stream.into()),
Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error?
if let Some(reason) = stream.closer.get_close_reason() {
return Err(reason.into());
}
}
Err(err.into())
}
}
}
}

View file

@ -2,7 +2,6 @@ use std::str::FromStr;
use anyhow::Context;
use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector};
use futures_util::io::Close;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
select,

View file

@ -10,6 +10,7 @@ edition = "2021"
[dependencies]
async-trait = "0.1.79"
atomic_enum = "0.3.0"
bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0"

View file

@ -157,9 +157,12 @@ impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::Sender<Bytes>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
is_closed_event: Arc<Event>,
}
@ -239,15 +242,20 @@ impl MuxInner {
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let close_reason: Arc<AtomicCloseReason> =
AtomicCloseReason::new(CloseReason::Unknown).into();
let is_closed_event: Arc<Event> = Event::new().into();
Ok((
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
close_reason: close_reason.clone(),
is_closed_event: is_closed_event.clone(),
},
MuxStream::new(
@ -259,6 +267,7 @@ impl MuxInner {
tx,
is_closed,
is_closed_event,
close_reason,
flow_control,
flow_control_event,
target_buffer_size,
@ -309,6 +318,9 @@ impl MuxInner {
}
WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
if let PacketType::Close(close) = packet.packet_type {
self.close_stream(packet.stream_id, close);
}
let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else {
@ -328,8 +340,11 @@ impl MuxInner {
}
}
fn close_stream(&self, packet: Packet) {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) {
if let Some((_, stream)) = self.stream_map.remove(&stream_id) {
stream
.close_reason
.store(close_packet.reason, Ordering::Release);
stream.is_closed.store(true, Ordering::Release);
stream.is_closed_event.notify(usize::MAX);
stream.flow_control.store(u32::MAX, Ordering::Release);
@ -410,11 +425,11 @@ impl MuxInner {
}
}
}
Close(_) => {
Close(inner_packet) => {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet)
self.close_stream(packet.stream_id, inner_packet)
}
}
}
@ -472,11 +487,11 @@ impl MuxInner {
}
}
}
Close(_) => {
Close(inner_packet) => {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet)
self.close_stream(packet.stream_id, inner_packet);
}
}
}

View file

@ -38,11 +38,20 @@ impl From<StreamType> for u8 {
}
}
mod close {
use std::fmt::Display;
use atomic_enum::atomic_enum;
use crate::WispError;
/// Close reason.
///
/// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons)
#[derive(Debug, PartialEq, Copy, Clone)]
#[derive(PartialEq)]
#[repr(u8)]
#[atomic_enum]
pub enum CloseReason {
/// Reason unspecified or unknown.
Unknown = 0x01,
@ -92,6 +101,38 @@ impl TryFrom<u8> for CloseReason {
}
}
impl Display for CloseReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use CloseReason as C;
write!(
f,
"{}",
match self {
C::Unknown => "Unknown close reason",
C::Voluntary => "Voluntarily closed",
C::Unexpected => "Unexpectedly closed",
C::IncompatibleExtensions => "Incompatible protocol extensions",
C::ServerStreamInvalidInfo =>
"Stream creation failed due to invalid information",
C::ServerStreamUnreachable =>
"Stream creation failed due to an unreachable destination",
C::ServerStreamConnectionTimedOut =>
"Stream creation failed due to destination not responding",
C::ServerStreamConnectionRefused =>
"Stream creation failed due to destination refusing connection",
C::ServerStreamTimedOut => "TCP timed out",
C::ServerStreamBlockedAddress => "Destination address is blocked",
C::ServerStreamThrottled => "Throttled",
C::ClientUnexpected => "Client encountered unexpected error",
}
)
}
}
}
pub(crate) use close::AtomicCloseReason;
pub use close::CloseReason;
trait Encode {
fn encode(self, bytes: &mut BytesMut);
}

View file

@ -1,7 +1,7 @@
use crate::{
sink_unfold,
ws::{Frame, LockedWebSocketWrite, Payload},
CloseReason, Packet, Role, StreamType, WispError,
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
};
use bytes::{BufMut, Bytes, BytesMut};
@ -40,11 +40,16 @@ pub struct MuxStreamRead {
pub stream_id: u32,
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
tx: LockedWebSocketWrite,
rx: mpsc::Receiver<Bytes>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
close_reason: Arc<AtomicCloseReason>,
flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32,
target_flow_control: u32,
@ -91,6 +96,15 @@ impl MuxStreamRead {
rx: self.into_inner_stream(),
}
}
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
}
}
}
/// Write side of a multiplexor stream.
@ -99,10 +113,14 @@ pub struct MuxStreamWrite {
pub stream_id: u32,
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
mux_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>,
}
@ -165,6 +183,7 @@ impl MuxStreamWrite {
stream_id: self.stream_id,
close_channel: self.mux_tx.clone(),
is_closed: self.is_closed.clone(),
close_reason: self.close_reason.clone(),
}
}
@ -197,6 +216,15 @@ impl MuxStreamWrite {
Ok(())
}
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
}
}
pub(crate) fn into_inner_sink(
self,
) -> Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>> {
@ -255,6 +283,7 @@ impl MuxStream {
tx: LockedWebSocketWrite,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
close_reason: Arc<AtomicCloseReason>,
flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>,
target_flow_control: u32,
@ -269,6 +298,7 @@ impl MuxStream {
rx,
is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(),
close_reason: close_reason.clone(),
flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0),
target_flow_control,
@ -280,6 +310,7 @@ impl MuxStream {
mux_tx,
tx,
is_closed: is_closed.clone(),
close_reason: close_reason.clone(),
flow_control: flow_control.clone(),
continue_recieved: continue_recieved.clone(),
},
@ -347,6 +378,7 @@ pub struct MuxStreamCloser {
pub stream_id: u32,
close_channel: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
}
impl MuxStreamCloser {
@ -369,6 +401,15 @@ impl MuxStreamCloser {
Ok(())
}
/// Get the stream's close reason, if it was closed.
pub fn get_close_reason(&self) -> Option<CloseReason> {
if self.is_closed.load(Ordering::Acquire) {
Some(self.close_reason.load(Ordering::Acquire))
} else {
None
}
}
}
/// Stream for sending arbitrary protocol extension packets.