mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
349 lines
9.2 KiB
Rust
349 lines
9.2 KiB
Rust
use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
|
|
|
|
use bytes::BytesMut;
|
|
use cfg_if::cfg_if;
|
|
use futures_rustls::{
|
|
rustls::{ClientConfig, RootCertStore},
|
|
TlsConnector,
|
|
};
|
|
use futures_util::{
|
|
future::Either,
|
|
lock::{Mutex, MutexGuard},
|
|
AsyncRead, AsyncWrite, Future, Stream,
|
|
};
|
|
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
|
|
use pin_project_lite::pin_project;
|
|
use wasm_bindgen_futures::spawn_local;
|
|
use webpki_roots::TLS_SERVER_ROOTS;
|
|
use wisp_mux::{
|
|
extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder},
|
|
generic::GenericWebSocketRead,
|
|
ws::{EitherWebSocketRead, EitherWebSocketWrite},
|
|
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake,
|
|
};
|
|
|
|
use crate::{
|
|
console_error, console_log,
|
|
utils::{IgnoreCloseNotify, NoCertificateVerification, WispTransportWrite},
|
|
ws_wrapper::{WebSocketReader, WebSocketWrapper},
|
|
EpoxyClientOptions, EpoxyError,
|
|
};
|
|
|
|
pub type ProviderUnencryptedStream = MuxStreamIo;
|
|
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
|
|
pub type ProviderTlsAsyncRW = IgnoreCloseNotify;
|
|
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
|
|
pub type ProviderWispTransportRead = EitherWebSocketRead<
|
|
WebSocketReader,
|
|
GenericWebSocketRead<
|
|
Pin<Box<dyn Stream<Item = Result<BytesMut, EpoxyError>> + Send>>,
|
|
EpoxyError,
|
|
>,
|
|
>;
|
|
pub type ProviderWispTransportWrite = EitherWebSocketWrite<WebSocketWrapper, WispTransportWrite>;
|
|
pub type ProviderWispTransportGenerator = Box<
|
|
dyn Fn(
|
|
bool,
|
|
) -> Pin<
|
|
Box<
|
|
dyn Future<
|
|
Output = Result<
|
|
(ProviderWispTransportRead, ProviderWispTransportWrite),
|
|
EpoxyError,
|
|
>,
|
|
> + Sync
|
|
+ Send,
|
|
>,
|
|
> + Sync
|
|
+ Send,
|
|
>;
|
|
|
|
pub struct StreamProvider {
|
|
wisp_generator: ProviderWispTransportGenerator,
|
|
|
|
wisp_v2: bool,
|
|
udp_extension: bool,
|
|
|
|
current_client: Arc<Mutex<Option<ClientMux<ProviderWispTransportWrite>>>>,
|
|
|
|
h2_config: Arc<ClientConfig>,
|
|
client_config: Arc<ClientConfig>,
|
|
}
|
|
|
|
impl StreamProvider {
|
|
pub fn new(
|
|
wisp_generator: ProviderWispTransportGenerator,
|
|
options: &EpoxyClientOptions,
|
|
) -> Result<Self, EpoxyError> {
|
|
let provider = Arc::new(futures_rustls::rustls::crypto::ring::default_provider());
|
|
let client_config = ClientConfig::builder_with_provider(provider.clone())
|
|
.with_safe_default_protocol_versions()?;
|
|
let mut client_config = if options.disable_certificate_validation {
|
|
client_config
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification(provider)))
|
|
} else {
|
|
cfg_if! {
|
|
if #[cfg(feature = "full")] {
|
|
let pems: Result<Result<Vec<_>, webpki::Error>, std::io::Error> = options
|
|
.pem_files
|
|
.iter()
|
|
.flat_map(|x| {
|
|
rustls_pemfile::certs(&mut std::io::BufReader::new(x.as_bytes()))
|
|
.map(|x| x.map(|x| webpki::anchor_from_trusted_cert(&x).map(|x| x.to_owned())))
|
|
.collect::<Vec<_>>()
|
|
})
|
|
.collect();
|
|
let pems = pems.map_err(EpoxyError::Pemfile)??;
|
|
let certstore: RootCertStore = pems.into_iter().chain(TLS_SERVER_ROOTS.iter().cloned()).collect();
|
|
} else {
|
|
let certstore: RootCertStore = TLS_SERVER_ROOTS.iter().cloned().collect();
|
|
}
|
|
}
|
|
client_config.with_root_certificates(certstore)
|
|
}
|
|
.with_no_client_auth();
|
|
let no_alpn_client_config = Arc::new(client_config.clone());
|
|
#[cfg(feature = "full")]
|
|
{
|
|
client_config.alpn_protocols =
|
|
vec!["h2".as_bytes().to_vec(), "http/1.1".as_bytes().to_vec()];
|
|
}
|
|
let client_config = Arc::new(client_config);
|
|
|
|
Ok(Self {
|
|
wisp_generator,
|
|
current_client: Arc::new(Mutex::new(None)),
|
|
wisp_v2: options.wisp_v2,
|
|
udp_extension: options.udp_extension_required,
|
|
h2_config: client_config,
|
|
client_config: no_alpn_client_config,
|
|
})
|
|
}
|
|
|
|
async fn create_client(
|
|
&self,
|
|
mut locked: MutexGuard<'_, Option<ClientMux<ProviderWispTransportWrite>>>,
|
|
) -> Result<(), EpoxyError> {
|
|
let extensions_vec: Vec<AnyProtocolExtensionBuilder> =
|
|
vec![AnyProtocolExtensionBuilder::new(
|
|
UdpProtocolExtensionBuilder,
|
|
)];
|
|
let extensions = if self.wisp_v2 {
|
|
Some(WispV2Handshake::new(extensions_vec))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let (read, write) = (self.wisp_generator)(self.wisp_v2).await?;
|
|
|
|
let client = ClientMux::create(read, write, extensions).await?;
|
|
let (mux, fut) = if self.udp_extension {
|
|
client.with_udp_extension_required().await?
|
|
} else {
|
|
client.with_no_required_extensions()
|
|
};
|
|
locked.replace(mux);
|
|
let current_client = self.current_client.clone();
|
|
spawn_local(async move {
|
|
match fut.await {
|
|
Ok(()) => console_log!("epoxy: wisp multiplexor task ended successfully"),
|
|
Err(x) => console_error!(
|
|
"epoxy: wisp multiplexor task ended with an error: {} {:?}",
|
|
x,
|
|
x
|
|
),
|
|
}
|
|
current_client.lock().await.take();
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn replace_client(&self) -> Result<(), EpoxyError> {
|
|
self.create_client(self.current_client.lock().await).await
|
|
}
|
|
|
|
pub async fn get_stream(
|
|
&self,
|
|
stream_type: StreamType,
|
|
host: String,
|
|
port: u16,
|
|
) -> Result<ProviderUnencryptedStream, EpoxyError> {
|
|
Box::pin(async {
|
|
let locked = self.current_client.lock().await;
|
|
if let Some(mux) = locked.as_ref() {
|
|
let stream = mux.client_new_stream(stream_type, host, port).await?;
|
|
Ok(stream.into_io())
|
|
} else {
|
|
self.create_client(locked).await?;
|
|
self.get_stream(stream_type, host, port).await
|
|
}
|
|
})
|
|
.await
|
|
}
|
|
|
|
pub async fn get_asyncread(
|
|
&self,
|
|
stream_type: StreamType,
|
|
host: String,
|
|
port: u16,
|
|
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
|
|
Ok(self
|
|
.get_stream(stream_type, host, port)
|
|
.await?
|
|
.into_asyncrw())
|
|
}
|
|
|
|
pub async fn get_tls_stream(
|
|
&self,
|
|
host: String,
|
|
port: u16,
|
|
http: bool,
|
|
) -> Result<ProviderTlsAsyncRW, EpoxyError> {
|
|
let stream = self
|
|
.get_asyncread(StreamType::Tcp, host.clone(), port)
|
|
.await?;
|
|
let connector = TlsConnector::from(if http {
|
|
self.h2_config.clone()
|
|
} else {
|
|
self.client_config.clone()
|
|
});
|
|
let ret = connector
|
|
.connect(host.try_into()?, stream)
|
|
.into_fallible()
|
|
.await;
|
|
match ret {
|
|
Ok(stream) => {
|
|
let h2_negotiated = stream
|
|
.get_ref()
|
|
.1
|
|
.alpn_protocol()
|
|
.is_some_and(|x| x == "h2".as_bytes());
|
|
Ok(IgnoreCloseNotify {
|
|
inner: stream.into(),
|
|
h2_negotiated,
|
|
})
|
|
}
|
|
Err((err, stream)) => {
|
|
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
|
|
// maybe actually a wisp error?
|
|
if let Some(reason) = stream.get_close_reason() {
|
|
return Err(EpoxyError::WispCloseReason(reason, err));
|
|
}
|
|
}
|
|
Err(err.into())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pin_project! {
|
|
pub struct HyperIo {
|
|
#[pin]
|
|
inner: ProviderAsyncRW,
|
|
}
|
|
}
|
|
|
|
impl hyper::rt::Read for HyperIo {
|
|
fn poll_read(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
mut buf: hyper::rt::ReadBufCursor<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
let buf_slice: &mut [u8] = unsafe {
|
|
&mut *(std::ptr::from_mut::<[std::mem::MaybeUninit<u8>]>(buf.as_mut()) as *mut [u8])
|
|
};
|
|
match self.project().inner.poll_read(cx, buf_slice) {
|
|
Poll::Ready(bytes_read) => {
|
|
let bytes_read = bytes_read?;
|
|
unsafe {
|
|
buf.advance(bytes_read);
|
|
}
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl hyper::rt::Write for HyperIo {
|
|
fn poll_write(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<Result<usize, std::io::Error>> {
|
|
self.project().inner.poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
self.project().inner.poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
self.project().inner.poll_close(cx)
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
bufs: &[std::io::IoSlice<'_>],
|
|
) -> Poll<Result<usize, std::io::Error>> {
|
|
self.project().inner.poll_write_vectored(cx, bufs)
|
|
}
|
|
}
|
|
|
|
impl Connection for HyperIo {
|
|
fn connected(&self) -> Connected {
|
|
let conn = Connected::new();
|
|
if let Either::Left(tls_stream) = &self.inner {
|
|
if tls_stream.h2_negotiated {
|
|
conn.negotiated_h2()
|
|
} else {
|
|
conn
|
|
}
|
|
} else {
|
|
conn
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct StreamProviderService(pub Arc<StreamProvider>);
|
|
|
|
impl ConnectSvc for StreamProviderService {
|
|
type Connection = HyperIo;
|
|
type Error = EpoxyError;
|
|
type Future = Pin<Box<impl Future<Output = Result<Self::Connection, Self::Error>>>>;
|
|
|
|
fn connect(self, req: hyper::Uri) -> Self::Future {
|
|
let provider = self.0.clone();
|
|
Box::pin(async move {
|
|
let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme(None))?;
|
|
let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string();
|
|
let port = req.port_u16().map_or_else(
|
|
|| match scheme {
|
|
"https" | "wss" => Ok(443),
|
|
"http" | "ws" => Ok(80),
|
|
_ => Err(EpoxyError::NoUrlPort),
|
|
},
|
|
Ok,
|
|
)?;
|
|
Ok(HyperIo {
|
|
inner: match scheme {
|
|
"https" => Either::Left(provider.get_tls_stream(host, port, true).await?),
|
|
"wss" => Either::Left(provider.get_tls_stream(host, port, false).await?),
|
|
"http" | "ws" => {
|
|
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
|
|
}
|
|
_ => return Err(EpoxyError::InvalidUrlScheme(Some(scheme.to_string()))),
|
|
},
|
|
})
|
|
})
|
|
}
|
|
}
|