expose close reasons

This commit is contained in:
Toshit Chawda 2024-08-02 23:01:47 -07:00
parent 8cbab94955
commit 569789c2a0
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
9 changed files with 294 additions and 74 deletions

View file

@ -1,4 +1,10 @@
use std::{pin::Pin, sync::Arc, task::Poll};
use std::{
io::ErrorKind,
ops::{Deref, DerefMut},
pin::Pin,
sync::Arc,
task::Poll,
};
use futures_rustls::{
rustls::{ClientConfig, RootCertStore},
@ -16,7 +22,7 @@ use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType,
};
use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
@ -32,6 +38,94 @@ lazy_static! {
};
}
pin_project! {
pub struct CloserWrapper<T> {
#[pin]
pub inner: T,
pub closer: MuxStreamCloser,
}
}
impl<T> CloserWrapper<T> {
pub fn new(inner: T, closer: MuxStreamCloser) -> Self {
Self { inner, closer }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> Deref for CloserWrapper<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for CloserWrapper<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T: AsyncRead> AsyncRead for CloserWrapper<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read_vectored(cx, bufs)
}
}
impl<T: AsyncWrite> AsyncWrite for CloserWrapper<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl From<CloserWrapper<MuxStreamIo>> for CloserWrapper<MuxStreamAsyncRW> {
fn from(value: CloserWrapper<MuxStreamIo>) -> Self {
let CloserWrapper { inner, closer } = value;
CloserWrapper::new(inner.into_asyncrw(), closer)
}
}
pub struct StreamProvider {
wisp_url: String,
@ -42,8 +136,8 @@ pub struct StreamProvider {
current_client: Arc<Mutex<Option<ClientMux>>>,
}
pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderUnencryptedStream = CloserWrapper<MuxStreamIo>;
pub type ProviderUnencryptedAsyncRW = CloserWrapper<MuxStreamAsyncRW>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
@ -101,10 +195,9 @@ impl StreamProvider {
Box::pin(async {
let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() {
Ok(mux
.client_new_stream(stream_type, host, port)
.await?
.into_io())
let stream = mux.client_new_stream(stream_type, host, port).await?;
let closer = stream.get_close_handle();
Ok(CloserWrapper::new(stream.into_io(), closer))
} else {
self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await
@ -119,10 +212,7 @@ impl StreamProvider {
host: String,
port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self
.get_stream(stream_type, host, port)
.await?
.into_asyncrw())
Ok(self.get_stream(stream_type, host, port).await?.into())
}
pub async fn get_tls_stream(
@ -134,7 +224,22 @@ impl StreamProvider {
.get_asyncread(StreamType::Tcp, host.clone(), port)
.await?;
let connector = TlsConnector::from(CLIENT_CONFIG.clone());
Ok(connector.connect(host.try_into()?, stream).await?.into())
let ret = connector
.connect(host.try_into()?, stream)
.into_fallible()
.await;
match ret {
Ok(stream) => Ok(stream.into()),
Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error?
if let Some(reason) = stream.closer.get_close_reason() {
return Err(reason.into());
}
}
Err(err.into())
}
}
}
}