diff --git a/client/src/lib.rs b/client/src/lib.rs index da00140..3d79084 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -269,8 +269,8 @@ impl EpoxyClient { return Err(EpoxyError::WebSocketConnectFailed); } Ok(( - Box::new(read) as Box, - Box::new(write) as Box, + Box::new(read) as Box, + Box::new(write) as Box, )) }) }), @@ -311,8 +311,8 @@ impl EpoxyClient { }; Ok(( - Box::new(read) as Box, - Box::new(write) as Box, + Box::new(read) as Box, + Box::new(write) as Box, )) })) }), diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 54ab34a..bde717c 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -2,7 +2,7 @@ use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll}; use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, - TlsConnector, TlsStream, + TlsConnector, }; use futures_util::{ future::Either, @@ -20,7 +20,7 @@ use wisp_mux::{ ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, }; -use crate::{console_log, EpoxyClientOptions, EpoxyError}; +use crate::{console_log, utils::IgnoreCloseNotify, EpoxyClientOptions, EpoxyError}; lazy_static! { static ref CLIENT_CONFIG: Arc = { @@ -35,22 +35,24 @@ lazy_static! { pub type ProviderUnencryptedStream = MuxStreamIo; pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; -pub type ProviderTlsAsyncRW = TlsStream; +pub type ProviderTlsAsyncRW = IgnoreCloseNotify; pub type ProviderAsyncRW = Either; pub type ProviderWispTransportGenerator = Box< dyn Fn() -> Pin< - Box< - dyn Future< - Output = Result< - ( - Box, - Box, - ), - EpoxyError, - >, - > + Sync + Send, - >, - > + Sync + Send, + Box< + dyn Future< + Output = Result< + ( + Box, + Box, + ), + EpoxyError, + >, + > + Sync + + Send, + >, + > + Sync + + Send, >; pub struct StreamProvider { @@ -153,7 +155,9 @@ impl StreamProvider { .into_fallible() .await; match ret { - Ok(stream) => Ok(stream.into()), + Ok(stream) => Ok(IgnoreCloseNotify { + inner: stream.into(), + }), Err((err, stream)) => { if matches!(err.kind(), ErrorKind::UnexpectedEof) { // maybe actually a wisp error? diff --git a/client/src/utils.rs b/client/src/utils.rs index fa507ff..9693944 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -1,11 +1,13 @@ use std::{ + io::ErrorKind, pin::Pin, task::{Context, Poll}, }; use async_trait::async_trait; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; -use futures_util::{ready, AsyncRead, Future, Stream, StreamExt, TryStreamExt}; +use futures_rustls::TlsStream; +use futures_util::{ready, AsyncRead, AsyncWrite, Future, Stream, StreamExt, TryStreamExt}; use http::{HeaderValue, Uri}; use hyper::{body::Body, rt::Executor}; use js_sys::{Array, ArrayBuffer, JsString, Object, Uint8Array}; @@ -20,7 +22,7 @@ use wisp_mux::{ WispError, }; -use crate::EpoxyError; +use crate::{stream_provider::ProviderUnencryptedAsyncRW, EpoxyError}; #[wasm_bindgen] extern "C" { @@ -230,6 +232,83 @@ impl WebSocketWrite for WispTransportWrite { } } +fn map_close_notify(x: std::io::Result) -> std::io::Result { + match x { + Ok(x) => Ok(x), + Err(x) => { + // hacky way to find if it's actually a rustls close notify error + if x.kind() == ErrorKind::UnexpectedEof + && format!("{:?}", x).contains("TLS close_notify") + { + Ok(0) + } else { + Err(x) + } + } + } +} + +pin_project! { + pub struct IgnoreCloseNotify { + #[pin] + pub inner: TlsStream, + } +} + +impl AsyncRead for IgnoreCloseNotify { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project() + .inner + .poll_read(cx, buf) + .map(map_close_notify) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project() + .inner + .poll_read_vectored(cx, bufs) + .map(map_close_notify) + } +} + +impl AsyncWrite for IgnoreCloseNotify { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project() + .inner + .poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project() + .inner + .poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + pub fn is_redirect(code: u16) -> bool { [301, 302, 303, 307, 308].contains(&code) } diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index f9f9958..0d74e42 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -48,8 +48,13 @@ async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow: let mut tcprx = BufReader::new(tcprx); loop { let buf = tcprx.fill_buf().await?; - muxtx.write(&buf).await?; + let len = buf.len(); + if len == 0 { + return Ok(()) + } + + muxtx.write(&buf).await?; tcprx.consume(len); } }