add password protocol extension, simplify protocol extension api

This commit is contained in:
Toshit Chawda 2024-04-13 16:29:20 -07:00
parent b0d1038a3c
commit 481128e4f5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 343 additions and 58 deletions

View file

@ -458,13 +458,11 @@ pub struct ServerMux {
impl ServerMux {
/// Create a new server-side multiplexor.
///
/// If either extensions or extension_builders are None a Wisp v1 connection is created
/// otherwise a Wisp v2 connection is created.
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
pub async fn new<R, W>(
mut read: R,
write: W,
buffer_size: u32,
extensions: Option<Vec<AnyProtocolExtension>>,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where
@ -483,28 +481,29 @@ impl ServerMux {
let mut extra_packet = Vec::with_capacity(1);
let mut downgraded = true;
if let Some(extensions) = extensions {
if let Some(builders) = extension_builders {
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
// TODO change this to correct timeout once draft 2 is out
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.push(packet.into());
}
if let Some(builders) = extension_builders {
let extensions: Vec<_> = builders
.iter()
.map(|x| x.build_to_extension(Role::Server))
.collect();
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.push(packet.into());
}
}
}
@ -574,12 +573,10 @@ pub struct ClientMux {
impl ClientMux {
/// Create a new client side multiplexor.
///
/// If either extensions or extension_builders are None a Wisp v1 connection is created
/// otherwise a Wisp v2 connection is created.
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
pub async fn new<R, W>(
mut read: R,
write: W,
extensions: Option<Vec<AnyProtocolExtension>>,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where
@ -596,28 +593,29 @@ impl ClientMux {
let mut extra_packet = Vec::with_capacity(1);
let mut downgraded = true;
if let Some(extensions) = extensions {
if let Some(builders) = extension_builders {
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
// TODO change this to correct timeout once draft 2 is out
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
downgraded = false;
} else {
extra_packet.push(packet.into());
}
if let Some(builders) = extension_builders {
let extensions: Vec<_> = builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect();
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
downgraded = false;
} else {
extra_packet.push(packet.into());
}
}
}