fix tls close_notify error, empty data packet spam, compilation errors

This commit is contained in:
Toshit Chawda 2024-08-24 22:29:16 -07:00
parent 6f0e1e7feb
commit 028c0c1332
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 111 additions and 23 deletions

View file

@ -269,8 +269,8 @@ impl EpoxyClient {
return Err(EpoxyError::WebSocketConnectFailed); return Err(EpoxyError::WebSocketConnectFailed);
} }
Ok(( Ok((
Box::new(read) as Box<dyn WebSocketRead + Send + Sync>, Box::new(read) as Box<dyn WebSocketRead + Send>,
Box::new(write) as Box<dyn WebSocketWrite + Send + Sync>, Box::new(write) as Box<dyn WebSocketWrite + Send>,
)) ))
}) })
}), }),
@ -311,8 +311,8 @@ impl EpoxyClient {
}; };
Ok(( Ok((
Box::new(read) as Box<dyn WebSocketRead + Send + Sync>, Box::new(read) as Box<dyn WebSocketRead + Send>,
Box::new(write) as Box<dyn WebSocketWrite + Send + Sync>, Box::new(write) as Box<dyn WebSocketWrite + Send>,
)) ))
})) }))
}), }),

View file

@ -2,7 +2,7 @@ use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
use futures_rustls::{ use futures_rustls::{
rustls::{ClientConfig, RootCertStore}, rustls::{ClientConfig, RootCertStore},
TlsConnector, TlsStream, TlsConnector,
}; };
use futures_util::{ use futures_util::{
future::Either, future::Either,
@ -20,7 +20,7 @@ use wisp_mux::{
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
}; };
use crate::{console_log, EpoxyClientOptions, EpoxyError}; use crate::{console_log, utils::IgnoreCloseNotify, EpoxyClientOptions, EpoxyError};
lazy_static! { lazy_static! {
static ref CLIENT_CONFIG: Arc<ClientConfig> = { static ref CLIENT_CONFIG: Arc<ClientConfig> = {
@ -35,22 +35,24 @@ lazy_static! {
pub type ProviderUnencryptedStream = MuxStreamIo; pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>; pub type ProviderTlsAsyncRW = IgnoreCloseNotify;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>; pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
pub type ProviderWispTransportGenerator = Box< pub type ProviderWispTransportGenerator = Box<
dyn Fn() -> Pin< dyn Fn() -> Pin<
Box< Box<
dyn Future< dyn Future<
Output = Result< Output = Result<
( (
Box<dyn WebSocketRead + Sync + Send>, Box<dyn WebSocketRead + Send>,
Box<dyn WebSocketWrite + Sync + Send>, Box<dyn WebSocketWrite + Send>,
), ),
EpoxyError, EpoxyError,
>, >,
> + Sync + Send, > + Sync
>, + Send,
> + Sync + Send, >,
> + Sync
+ Send,
>; >;
pub struct StreamProvider { pub struct StreamProvider {
@ -153,7 +155,9 @@ impl StreamProvider {
.into_fallible() .into_fallible()
.await; .await;
match ret { match ret {
Ok(stream) => Ok(stream.into()), Ok(stream) => Ok(IgnoreCloseNotify {
inner: stream.into(),
}),
Err((err, stream)) => { Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) { if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error? // maybe actually a wisp error?

View file

@ -1,11 +1,13 @@
use std::{ use std::{
io::ErrorKind,
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
use futures_util::{ready, AsyncRead, Future, Stream, StreamExt, TryStreamExt}; use futures_rustls::TlsStream;
use futures_util::{ready, AsyncRead, AsyncWrite, Future, Stream, StreamExt, TryStreamExt};
use http::{HeaderValue, Uri}; use http::{HeaderValue, Uri};
use hyper::{body::Body, rt::Executor}; use hyper::{body::Body, rt::Executor};
use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array}; use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array};
@ -20,7 +22,7 @@ use wisp_mux::{
WispError, WispError,
}; };
use crate::EpoxyError; use crate::{stream_provider::ProviderUnencryptedAsyncRW, EpoxyError};
#[wasm_bindgen] #[wasm_bindgen]
extern "C" { extern "C" {
@ -230,6 +232,83 @@ impl WebSocketWrite for WispTransportWrite {
} }
} }
fn map_close_notify(x: std::io::Result<usize>) -> std::io::Result<usize> {
match x {
Ok(x) => Ok(x),
Err(x) => {
// hacky way to find if it's actually a rustls close notify error
if x.kind() == ErrorKind::UnexpectedEof
&& format!("{:?}", x).contains("TLS close_notify")
{
Ok(0)
} else {
Err(x)
}
}
}
}
pin_project! {
pub struct IgnoreCloseNotify {
#[pin]
pub inner: TlsStream<ProviderUnencryptedAsyncRW>,
}
}
impl AsyncRead for IgnoreCloseNotify {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project()
.inner
.poll_read(cx, buf)
.map(map_close_notify)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project()
.inner
.poll_read_vectored(cx, bufs)
.map(map_close_notify)
}
}
impl AsyncWrite for IgnoreCloseNotify {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project()
.inner
.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project()
.inner
.poll_write_vectored(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
pub fn is_redirect(code: u16) -> bool { pub fn is_redirect(code: u16) -> bool {
[301, 302, 303, 307, 308].contains(&code) [301, 302, 303, 307, 308].contains(&code)
} }

View file

@ -48,8 +48,13 @@ async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow:
let mut tcprx = BufReader::new(tcprx); let mut tcprx = BufReader::new(tcprx);
loop { loop {
let buf = tcprx.fill_buf().await?; let buf = tcprx.fill_buf().await?;
muxtx.write(&buf).await?;
let len = buf.len(); let len = buf.len();
if len == 0 {
return Ok(())
}
muxtx.write(&buf).await?;
tcprx.consume(len); tcprx.consume(len);
} }
} }