mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
switch to draft 5 handshake
This commit is contained in:
parent
c8de5524b4
commit
cda7ed2190
7 changed files with 194 additions and 136 deletions
|
@ -12,12 +12,71 @@ use futures::channel::oneshot;
|
|||
use crate::{
|
||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||
inner::{MuxInner, WsEvent},
|
||||
mux::send_info_packet,
|
||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
||||
WispError,
|
||||
};
|
||||
|
||||
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
|
||||
use super::{
|
||||
get_supported_extensions, validate_continue_packet, Multiplexor, MuxResult,
|
||||
WispHandshakeResult, WispHandshakeResultKind, WispV2Extensions,
|
||||
};
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite,
|
||||
v2_info: Option<WispV2Extensions>,
|
||||
) -> Result<(WispHandshakeResult, u32), WispError> {
|
||||
if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = v2_info
|
||||
{
|
||||
let packet =
|
||||
Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Client, &mut builders)?;
|
||||
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
// v2 server
|
||||
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?;
|
||||
|
||||
(closure)(&mut builders).await?;
|
||||
send_info_packet(tx, &mut builders).await?;
|
||||
|
||||
Ok((
|
||||
WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V2 {
|
||||
extensions: get_supported_extensions(info.extensions, &mut builders),
|
||||
},
|
||||
downgraded: false,
|
||||
},
|
||||
buffer_size,
|
||||
))
|
||||
} else {
|
||||
// downgrade to v1
|
||||
let buffer_size = validate_continue_packet(packet)?;
|
||||
|
||||
Ok((
|
||||
WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 { frame: None },
|
||||
downgraded: true,
|
||||
},
|
||||
buffer_size,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
// user asked for a v1 client
|
||||
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?;
|
||||
|
||||
Ok((
|
||||
WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 { frame: None },
|
||||
downgraded: false,
|
||||
},
|
||||
buffer_size,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
pub struct ClientMux {
|
||||
|
@ -48,49 +107,29 @@ impl ClientMux {
|
|||
W: WebSocketWrite + Send + 'static,
|
||||
{
|
||||
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
|
||||
|
||||
if first_packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
|
||||
let (extensions, frame) = handshake_result.kind.into_parts();
|
||||
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = wisp_v2
|
||||
{
|
||||
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
||||
// if not downgraded
|
||||
if !res.2 {
|
||||
(closure)(&mut builders).await?;
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
}
|
||||
res
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
let mux_inner = MuxInner::new_client(
|
||||
AppendingWebSocketRead(frame, rx),
|
||||
tx.clone(),
|
||||
extensions.clone(),
|
||||
buffer_size,
|
||||
);
|
||||
|
||||
let mux_result = MuxInner::new_client(
|
||||
AppendingWebSocketRead(extra_packet, rx),
|
||||
tx.clone(),
|
||||
supported_extensions.clone(),
|
||||
packet.buffer_remaining,
|
||||
);
|
||||
Ok(MuxResult(
|
||||
Self {
|
||||
actor_tx: mux_inner.actor_tx,
|
||||
actor_exited: mux_inner.actor_exited,
|
||||
|
||||
Ok(MuxResult(
|
||||
Self {
|
||||
actor_tx: mux_result.actor_tx,
|
||||
downgraded,
|
||||
supported_extensions,
|
||||
tx,
|
||||
actor_exited: mux_result.actor_exited,
|
||||
},
|
||||
mux_result.mux.into_future(),
|
||||
))
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
}
|
||||
tx,
|
||||
|
||||
downgraded: handshake_result.downgraded,
|
||||
supported_extensions: extensions,
|
||||
},
|
||||
mux_inner.mux.into_future(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new stream, multiplexed through Wisp.
|
||||
|
|
|
@ -1,53 +1,37 @@
|
|||
mod client;
|
||||
mod server;
|
||||
use std::{future::Future, pin::Pin, time::Duration};
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
pub use client::ClientMux;
|
||||
use futures::{select, FutureExt};
|
||||
use futures_timer::Delay;
|
||||
pub use server::ServerMux;
|
||||
|
||||
use crate::{
|
||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
|
||||
ws::{Frame, LockedWebSocketWrite},
|
||||
CloseReason, Packet, PacketType, Role, WispError,
|
||||
};
|
||||
|
||||
async fn maybe_wisp_v2<R>(
|
||||
read: &mut R,
|
||||
write: &LockedWebSocketWrite,
|
||||
role: Role,
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Result<(Vec<AnyProtocolExtension>, Option<Frame<'static>>, bool), WispError>
|
||||
where
|
||||
R: WebSocketRead + Send,
|
||||
{
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet: Option<Frame<'static>> = None;
|
||||
let mut downgraded = true;
|
||||
struct WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind,
|
||||
downgraded: bool,
|
||||
}
|
||||
|
||||
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(write).fuse() => Some(x?),
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, role, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.replace(Frame::from(packet).clone());
|
||||
enum WispHandshakeResultKind {
|
||||
V2 {
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
},
|
||||
V1 {
|
||||
frame: Option<Frame<'static>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WispHandshakeResultKind {
|
||||
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Frame<'static>>) {
|
||||
match self {
|
||||
Self::V2 { extensions } => (extensions, None),
|
||||
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
|
||||
}
|
||||
}
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(read, write).await?;
|
||||
}
|
||||
Ok((supported_extensions, extra_packet, downgraded))
|
||||
}
|
||||
|
||||
async fn send_info_packet(
|
||||
|
@ -67,19 +51,42 @@ async fn send_info_packet(
|
|||
.await
|
||||
}
|
||||
|
||||
fn validate_continue_packet(packet: Packet<'_>) -> Result<u32, WispError> {
|
||||
if packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
|
||||
let PacketType::Continue(continue_packet) = packet.packet_type else {
|
||||
return Err(WispError::InvalidPacketType);
|
||||
};
|
||||
|
||||
Ok(continue_packet.buffer_remaining)
|
||||
}
|
||||
|
||||
fn get_supported_extensions(
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
builders: &mut [AnyProtocolExtensionBuilder],
|
||||
) -> Vec<AnyProtocolExtension> {
|
||||
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||
extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
trait Multiplexor {
|
||||
fn has_extension(&self, extension_id: u8) -> bool;
|
||||
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
|
||||
}
|
||||
|
||||
/// Result of creating a multiplexor. Helps require protocol extensions.
|
||||
#[allow(private_bounds)]
|
||||
#[expect(private_bounds)]
|
||||
pub struct MuxResult<M, F>(M, F)
|
||||
where
|
||||
M: Multiplexor,
|
||||
F: Future<Output = Result<(), WispError>> + Send;
|
||||
|
||||
#[allow(private_bounds)]
|
||||
#[expect(private_bounds)]
|
||||
impl<M, F> MuxResult<M, F>
|
||||
where
|
||||
M: Multiplexor,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
ops::DerefMut,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
|
@ -14,10 +13,63 @@ use crate::{
|
|||
extensions::AnyProtocolExtension,
|
||||
inner::{MuxInner, WsEvent},
|
||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
|
||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
|
||||
WispError,
|
||||
};
|
||||
|
||||
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
|
||||
use super::{
|
||||
get_supported_extensions, send_info_packet, Multiplexor, MuxResult, WispHandshakeResult,
|
||||
WispHandshakeResultKind, WispV2Extensions,
|
||||
};
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite,
|
||||
buffer_size: u32,
|
||||
v2_info: Option<WispV2Extensions>,
|
||||
) -> Result<WispHandshakeResult, WispError> {
|
||||
if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = v2_info
|
||||
{
|
||||
send_info_packet(tx, &mut builders).await?;
|
||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
(closure)(&mut builders).await?;
|
||||
|
||||
let packet =
|
||||
Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Server, &mut builders)?;
|
||||
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
// v2 client
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V2 {
|
||||
extensions: get_supported_extensions(info.extensions, &mut builders),
|
||||
},
|
||||
downgraded: false,
|
||||
})
|
||||
} else {
|
||||
// downgrade to v1
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 {
|
||||
frame: Some(packet.into()),
|
||||
},
|
||||
downgraded: true,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// user asked for v1 server
|
||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 { frame: None },
|
||||
downgraded: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
pub struct ServerMux {
|
||||
|
@ -52,36 +104,26 @@ impl ServerMux {
|
|||
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||
let ret_tx = tx.clone();
|
||||
let ret = async {
|
||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = wisp_v2
|
||||
{
|
||||
send_info_packet(&tx, builders.deref_mut()).await?;
|
||||
(closure)(builders.deref_mut()).await?;
|
||||
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?;
|
||||
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||
|
||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||
AppendingWebSocketRead(extra_packet, rx),
|
||||
tx.clone(),
|
||||
supported_extensions.clone(),
|
||||
extensions.clone(),
|
||||
buffer_size,
|
||||
);
|
||||
|
||||
Ok(MuxResult(
|
||||
Self {
|
||||
muxstream_recv,
|
||||
actor_tx: mux_result.actor_tx,
|
||||
downgraded,
|
||||
supported_extensions,
|
||||
tx,
|
||||
actor_exited: mux_result.actor_exited,
|
||||
muxstream_recv,
|
||||
|
||||
tx,
|
||||
|
||||
downgraded: handshake_result.downgraded,
|
||||
supported_extensions: extensions,
|
||||
},
|
||||
mux_result.mux.into_future(),
|
||||
))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue