add alpn for h2

This commit is contained in:
Toshit Chawda 2024-10-20 10:28:38 -07:00
parent 19ad891b1d
commit 9e7f05e381
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 43 additions and 9 deletions

View file

@ -56,6 +56,7 @@ pub struct StreamProvider {
current_client: Arc<Mutex<Option<ClientMux>>>,
h2_config: Arc<ClientConfig>,
client_config: Arc<ClientConfig>,
}
@ -64,7 +65,7 @@ impl StreamProvider {
wisp_generator: ProviderWispTransportGenerator,
options: &EpoxyClientOptions,
) -> Result<Self, EpoxyError> {
let client_config = if options.disable_certificate_validation {
let mut client_config = if options.disable_certificate_validation {
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification::new(
@ -91,6 +92,12 @@ impl StreamProvider {
ClientConfig::builder().with_root_certificates(certstore)
}
.with_no_client_auth();
let no_alpn_client_config = Arc::new(client_config.clone());
#[cfg(feature = "full")]
{
client_config.alpn_protocols =
vec!["h2".as_bytes().to_vec(), "http/1.1".as_bytes().to_vec()];
}
let client_config = Arc::new(client_config);
Ok(Self {
@ -98,7 +105,8 @@ impl StreamProvider {
current_client: Arc::new(Mutex::new(None)),
wisp_v2: options.wisp_v2,
udp_extension: options.udp_extension_required,
client_config,
h2_config: client_config,
client_config: no_alpn_client_config,
})
}
@ -172,19 +180,33 @@ impl StreamProvider {
&self,
host: String,
port: u16,
http: bool,
) -> Result<ProviderTlsAsyncRW, EpoxyError> {
let stream = self
.get_asyncread(StreamType::Tcp, host.clone(), port)
.await?;
let connector = TlsConnector::from(self.client_config.clone());
let connector = TlsConnector::from(if http {
self.h2_config.clone()
} else {
self.client_config.clone()
});
let ret = connector
.connect(host.try_into()?, stream)
.into_fallible()
.await;
match ret {
Ok(stream) => Ok(IgnoreCloseNotify {
inner: stream.into(),
}),
Ok(stream) => {
let h2_negotiated = stream
.get_ref()
.1
.alpn_protocol()
.map(|x| x == "h2".as_bytes())
.unwrap_or(false);
Ok(IgnoreCloseNotify {
inner: stream.into(),
h2_negotiated,
})
}
Err((err, stream)) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// maybe actually a wisp error?
@ -259,7 +281,17 @@ impl hyper::rt::Write for HyperIo {
impl Connection for HyperIo {
fn connected(&self) -> Connected {
Connected::new()
let conn = Connected::new();
let conn = if let Either::Left(tls_stream) = &self.inner {
if tls_stream.h2_negotiated {
conn.negotiated_h2()
} else {
conn
}
} else {
conn
};
conn
}
}
@ -283,7 +315,8 @@ impl ConnectSvc for StreamProviderService {
})?;
Ok(HyperIo {
inner: match scheme {
"https" | "wss" => Either::Left(provider.get_tls_stream(host, port).await?),
"https" => Either::Left(provider.get_tls_stream(host, port, true).await?),
"wss" => Either::Left(provider.get_tls_stream(host, port, false).await?),
"http" | "ws" => {
Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?)
}