mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
ed25519 auth
This commit is contained in:
parent
a1963f53f1
commit
b2435b554a
11 changed files with 450 additions and 118 deletions
75
Cargo.lock
generated
75
Cargo.lock
generated
|
@ -252,6 +252,12 @@ version = "0.22.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.6.0"
|
||||
|
@ -421,6 +427,12 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "const-oid"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.13"
|
||||
|
@ -484,6 +496,17 @@ version = "2.6.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0"
|
||||
dependencies = [
|
||||
"const-oid",
|
||||
"pem-rfc7468",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
@ -494,6 +517,17 @@ dependencies = [
|
|||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ed25519"
|
||||
version = "2.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53"
|
||||
dependencies = [
|
||||
"pkcs8",
|
||||
"signature",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
|
@ -1387,6 +1421,15 @@ version = "1.0.15"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "pem-rfc7468"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
|
@ -1425,6 +1468,16 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "pkcs8"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7"
|
||||
dependencies = [
|
||||
"der",
|
||||
"spki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.20"
|
||||
|
@ -1801,6 +1854,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signature"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simdutf8"
|
||||
version = "0.1.4"
|
||||
|
@ -1869,6 +1931,16 @@ dependencies = [
|
|||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
"der",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
@ -2597,12 +2669,15 @@ version = "5.1.0"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"atomic_enum",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"ed25519",
|
||||
"event-listener",
|
||||
"fastwebsockets",
|
||||
"flume",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"getrandom",
|
||||
"nohash-hasher",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
|
|
|
@ -109,7 +109,7 @@ impl StreamProvider {
|
|||
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
|
||||
vec![Box::new(UdpProtocolExtensionBuilder)];
|
||||
let extensions = if self.wisp_v2 {
|
||||
Some(extensions_vec.as_slice())
|
||||
Some(extensions_vec)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
|
@ -230,7 +230,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
|
||||
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions)
|
||||
.await
|
||||
.context("failed to create server multiplexor")?
|
||||
.with_no_required_extensions();
|
||||
|
|
|
@ -163,7 +163,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
.await?
|
||||
.with_no_required_extensions()
|
||||
} else {
|
||||
ClientMux::create(rx, tx, Some(extensions.as_slice()))
|
||||
ClientMux::create(rx, tx, Some(extensions))
|
||||
.await?
|
||||
.with_required_extensions(extension_ids.as_slice())
|
||||
.await?
|
||||
|
|
|
@ -13,21 +13,25 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket
|
|||
[dependencies]
|
||||
async-trait = "0.1.81"
|
||||
atomic_enum = "0.3.0"
|
||||
bitflags = { version = "2.6.0", optional = true, features = ["std"] }
|
||||
bytes = "1.7.1"
|
||||
ed25519 = { version = "2.2.3", optional = true, features = ["pem", "zeroize"] }
|
||||
event-listener = "5.3.1"
|
||||
fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true }
|
||||
flume = "0.11.0"
|
||||
futures = "0.3.30"
|
||||
futures-timer = "3.0.3"
|
||||
getrandom = { version = "0.2.15", features = ["std"], optional = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
pin-project-lite = "0.2.14"
|
||||
tokio = { version = "1.39.3", optional = true, default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["generic_stream"]
|
||||
default = ["generic_stream", "certificate"]
|
||||
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
||||
generic_stream = []
|
||||
wasm = ["futures-timer/wasm-bindgen"]
|
||||
wasm = ["futures-timer/wasm-bindgen", "getrandom/js"]
|
||||
certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
|
319
wisp/src/extensions/cert.rs
Normal file
319
wisp/src/extensions/cert.rs
Normal file
|
@ -0,0 +1,319 @@
|
|||
//! Certificate authentication protocol extension.
|
||||
//!
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use ed25519::{
|
||||
signature::{Signer, Verifier},
|
||||
Signature,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||
|
||||
/// Certificate authentication protocol extension error.
|
||||
#[derive(Debug)]
|
||||
pub enum CertAuthError {
|
||||
/// Invalid or unsupported certificate type
|
||||
InvalidCertType,
|
||||
/// Invalid signature
|
||||
InvalidSignature,
|
||||
/// ED25519 error
|
||||
Ed25519(ed25519::Error),
|
||||
/// Getrandom error
|
||||
Getrandom(getrandom::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CertAuthError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::InvalidCertType => write!(f, "Invalid or unsupported certificate type"),
|
||||
Self::InvalidSignature => write!(f, "Invalid signature"),
|
||||
Self::Ed25519(x) => write!(f, "ED25519: {:?}", x),
|
||||
Self::Getrandom(x) => write!(f, "getrandom: {:?}", x),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::error::Error for CertAuthError {}
|
||||
|
||||
impl From<ed25519::Error> for CertAuthError {
|
||||
fn from(value: ed25519::Error) -> Self {
|
||||
CertAuthError::Ed25519(value)
|
||||
}
|
||||
}
|
||||
impl From<getrandom::Error> for CertAuthError {
|
||||
fn from(value: getrandom::Error) -> Self {
|
||||
CertAuthError::Getrandom(value)
|
||||
}
|
||||
}
|
||||
impl From<CertAuthError> for WispError {
|
||||
fn from(value: CertAuthError) -> Self {
|
||||
WispError::ExtensionImplError(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Supported certificate types for certificate authentication.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct SupportedCertificateTypes: u8 {
|
||||
/// ED25519 certificate.
|
||||
const Ed25519 = 0b00000001;
|
||||
}
|
||||
}
|
||||
|
||||
/// Verification key.
|
||||
#[derive(Clone)]
|
||||
pub struct VerifyKey {
|
||||
/// Certificate type of the keypair.
|
||||
pub cert_type: SupportedCertificateTypes,
|
||||
/// SHA-512 hash of the public key.
|
||||
pub hash: [u8; 64],
|
||||
/// Verifier.
|
||||
pub verifier: Arc<dyn Verifier<Signature>>,
|
||||
}
|
||||
|
||||
/// Signing key.
|
||||
#[derive(Clone)]
|
||||
pub struct SigningKey {
|
||||
/// Certificate type of the keypair.
|
||||
pub cert_type: SupportedCertificateTypes,
|
||||
/// SHA-512 hash of the public key.
|
||||
pub hash: [u8; 64],
|
||||
/// Signer.
|
||||
pub signer: Arc<dyn Signer<Signature>>,
|
||||
}
|
||||
|
||||
/// Certificate authentication protocol extension.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CertAuthProtocolExtension {
|
||||
/// Server variant of certificate authentication protocol extension.
|
||||
Server {
|
||||
/// Supported certificate types on the server.
|
||||
cert_types: SupportedCertificateTypes,
|
||||
/// Random challenge for the client.
|
||||
challenge: Bytes,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension.
|
||||
Client {
|
||||
/// Chosen certificate type.
|
||||
cert_type: SupportedCertificateTypes,
|
||||
/// Hash of public key.
|
||||
hash: [u8; 64],
|
||||
/// Signature of challenge.
|
||||
signature: Bytes,
|
||||
},
|
||||
/// Marker that client has successfully signed the challenge.
|
||||
ClientSigned,
|
||||
/// Marker that server has successfully verified the client.
|
||||
ServerVerified,
|
||||
}
|
||||
|
||||
impl CertAuthProtocolExtension {
|
||||
/// ID of certificate authentication protocol extension.
|
||||
pub const ID: u8 = 0x03;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtension for CertAuthProtocolExtension {
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
fn get_congestion_stream_types(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn encode(&self) -> Bytes {
|
||||
match self {
|
||||
Self::Server {
|
||||
cert_types,
|
||||
challenge,
|
||||
} => {
|
||||
let mut out = BytesMut::with_capacity(1 + challenge.len());
|
||||
out.put_u8(cert_types.bits());
|
||||
out.extend_from_slice(challenge);
|
||||
out.freeze()
|
||||
}
|
||||
Self::Client {
|
||||
cert_type,
|
||||
hash,
|
||||
signature,
|
||||
} => {
|
||||
let mut out = BytesMut::with_capacity(1 + signature.len());
|
||||
out.put_u8(cert_type.bits());
|
||||
out.extend_from_slice(hash);
|
||||
out.extend_from_slice(signature);
|
||||
out.freeze()
|
||||
}
|
||||
Self::ClientSigned => Bytes::new(),
|
||||
Self::ServerVerified => Bytes::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CertAuthProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: CertAuthProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Certificate authentication protocol extension builder.
|
||||
pub enum CertAuthProtocolExtensionBuilder {
|
||||
/// Server variant of certificate authentication protocol extension before the challenge has
|
||||
/// been sent.
|
||||
ServerBeforeChallenge {
|
||||
/// Keypair verifiers.
|
||||
verifiers: Vec<VerifyKey>,
|
||||
},
|
||||
/// Server variant of certificate authentication protocol extension after the challenge has
|
||||
/// been sent.
|
||||
ServerAfterChallenge {
|
||||
/// Keypair verifiers.
|
||||
verifiers: Vec<VerifyKey>,
|
||||
/// Challenge to verify against.
|
||||
challenge: Bytes,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension before the challenge has
|
||||
/// been recieved.
|
||||
ClientBeforeChallenge {
|
||||
/// Keypair signer.
|
||||
signer: SigningKey,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension after the challenge has
|
||||
/// been recieved.
|
||||
ClientAfterChallenge {
|
||||
/// Keypair signer.
|
||||
signer: SigningKey,
|
||||
/// Signature of challenge recieved from the server.
|
||||
signature: Bytes,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
CertAuthProtocolExtension::ID
|
||||
}
|
||||
|
||||
// client: 1
|
||||
// server: 2
|
||||
fn build_from_bytes(
|
||||
&mut self,
|
||||
mut bytes: Bytes,
|
||||
_: Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
match self {
|
||||
// server should have already sent the challenge before recieving a response to parse
|
||||
Self::ServerBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ServerAfterChallenge {
|
||||
verifiers,
|
||||
challenge,
|
||||
} => {
|
||||
// validate and parse response
|
||||
let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8())
|
||||
.ok_or(CertAuthError::InvalidCertType)?;
|
||||
let hash = bytes.split_to(64);
|
||||
let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?;
|
||||
let is_valid = verifiers
|
||||
.iter()
|
||||
.filter(|x| x.cert_type == cert_type && x.hash == *hash)
|
||||
.any(|x| x.verifier.verify(challenge, &sig).is_ok());
|
||||
|
||||
if is_valid {
|
||||
Ok(CertAuthProtocolExtension::ServerVerified.into())
|
||||
} else {
|
||||
Err(CertAuthError::InvalidSignature.into())
|
||||
}
|
||||
}
|
||||
Self::ClientBeforeChallenge { signer } => {
|
||||
// sign challenge
|
||||
let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8())
|
||||
.ok_or(CertAuthError::InvalidCertType)?;
|
||||
if !cert_types.iter().any(|x| x == signer.cert_type) {
|
||||
return Err(CertAuthError::InvalidCertType.into());
|
||||
}
|
||||
|
||||
let signed: Bytes = signer
|
||||
.signer
|
||||
.try_sign(&bytes)
|
||||
.map_err(CertAuthError::from)?
|
||||
.to_vec()
|
||||
.into();
|
||||
|
||||
*self = Self::ClientAfterChallenge {
|
||||
signer: signer.clone(),
|
||||
signature: signed,
|
||||
};
|
||||
|
||||
Ok(CertAuthProtocolExtension::ClientSigned.into())
|
||||
}
|
||||
// client has already recieved a challenge
|
||||
Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
}
|
||||
}
|
||||
|
||||
// client: 2
|
||||
// server: 1
|
||||
fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
match self {
|
||||
Self::ServerBeforeChallenge { verifiers } => {
|
||||
let mut challenge = BytesMut::with_capacity(64);
|
||||
getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?;
|
||||
let challenge = challenge.freeze();
|
||||
|
||||
*self = Self::ServerAfterChallenge {
|
||||
verifiers: verifiers.to_vec(),
|
||||
challenge: challenge.clone(),
|
||||
};
|
||||
|
||||
Ok(CertAuthProtocolExtension::Server {
|
||||
cert_types: SupportedCertificateTypes::Ed25519,
|
||||
challenge,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
// server has already sent a challenge
|
||||
Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
// client needs to recieve a challenge
|
||||
Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ClientAfterChallenge { signer, signature } => {
|
||||
Ok(CertAuthProtocolExtension::Client {
|
||||
cert_type: signer.cert_type,
|
||||
hash: signer.hash,
|
||||
signature: signature.clone(),
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
//! Wisp protocol extensions.
|
||||
pub mod password;
|
||||
pub mod udp;
|
||||
#[cfg(feature = "certificate")]
|
||||
pub mod cert;
|
||||
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
|
@ -102,9 +104,9 @@ pub trait ProtocolExtensionBuilder {
|
|||
fn get_id(&self) -> u8;
|
||||
|
||||
/// Build a protocol extension from the extension's metadata.
|
||||
fn build_from_bytes(&self, bytes: Bytes, role: Role)
|
||||
fn build_from_bytes(&mut self, bytes: Bytes, role: Role)
|
||||
-> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Build a protocol extension to send to the other side.
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
|
||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
||||
}
|
||||
|
|
|
@ -2,33 +2,6 @@
|
|||
//!
|
||||
//! Passwords are sent in plain text!!
|
||||
//!
|
||||
//! # Example
|
||||
//! Server:
|
||||
//! ```
|
||||
//! let mut passwords = HashMap::new();
|
||||
//! passwords.insert("user1".to_string(), "pw".to_string());
|
||||
//! let (mux, fut) = ServerMux::new(
|
||||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))])
|
||||
//! );
|
||||
//! ```
|
||||
//!
|
||||
//! Client:
|
||||
//! ```
|
||||
//! let (mux, fut) = ClientMux::new(
|
||||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[
|
||||
//! Box::new(PasswordProtocolExtensionBuilder::new_client(
|
||||
//! "user1".to_string(),
|
||||
//! "pw".to_string()
|
||||
//! ))
|
||||
//! ])
|
||||
//! );
|
||||
//! ```
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||
|
||||
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
||||
|
@ -223,7 +196,7 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
|||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
&mut self,
|
||||
mut payload: Bytes,
|
||||
role: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
|
@ -268,13 +241,13 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
|
||||
match role {
|
||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(match role {
|
||||
Role::Server => PasswordProtocolExtension::new_server(),
|
||||
Role::Client => {
|
||||
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
|
||||
}
|
||||
}
|
||||
.into()
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,5 @@
|
|||
//! UDP protocol extension.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! let (mux, fut) = ServerMux::new(
|
||||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[Box::new(UdpProtocolExtensionBuilder)])
|
||||
//! );
|
||||
//! ```
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp)
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
|
@ -84,14 +75,14 @@ impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
|||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(UdpProtocolExtension.into())
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||
UdpProtocolExtension.into()
|
||||
fn build_to_extension(&mut self, _: crate::Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(UdpProtocolExtension.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#![deny(missing_docs)]
|
||||
#![deny(missing_docs, clippy::todo)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
//! A library for easily creating [Wisp] clients and servers.
|
||||
//!
|
||||
|
@ -162,7 +162,7 @@ impl std::error::Error for WispError {}
|
|||
async fn maybe_wisp_v2<R>(
|
||||
read: &mut R,
|
||||
write: &LockedWebSocketWrite,
|
||||
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -195,25 +195,24 @@ where
|
|||
Ok((supported_extensions, extra_packet, downgraded))
|
||||
}
|
||||
|
||||
async fn send_info_packet(
|
||||
write: &LockedWebSocketWrite,
|
||||
builders: &mut [Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
) -> Result<(), WispError> {
|
||||
write
|
||||
.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter_mut()
|
||||
.map(|x| x.build_to_extension(Role::Server))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use wisp_mux::ServerMux;
|
||||
///
|
||||
/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([]));
|
||||
/// tokio::spawn(async move {
|
||||
/// if let Err(e) = fut.await {
|
||||
/// println!("error in multiplexor: {:?}", e);
|
||||
/// }
|
||||
/// });
|
||||
/// while let Some((packet, stream)) = mux.server_new_stream().await {
|
||||
/// tokio::spawn(async move {
|
||||
/// let url = format!("{}:{}", packet.destination_hostname, packet.destination_port);
|
||||
/// // do something with `url` and `packet.stream_type`
|
||||
/// });
|
||||
/// }
|
||||
/// ```
|
||||
pub struct ServerMux {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
|
@ -237,7 +236,7 @@ impl ServerMux {
|
|||
mut rx: R,
|
||||
tx: W,
|
||||
buffer_size: u32,
|
||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -249,18 +248,9 @@ impl ServerMux {
|
|||
.await?;
|
||||
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(builders) = extension_builders {
|
||||
tx.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
maybe_wisp_v2(&mut rx, &tx, builders).await?
|
||||
if let Some(mut builders) = extension_builders {
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
maybe_wisp_v2(&mut rx, &tx, &mut builders).await?
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
@ -367,7 +357,9 @@ where
|
|||
} else {
|
||||
self.0.close_extension_incompat().await?;
|
||||
self.1.await?;
|
||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||
Err(WispError::ExtensionsNotSupported(
|
||||
unsupported_extensions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -379,19 +371,6 @@ where
|
|||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use wisp_mux::{ClientMux, StreamType};
|
||||
///
|
||||
/// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?;
|
||||
/// tokio::spawn(async move {
|
||||
/// if let Err(e) = fut.await {
|
||||
/// println!("error in multiplexor: {:?}", e);
|
||||
/// }
|
||||
/// });
|
||||
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
|
||||
/// ```
|
||||
pub struct ClientMux {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
|
@ -413,7 +392,7 @@ impl ClientMux {
|
|||
pub async fn create<R, W>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||
extension_builders: Option<Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>>>,
|
||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -428,22 +407,13 @@ impl ClientMux {
|
|||
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(builders) = extension_builders {
|
||||
let x = maybe_wisp_v2(&mut rx, &tx, builders).await?;
|
||||
if let Some(mut builders) = extension_builders {
|
||||
let res = maybe_wisp_v2(&mut rx, &tx, &mut builders).await?;
|
||||
// if not downgraded
|
||||
if !x.2 {
|
||||
tx.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
if !res.2 {
|
||||
send_info_packet(&tx, &mut builders).await?;
|
||||
}
|
||||
x
|
||||
res
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
|
|
@ -428,7 +428,7 @@ impl<'a> Packet<'a> {
|
|||
pub(crate) fn maybe_parse_info(
|
||||
frame: Frame<'a>,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
|
@ -502,7 +502,7 @@ impl<'a> Packet<'a> {
|
|||
fn parse_info(
|
||||
mut bytes: Payload<'a>,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
// packet type is already read by code that calls this
|
||||
if bytes.remaining() < 4 + 2 {
|
||||
|
@ -530,10 +530,8 @@ impl<'a> Packet<'a> {
|
|||
if bytes.remaining() < length {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
|
||||
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
|
||||
extensions.push(extension)
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
|
||||
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?)
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue