switch to draft 5 handshake

This commit is contained in:
Toshit Chawda 2024-10-24 00:10:15 -07:00
parent c8de5524b4
commit cda7ed2190
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
7 changed files with 194 additions and 136 deletions

View file

@ -1,53 +1,37 @@
mod client;
mod server;
use std::{future::Future, pin::Pin, time::Duration};
use std::{future::Future, pin::Pin};
pub use client::ClientMux;
use futures::{select, FutureExt};
use futures_timer::Delay;
pub use server::ServerMux;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
ws::{Frame, LockedWebSocketWrite},
CloseReason, Packet, PacketType, Role, WispError,
};
async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
role: Role,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<Frame<'static>>, bool), WispError>
where
R: WebSocketRead + Send,
{
let mut supported_extensions = Vec::new();
let mut extra_packet: Option<Frame<'static>> = None;
let mut downgraded = true;
struct WispHandshakeResult {
kind: WispHandshakeResultKind,
downgraded: bool,
}
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, role, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.replace(Frame::from(packet).clone());
enum WispHandshakeResultKind {
V2 {
extensions: Vec<AnyProtocolExtension>,
},
V1 {
frame: Option<Frame<'static>>,
},
}
impl WispHandshakeResultKind {
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Frame<'static>>) {
match self {
Self::V2 { extensions } => (extensions, None),
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
}
}
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(read, write).await?;
}
Ok((supported_extensions, extra_packet, downgraded))
}
async fn send_info_packet(
@ -67,19 +51,42 @@ async fn send_info_packet(
.await
}
fn validate_continue_packet(packet: Packet<'_>) -> Result<u32, WispError> {
if packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
let PacketType::Continue(continue_packet) = packet.packet_type else {
return Err(WispError::InvalidPacketType);
};
Ok(continue_packet.buffer_remaining)
}
fn get_supported_extensions(
extensions: Vec<AnyProtocolExtension>,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Vec<AnyProtocolExtension> {
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect()
}
trait Multiplexor {
fn has_extension(&self, extension_id: u8) -> bool;
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
}
/// Result of creating a multiplexor. Helps require protocol extensions.
#[allow(private_bounds)]
#[expect(private_bounds)]
pub struct MuxResult<M, F>(M, F)
where
M: Multiplexor,
F: Future<Output = Result<(), WispError>> + Send;
#[allow(private_bounds)]
#[expect(private_bounds)]
impl<M, F> MuxResult<M, F>
where
M: Multiplexor,