From 9e7f05e381e1291ac999774cb2fe8cca51119151 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 20 Oct 2024 10:28:38 -0700 Subject: [PATCH] add alpn for h2 --- client/src/lib.rs | 2 +- client/src/stream_provider.rs | 49 +++++++++++++++++++++++++++++------ client/src/utils.rs | 1 + 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/client/src/lib.rs b/client/src/lib.rs index 49301be..d4069ab 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -485,7 +485,7 @@ impl EpoxyClient { let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; let stream = self .stream_provider - .get_tls_stream(host.to_string(), port) + .get_tls_stream(host.to_string(), port, false) .await?; Ok(iostream_from_asyncrw( Either::Left(stream), diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 4bc6ec3..288965f 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -56,6 +56,7 @@ pub struct StreamProvider { current_client: Arc>>, + h2_config: Arc, client_config: Arc, } @@ -64,7 +65,7 @@ impl StreamProvider { wisp_generator: ProviderWispTransportGenerator, options: &EpoxyClientOptions, ) -> Result { - let client_config = if options.disable_certificate_validation { + let mut client_config = if options.disable_certificate_validation { ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoCertificateVerification::new( @@ -91,6 +92,12 @@ impl StreamProvider { ClientConfig::builder().with_root_certificates(certstore) } .with_no_client_auth(); + let no_alpn_client_config = Arc::new(client_config.clone()); + #[cfg(feature = "full")] + { + client_config.alpn_protocols = + vec!["h2".as_bytes().to_vec(), "http/1.1".as_bytes().to_vec()]; + } let client_config = Arc::new(client_config); Ok(Self { @@ -98,7 +105,8 @@ impl StreamProvider { current_client: Arc::new(Mutex::new(None)), wisp_v2: options.wisp_v2, udp_extension: options.udp_extension_required, - client_config, + h2_config: client_config, + client_config: no_alpn_client_config, }) } @@ -172,19 +180,33 @@ impl StreamProvider { &self, host: String, port: u16, + http: bool, ) -> Result { let stream = self .get_asyncread(StreamType::Tcp, host.clone(), port) .await?; - let connector = TlsConnector::from(self.client_config.clone()); + let connector = TlsConnector::from(if http { + self.h2_config.clone() + } else { + self.client_config.clone() + }); let ret = connector .connect(host.try_into()?, stream) .into_fallible() .await; match ret { - Ok(stream) => Ok(IgnoreCloseNotify { - inner: stream.into(), - }), + Ok(stream) => { + let h2_negotiated = stream + .get_ref() + .1 + .alpn_protocol() + .map(|x| x == "h2".as_bytes()) + .unwrap_or(false); + Ok(IgnoreCloseNotify { + inner: stream.into(), + h2_negotiated, + }) + } Err((err, stream)) => { if matches!(err.kind(), ErrorKind::UnexpectedEof) { // maybe actually a wisp error? @@ -259,7 +281,17 @@ impl hyper::rt::Write for HyperIo { impl Connection for HyperIo { fn connected(&self) -> Connected { - Connected::new() + let conn = Connected::new(); + let conn = if let Either::Left(tls_stream) = &self.inner { + if tls_stream.h2_negotiated { + conn.negotiated_h2() + } else { + conn + } + } else { + conn + }; + conn } } @@ -283,7 +315,8 @@ impl ConnectSvc for StreamProviderService { })?; Ok(HyperIo { inner: match scheme { - "https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?), + "https" => Either::Left(provider.get_tls_stream(host, port, true).await?), + "wss" => Either::Left(provider.get_tls_stream(host, port, false).await?), "http" | "ws" => { Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) } diff --git a/client/src/utils.rs b/client/src/utils.rs index b57c152..a79d2f5 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -269,6 +269,7 @@ pin_project! { pub struct IgnoreCloseNotify { #[pin] pub inner: TlsStream, + pub h2_negotiated: bool, } }