custom wisp transport support

This commit is contained in:
Toshit Chawda 2024-08-16 23:29:33 -07:00
parent 80b68f1cee
commit 16268905fc
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 313 additions and 135 deletions

View file

@ -1,10 +1,4 @@
use std::{
io::ErrorKind,
ops::{Deref, DerefMut},
pin::Pin,
sync::Arc,
task::Poll,
};
use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll};
use futures_rustls::{
rustls::{ClientConfig, RootCertStore},
@ -22,10 +16,11 @@ use wasm_bindgen_futures::spawn_local;
use webpki_roots::TLS_SERVER_ROOTS;
use wisp_mux::{
extensions::{udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder},
ClientMux, MuxStreamAsyncRW, MuxStreamCloser, MuxStreamIo, StreamType,
ws::{WebSocketRead, WebSocketWrite},
ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType,
};
use crate::{console_log, ws_wrapper::WebSocketWrapper, EpoxyClientOptions, EpoxyError};
use crate::{console_log, EpoxyClientOptions, EpoxyError};
lazy_static! {
static ref CLIENT_CONFIG: Arc<ClientConfig> = {
@ -38,117 +33,45 @@ lazy_static! {
};
}
pin_project! {
pub struct CloserWrapper<T> {
#[pin]
pub inner: T,
pub closer: MuxStreamCloser,
}
}
impl<T> CloserWrapper<T> {
pub fn new(inner: T, closer: MuxStreamCloser) -> Self {
Self { inner, closer }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> Deref for CloserWrapper<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for CloserWrapper<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T: AsyncRead> AsyncRead for CloserWrapper<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_read_vectored(cx, bufs)
}
}
impl<T: AsyncWrite> AsyncWrite for CloserWrapper<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl From<CloserWrapper<MuxStreamIo>> for CloserWrapper<MuxStreamAsyncRW> {
fn from(value: CloserWrapper<MuxStreamIo>) -> Self {
let CloserWrapper { inner, closer } = value;
CloserWrapper::new(inner.into_asyncrw(), closer)
}
}
pub type ProviderUnencryptedStream = MuxStreamIo;
pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
pub type ProviderWispTransportGenerator = Box<
dyn Fn() -> Pin<
Box<
dyn Future<
Output = Result<
(
Box<dyn WebSocketRead + Sync + Send>,
Box<dyn WebSocketWrite + Sync + Send>,
),
EpoxyError,
>,
> + Sync + Send,
>,
> + Sync + Send,
>;
pub struct StreamProvider {
wisp_url: String,
wisp_generator: ProviderWispTransportGenerator,
wisp_v2: bool,
udp_extension: bool,
websocket_protocols: Vec<String>,
current_client: Arc<Mutex<Option<ClientMux>>>,
}
pub type ProviderUnencryptedStream = CloserWrapper<MuxStreamIo>;
pub type ProviderUnencryptedAsyncRW = CloserWrapper<MuxStreamAsyncRW>;
pub type ProviderTlsAsyncRW = TlsStream<ProviderUnencryptedAsyncRW>;
pub type ProviderAsyncRW = Either<ProviderTlsAsyncRW, ProviderUnencryptedAsyncRW>;
impl StreamProvider {
pub fn new(wisp_url: String, options: &EpoxyClientOptions) -> Result<Self, EpoxyError> {
pub fn new(
wisp_generator: ProviderWispTransportGenerator,
options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> {
Ok(Self {
wisp_url,
wisp_generator,
current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required,
websocket_protocols: options.websocket_protocols.clone(),
})
}
@ -163,10 +86,9 @@ impl StreamProvider {
} else {
None
};
let (write, read) = WebSocketWrapper::connect(&self.wisp_url, &self.websocket_protocols)?;
if !write.wait_for_open().await {
return Err(EpoxyError::WebSocketConnectFailed);
}
let (read, write) = (self.wisp_generator)().await?;
let client = ClientMux::create(read, write, extensions).await?;
let (mux, fut) = if self.udp_extension {
client.with_udp_extension_required().await?
@ -196,8 +118,7 @@ impl StreamProvider {
let locked = self.current_client.lock().await;
if let Some(mux) = locked.as_ref() {
let stream = mux.client_new_stream(stream_type, host, port).await?;
let closer = stream.get_close_handle();
Ok(CloserWrapper::new(stream.into_io(), closer))
Ok(stream.into_io())
} else {
self.create_client(locked).await?;
self.get_stream(stream_type, host, port).await
@ -212,7 +133,10 @@ impl StreamProvider {
host: String,
port: u16,
) -> Result<ProviderUnencryptedAsyncRW, EpoxyError> {
Ok(self.get_stream(stream_type, host, port).await?.into())
Ok(self
.get_stream(stream_type, host, port)
.await?
.into_asyncrw())
}
pub async fn get_tls_stream(
@ -233,7 +157,7 @@ impl StreamProvider {
Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error?
if let Some(reason) = stream.closer.get_close_reason() {
if let Some(reason) = stream.get_close_reason() {
return Err(reason.into());
}
}