add alpn for h2

This commit is contained in:
Toshit Chawda 2024-10-20 10:28:38 -07:00
parent 19ad891b1d
commit 9e7f05e381
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 43 additions and 9 deletions

View file

@ -485,7 +485,7 @@ impl EpoxyClient {
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
let stream = self let stream = self
.stream_provider .stream_provider
.get_tls_stream(host.to_string(), port) .get_tls_stream(host.to_string(), port, false)
.await?; .await?;
Ok(iostream_from_asyncrw( Ok(iostream_from_asyncrw(
Either::Left(stream), Either::Left(stream),

View file

@ -56,6 +56,7 @@ pub struct StreamProvider {
current_client: Arc<Mutex<Option<ClientMux>>>, current_client: Arc<Mutex<Option<ClientMux>>>,
h2_config: Arc<ClientConfig>,
client_config: Arc<ClientConfig>, client_config: Arc<ClientConfig>,
} }
@ -64,7 +65,7 @@ impl StreamProvider {
wisp_generator: ProviderWispTransportGenerator, wisp_generator: ProviderWispTransportGenerator,
options: &EpoxyClientOptions, options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> { ) -> Result<Self, EpoxyError> {
let client_config = if options.disable_certificate_validation { let mut client_config = if options.disable_certificate_validation {
ClientConfig::builder() ClientConfig::builder()
.dangerous() .dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification::new( .with_custom_certificate_verifier(Arc::new(NoCertificateVerification::new(
@ -91,6 +92,12 @@ impl StreamProvider {
ClientConfig::builder().with_root_certificates(certstore) ClientConfig::builder().with_root_certificates(certstore)
} }
.with_no_client_auth(); .with_no_client_auth();
let no_alpn_client_config = Arc::new(client_config.clone());
#[cfg(feature = "full")]
{
client_config.alpn_protocols =
vec!["h2".as_bytes().to_vec(), "http/1.1".as_bytes().to_vec()];
}
let client_config = Arc::new(client_config); let client_config = Arc::new(client_config);
Ok(Self { Ok(Self {
@ -98,7 +105,8 @@ impl StreamProvider {
current_client: Arc::new(Mutex::new(None)), current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2, wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required, udp_extension: options.udp_extension_required,
client_config, h2_config: client_config,
client_config: no_alpn_client_config,
}) })
} }
@ -172,19 +180,33 @@ impl StreamProvider {
&self, &self,
host: String, host: String,
port: u16, port: u16,
http: bool,
) -> Result<ProviderTlsAsyncRW, EpoxyError> { ) -> Result<ProviderTlsAsyncRW, EpoxyError> {
let stream = self let stream = self
.get_asyncread(StreamType::Tcp, host.clone(), port) .get_asyncread(StreamType::Tcp, host.clone(), port)
.await?; .await?;
let connector = TlsConnector::from(self.client_config.clone()); let connector = TlsConnector::from(if http {
self.h2_config.clone()
} else {
self.client_config.clone()
});
let ret = connector let ret = connector
.connect(host.try_into()?, stream) .connect(host.try_into()?, stream)
.into_fallible() .into_fallible()
.await; .await;
match ret { match ret {
Ok(stream) => Ok(IgnoreCloseNotify { Ok(stream) => {
inner: stream.into(), let h2_negotiated = stream
}), .get_ref()
.1
.alpn_protocol()
.map(|x| x == "h2".as_bytes())
.unwrap_or(false);
Ok(IgnoreCloseNotify {
inner: stream.into(),
h2_negotiated,
})
}
Err((err, stream)) => { Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) { if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error? // maybe actually a wisp error?
@ -259,7 +281,17 @@ impl hyper::rt::Write for HyperIo {
impl Connection for HyperIo { impl Connection for HyperIo {
fn connected(&self) -> Connected { fn connected(&self) -> Connected {
Connected::new() let conn = Connected::new();
let conn = if let Either::Left(tls_stream) = &self.inner {
if tls_stream.h2_negotiated {
conn.negotiated_h2()
} else {
conn
}
} else {
conn
};
conn
} }
} }
@ -283,7 +315,8 @@ impl ConnectSvc for StreamProviderService {
})?; })?;
Ok(HyperIo { Ok(HyperIo {
inner: match scheme { inner: match scheme {
"https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), "https" => Either::Left(provider.get_tls_stream(host, port, true).await?),
"wss" => Either::Left(provider.get_tls_stream(host, port, false).await?),
"http" | "ws" => { "http" | "ws" => {
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
} }

View file

@ -269,6 +269,7 @@ pin_project! {
pub struct IgnoreCloseNotify { pub struct IgnoreCloseNotify {
#[pin] #[pin]
pub inner: TlsStream<ProviderUnencryptedAsyncRW>, pub inner: TlsStream<ProviderUnencryptedAsyncRW>,
pub h2_negotiated: bool,
} }
} }