mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
expose close reasons
This commit is contained in:
parent
8cbab94955
commit
569789c2a0
9 changed files with 294 additions and 74 deletions
|
@ -1,4 +1,10 @@
|
|||
use std::{pin::Pin, sync::Arc, task::Poll};
|
||||
use std::{
|
||||
io::ErrorKind,
|
||||
ops::{Deref, DerefMut},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use futures_rustls::{
|
||||
rustls::{ClientConfig, RootCertStore},
|
||||
|
@ -16,7 +22,7 @@ use wasm_bindgen_futures::spawn_local;
|
|||
use webpki_roots::TLS_SERVER_ROOTS;
|
||||
use wisp_mux::{
|
||||
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
|
||||
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
|
||||
ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType,
|
||||
};
|
||||
|
||||
use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
|
||||
|
@ -32,6 +38,94 @@ lazy_static! {
|
|||
};
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct CloserWrapper<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
pub closer: MuxStreamCloser,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CloserWrapper<T> {
|
||||
pub fn new(inner: T, closer: MuxStreamCloser) -> Self {
|
||||
Self { inner, closer }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> T {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for CloserWrapper<T> {
|
||||
type Target = T;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for CloserWrapper<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for CloserWrapper<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_read_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &mut [std::io::IoSliceMut<'_>],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.project().inner.poll_read_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for CloserWrapper<T> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CloserWrapper<MuxStreamIo>> for CloserWrapper<MuxStreamAsyncRW> {
|
||||
fn from(value: CloserWrapper<MuxStreamIo>) -> Self {
|
||||
let CloserWrapper { inner, closer } = value;
|
||||
CloserWrapper::new(inner.into_asyncrw(), closer)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StreamProvider {
|
||||
wisp_url: String,
|
||||
|
||||
|
@ -42,8 +136,8 @@ pub struct StreamProvider {
|
|||
current_client: Arc<Mutex<Option<ClientMux>>>,
|
||||
}
|
||||
|
||||
pub type ProviderUnencryptedStream = MuxStreamIo;
|
||||
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
|
||||
pub type ProviderUnencryptedStream = CloserWrapper<MuxStreamIo>;
|
||||
pub type ProviderUnencryptedAsyncRW = CloserWrapper<MuxStreamAsyncRW>;
|
||||
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
|
||||
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
|
||||
|
||||
|
@ -101,10 +195,9 @@ impl StreamProvider {
|
|||
Box::pin(async {
|
||||
let locked = self.current_client.lock().await;
|
||||
if let Some(mux) = locked.as_ref() {
|
||||
Ok(mux
|
||||
.client_new_stream(stream_type, host, port)
|
||||
.await?
|
||||
.into_io())
|
||||
let stream = mux.client_new_stream(stream_type, host, port).await?;
|
||||
let closer = stream.get_close_handle();
|
||||
Ok(CloserWrapper::new(stream.into_io(), closer))
|
||||
} else {
|
||||
self.create_client(locked).await?;
|
||||
self.get_stream(stream_type, host, port).await
|
||||
|
@ -119,10 +212,7 @@ impl StreamProvider {
|
|||
host: String,
|
||||
port: u16,
|
||||
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
|
||||
Ok(self
|
||||
.get_stream(stream_type, host, port)
|
||||
.await?
|
||||
.into_asyncrw())
|
||||
Ok(self.get_stream(stream_type, host, port).await?.into())
|
||||
}
|
||||
|
||||
pub async fn get_tls_stream(
|
||||
|
@ -134,7 +224,22 @@ impl StreamProvider {
|
|||
.get_asyncread(StreamType::Tcp, host.clone(), port)
|
||||
.await?;
|
||||
let connector = TlsConnector::from(CLIENT_CONFIG.clone());
|
||||
Ok(connector.connect(host.try_into()?, stream).await?.into())
|
||||
let ret = connector
|
||||
.connect(host.try_into()?, stream)
|
||||
.into_fallible()
|
||||
.await;
|
||||
match ret {
|
||||
Ok(stream) => Ok(stream.into()),
|
||||
Err((err, stream)) => {
|
||||
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
|
||||
// maybe actually a wisp error?
|
||||
if let Some(reason) = stream.closer.get_close_reason() {
|
||||
return Err(reason.into());
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue