mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
add a new Payload struct to allow for one-copy writes and cargo fmt
This commit is contained in:
parent
314c1bfa75
commit
d6353bd5a9
18 changed files with 3533 additions and 3395 deletions
|
@ -8,8 +8,8 @@ use async_trait::async_trait;
|
|||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
/// Type-erased protocol extension that implements Clone.
|
||||
|
@ -17,90 +17,90 @@ use crate::{
|
|||
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);
|
||||
|
||||
impl AnyProtocolExtension {
|
||||
/// Create a new type-erased protocol extension.
|
||||
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
|
||||
Self(Box::new(extension))
|
||||
}
|
||||
/// Create a new type-erased protocol extension.
|
||||
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
|
||||
Self(Box::new(extension))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for AnyProtocolExtension {
|
||||
type Target = dyn ProtocolExtension;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
type Target = dyn ProtocolExtension;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for AnyProtocolExtension {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0.deref_mut()
|
||||
}
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AnyProtocolExtension {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.box_clone())
|
||||
}
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.box_clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AnyProtocolExtension> for Bytes {
|
||||
fn from(value: AnyProtocolExtension) -> Self {
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
let payload = value.encode();
|
||||
bytes.put_u8(value.get_id());
|
||||
bytes.put_u32_le(payload.len() as u32);
|
||||
bytes.extend(payload);
|
||||
bytes.freeze()
|
||||
}
|
||||
fn from(value: AnyProtocolExtension) -> Self {
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
let payload = value.encode();
|
||||
bytes.put_u8(value.get_id());
|
||||
bytes.put_u32_le(payload.len() as u32);
|
||||
bytes.extend(payload);
|
||||
bytes.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
/// A Wisp protocol extension.
|
||||
///
|
||||
/// See [the
|
||||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions).
|
||||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
|
||||
#[async_trait]
|
||||
pub trait ProtocolExtension: std::fmt::Debug {
|
||||
/// Get the protocol extension ID.
|
||||
fn get_id(&self) -> u8;
|
||||
/// Get the protocol extension's supported packets.
|
||||
///
|
||||
/// Used to decide whether to call the protocol extension's packet handler.
|
||||
fn get_supported_packets(&self) -> &'static [u8];
|
||||
/// Get the protocol extension ID.
|
||||
fn get_id(&self) -> u8;
|
||||
/// Get the protocol extension's supported packets.
|
||||
///
|
||||
/// Used to decide whether to call the protocol extension's packet handler.
|
||||
fn get_supported_packets(&self) -> &'static [u8];
|
||||
|
||||
/// Encode self into Bytes.
|
||||
fn encode(&self) -> Bytes;
|
||||
/// Encode self into Bytes.
|
||||
fn encode(&self) -> Bytes;
|
||||
|
||||
/// Handle the handshake part of a Wisp connection.
|
||||
///
|
||||
/// This should be used to send or receive data before any streams are created.
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
/// Handle the handshake part of a Wisp connection.
|
||||
///
|
||||
/// This should be used to send or receive data before any streams are created.
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
/// Handle receiving a packet.
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
packet: Bytes,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
/// Handle receiving a packet.
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
packet: Bytes,
|
||||
read: &mut dyn WebSocketRead,
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
/// Clone the protocol extension.
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
||||
/// Clone the protocol extension.
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
||||
}
|
||||
|
||||
/// Trait to build a Wisp protocol extension from a payload.
|
||||
pub trait ProtocolExtensionBuilder {
|
||||
/// Get the protocol extension ID.
|
||||
///
|
||||
/// Used to decide whether this builder should be used.
|
||||
fn get_id(&self) -> u8;
|
||||
/// Get the protocol extension ID.
|
||||
///
|
||||
/// Used to decide whether this builder should be used.
|
||||
fn get_id(&self) -> u8;
|
||||
|
||||
/// Build a protocol extension from the extension's metadata.
|
||||
fn build_from_bytes(&self, bytes: Bytes, role: Role)
|
||||
-> Result<AnyProtocolExtension, WispError>;
|
||||
/// Build a protocol extension from the extension's metadata.
|
||||
fn build_from_bytes(&self, bytes: Bytes, role: Role)
|
||||
-> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Build a protocol extension to send to the other side.
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
|
||||
/// Build a protocol extension to send to the other side.
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
//! ])
|
||||
//! );
|
||||
//! ```
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication)
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||
|
||||
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
||||
|
||||
|
@ -37,8 +37,8 @@ use async_trait::async_trait;
|
|||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||
|
@ -50,227 +50,227 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
|||
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
||||
/// or the password's length does not fit within a u16.**
|
||||
pub struct PasswordProtocolExtension {
|
||||
/// The username to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u8.
|
||||
pub username: String,
|
||||
/// The password to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u16.
|
||||
pub password: String,
|
||||
role: Role,
|
||||
/// The username to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u8.
|
||||
pub username: String,
|
||||
/// The password to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u16.
|
||||
pub password: String,
|
||||
role: Role,
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtension {
|
||||
/// Password protocol extension ID.
|
||||
pub const ID: u8 = 0x02;
|
||||
/// Password protocol extension ID.
|
||||
pub const ID: u8 = 0x02;
|
||||
|
||||
/// Create a new password protocol extension for the server.
|
||||
///
|
||||
/// This signifies that the server requires a password.
|
||||
pub fn new_server() -> Self {
|
||||
Self {
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
role: Role::Server,
|
||||
}
|
||||
}
|
||||
/// Create a new password protocol extension for the server.
|
||||
///
|
||||
/// This signifies that the server requires a password.
|
||||
pub fn new_server() -> Self {
|
||||
Self {
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
role: Role::Server,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension for the client, with a username and password.
|
||||
///
|
||||
/// The username's length must fit within a u8. The password's length must fit within a
|
||||
/// u16.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
username,
|
||||
password,
|
||||
role: Role::Client,
|
||||
}
|
||||
}
|
||||
/// Create a new password protocol extension for the client, with a username and password.
|
||||
///
|
||||
/// The username's length must fit within a u8. The password's length must fit within a
|
||||
/// u16.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
username,
|
||||
password,
|
||||
role: Role::Client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtension for PasswordProtocolExtension {
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn encode(&self) -> Bytes {
|
||||
match self.role {
|
||||
Role::Server => Bytes::new(),
|
||||
Role::Client => {
|
||||
let username = Bytes::from(self.username.clone().into_bytes());
|
||||
let password = Bytes::from(self.password.clone().into_bytes());
|
||||
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||
let password_len = u16::try_from(password.len()).expect("password was too long");
|
||||
fn encode(&self) -> Bytes {
|
||||
match self.role {
|
||||
Role::Server => Bytes::new(),
|
||||
Role::Client => {
|
||||
let username = Bytes::from(self.username.clone().into_bytes());
|
||||
let password = Bytes::from(self.password.clone().into_bytes());
|
||||
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||
let password_len = u16::try_from(password.len()).expect("password was too long");
|
||||
|
||||
let mut bytes =
|
||||
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||
bytes.put_u8(username_len);
|
||||
bytes.put_u16_le(password_len);
|
||||
bytes.extend(username);
|
||||
bytes.extend(password);
|
||||
bytes.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut bytes =
|
||||
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||
bytes.put_u8(username_len);
|
||||
bytes.put_u16_le(password_len);
|
||||
bytes.extend(username);
|
||||
bytes.extend(password);
|
||||
bytes.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PasswordProtocolExtensionError {
|
||||
Utf8Error(FromUtf8Error),
|
||||
InvalidUsername,
|
||||
InvalidPassword,
|
||||
Utf8Error(FromUtf8Error),
|
||||
InvalidUsername,
|
||||
InvalidPassword,
|
||||
}
|
||||
|
||||
impl Display for PasswordProtocolExtensionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use PasswordProtocolExtensionError as E;
|
||||
match self {
|
||||
E::Utf8Error(e) => write!(f, "{}", e),
|
||||
E::InvalidUsername => write!(f, "Invalid username"),
|
||||
E::InvalidPassword => write!(f, "Invalid password"),
|
||||
}
|
||||
}
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use PasswordProtocolExtensionError as E;
|
||||
match self {
|
||||
E::Utf8Error(e) => write!(f, "{}", e),
|
||||
E::InvalidUsername => write!(f, "Invalid username"),
|
||||
E::InvalidPassword => write!(f, "Invalid password"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PasswordProtocolExtensionError {}
|
||||
|
||||
impl From<PasswordProtocolExtensionError> for WispError {
|
||||
fn from(value: PasswordProtocolExtensionError) -> Self {
|
||||
WispError::ExtensionImplError(Box::new(value))
|
||||
}
|
||||
fn from(value: PasswordProtocolExtensionError) -> Self {
|
||||
WispError::ExtensionImplError(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FromUtf8Error> for PasswordProtocolExtensionError {
|
||||
fn from(value: FromUtf8Error) -> Self {
|
||||
PasswordProtocolExtensionError::Utf8Error(value)
|
||||
}
|
||||
fn from(value: FromUtf8Error) -> Self {
|
||||
PasswordProtocolExtensionError::Utf8Error(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: PasswordProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
fn from(value: PasswordProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Password protocol extension builder.
|
||||
///
|
||||
/// **Passwords are sent in plain text!!**
|
||||
pub struct PasswordProtocolExtensionBuilder {
|
||||
/// Map of users and their passwords to allow. Only used on server.
|
||||
pub users: HashMap<String, String>,
|
||||
/// Username to authenticate with. Only used on client.
|
||||
pub username: String,
|
||||
/// Password to authenticate with. Only used on client.
|
||||
pub password: String,
|
||||
/// Map of users and their passwords to allow. Only used on server.
|
||||
pub users: HashMap<String, String>,
|
||||
/// Username to authenticate with. Only used on client.
|
||||
pub username: String,
|
||||
/// Password to authenticate with. Only used on client.
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtensionBuilder {
|
||||
/// Create a new password protocol extension builder for the server, with a map of users
|
||||
/// and passwords to allow.
|
||||
pub fn new_server(users: HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
}
|
||||
}
|
||||
/// Create a new password protocol extension builder for the server, with a map of users
|
||||
/// and passwords to allow.
|
||||
pub fn new_server(users: HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension builder for the client, with a username and
|
||||
/// password to authenticate with.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
users: HashMap::new(),
|
||||
username,
|
||||
password,
|
||||
}
|
||||
}
|
||||
/// Create a new password protocol extension builder for the client, with a username and
|
||||
/// password to authenticate with.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
users: HashMap::new(),
|
||||
username,
|
||||
password,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
PasswordProtocolExtension::ID
|
||||
}
|
||||
fn get_id(&self) -> u8 {
|
||||
PasswordProtocolExtension::ID
|
||||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
mut payload: Bytes,
|
||||
role: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
match role {
|
||||
Role::Server => {
|
||||
if payload.remaining() < 3 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
mut payload: Bytes,
|
||||
role: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
match role {
|
||||
Role::Server => {
|
||||
if payload.remaining() < 3 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
||||
let username_len = payload.get_u8();
|
||||
let password_len = payload.get_u16_le();
|
||||
if payload.remaining() < (password_len + username_len as u16) as usize {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let username_len = payload.get_u8();
|
||||
let password_len = payload.get_u16_le();
|
||||
if payload.remaining() < (password_len + username_len as u16) as usize {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
||||
use PasswordProtocolExtensionError as EError;
|
||||
let username =
|
||||
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
let password =
|
||||
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
use PasswordProtocolExtensionError as EError;
|
||||
let username =
|
||||
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
let password =
|
||||
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
|
||||
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
|
||||
return Err(EError::InvalidUsername.into());
|
||||
};
|
||||
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
|
||||
return Err(EError::InvalidUsername.into());
|
||||
};
|
||||
|
||||
if *user.1 != password {
|
||||
return Err(EError::InvalidPassword.into());
|
||||
}
|
||||
if *user.1 != password {
|
||||
return Err(EError::InvalidPassword.into());
|
||||
}
|
||||
|
||||
Ok(PasswordProtocolExtension {
|
||||
username,
|
||||
password,
|
||||
role,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
Role::Client => {
|
||||
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(PasswordProtocolExtension {
|
||||
username,
|
||||
password,
|
||||
role,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
Role::Client => {
|
||||
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
|
||||
match role {
|
||||
Role::Server => PasswordProtocolExtension::new_server(),
|
||||
Role::Client => {
|
||||
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
|
||||
}
|
||||
}
|
||||
.into()
|
||||
}
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
|
||||
match role {
|
||||
Role::Server => PasswordProtocolExtension::new_server(),
|
||||
Role::Client => {
|
||||
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
|
||||
}
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,88 +6,88 @@
|
|||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[Box::new(UdpProtocolExtensionBuilder())])
|
||||
//! Some(&[Box::new(UdpProtocolExtensionBuilder)])
|
||||
//! );
|
||||
//! ```
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp)
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp)
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
WispError,
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
WispError,
|
||||
};
|
||||
|
||||
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// UDP protocol extension.
|
||||
pub struct UdpProtocolExtension();
|
||||
pub struct UdpProtocolExtension;
|
||||
|
||||
impl UdpProtocolExtension {
|
||||
/// UDP protocol extension ID.
|
||||
pub const ID: u8 = 0x01;
|
||||
/// UDP protocol extension ID.
|
||||
pub const ID: u8 = 0x01;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtension for UdpProtocolExtension {
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn encode(&self) -> Bytes {
|
||||
Bytes::new()
|
||||
}
|
||||
fn encode(&self) -> Bytes {
|
||||
Bytes::new()
|
||||
}
|
||||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(Self())
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UdpProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: UdpProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
fn from(value: UdpProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// UDP protocol extension builder.
|
||||
pub struct UdpProtocolExtensionBuilder();
|
||||
pub struct UdpProtocolExtensionBuilder;
|
||||
|
||||
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
UdpProtocolExtension::ID
|
||||
}
|
||||
fn get_id(&self) -> u8 {
|
||||
UdpProtocolExtension::ID
|
||||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
_: Bytes,
|
||||
_: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(UdpProtocolExtension().into())
|
||||
}
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
_: Bytes,
|
||||
_: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(UdpProtocolExtension.into())
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||
UdpProtocolExtension().into()
|
||||
}
|
||||
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||
UdpProtocolExtension.into()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,93 +3,100 @@ use std::ops::Deref;
|
|||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use fastwebsockets::{
|
||||
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::{ws::LockedWebSocketWrite, WispError};
|
||||
|
||||
fn match_payload(payload: Payload) -> BytesMut {
|
||||
fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
|
||||
match payload {
|
||||
Payload::Bytes(x) => x,
|
||||
Payload::Owned(x) => BytesMut::from(x.deref()),
|
||||
Payload::BorrowedMut(x) => BytesMut::from(x.deref()),
|
||||
Payload::Borrowed(x) => BytesMut::from(x),
|
||||
Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
|
||||
Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
|
||||
Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x),
|
||||
Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x),
|
||||
}
|
||||
}
|
||||
|
||||
fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> {
|
||||
match payload {
|
||||
crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
|
||||
crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpCode> for crate::ws::OpCode {
|
||||
fn from(opcode: OpCode) -> Self {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => {
|
||||
unreachable!("continuation should never be recieved when using a fragmentcollector")
|
||||
}
|
||||
Text => Self::Text,
|
||||
Binary => Self::Binary,
|
||||
Close => Self::Close,
|
||||
Ping => Self::Ping,
|
||||
Pong => Self::Pong,
|
||||
}
|
||||
}
|
||||
fn from(opcode: OpCode) -> Self {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => {
|
||||
unreachable!("continuation should never be recieved when using a fragmentcollector")
|
||||
}
|
||||
Text => Self::Text,
|
||||
Binary => Self::Binary,
|
||||
Close => Self::Close,
|
||||
Ping => Self::Ping,
|
||||
Pong => Self::Pong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Frame<'_>> for crate::ws::Frame {
|
||||
fn from(frame: Frame) -> Self {
|
||||
Self {
|
||||
finished: frame.fin,
|
||||
opcode: frame.opcode.into(),
|
||||
payload: match_payload(frame.payload),
|
||||
}
|
||||
}
|
||||
impl<'a> From<Frame<'a>> for crate::ws::Frame<'a> {
|
||||
fn from(frame: Frame<'a>) -> Self {
|
||||
Self {
|
||||
finished: frame.fin,
|
||||
opcode: frame.opcode.into(),
|
||||
payload: match_payload(frame.payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<crate::ws::Frame> for Frame<'a> {
|
||||
fn from(frame: crate::ws::Frame) -> Self {
|
||||
use crate::ws::OpCode::*;
|
||||
let payload = Payload::Bytes(frame.payload);
|
||||
match frame.opcode {
|
||||
Text => Self::text(payload),
|
||||
Binary => Self::binary(payload),
|
||||
Close => Self::close_raw(payload),
|
||||
Ping => Self::new(true, OpCode::Ping, None, payload),
|
||||
Pong => Self::pong(payload),
|
||||
}
|
||||
}
|
||||
impl<'a> From<crate::ws::Frame<'a>> for Frame<'a> {
|
||||
fn from(frame: crate::ws::Frame<'a>) -> Self {
|
||||
use crate::ws::OpCode::*;
|
||||
let payload = match_payload_reverse(frame.payload);
|
||||
match frame.opcode {
|
||||
Text => Self::text(payload),
|
||||
Binary => Self::binary(payload),
|
||||
Close => Self::close_raw(payload),
|
||||
Ping => Self::new(true, OpCode::Ping, None, payload),
|
||||
Pong => Self::pong(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WebSocketError> for crate::WispError {
|
||||
fn from(err: WebSocketError) -> Self {
|
||||
if let WebSocketError::ConnectionClosed = err {
|
||||
Self::WsImplSocketClosed
|
||||
} else {
|
||||
Self::WsImplError(Box::new(err))
|
||||
}
|
||||
}
|
||||
fn from(err: WebSocketError) -> Self {
|
||||
if let WebSocketError::ConnectionClosed = err {
|
||||
Self::WsImplSocketClosed
|
||||
} else {
|
||||
Self::WsImplError(Box::new(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<crate::ws::Frame, WispError> {
|
||||
Ok(self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.await?
|
||||
.into())
|
||||
}
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<crate::ws::Frame<'static>, WispError> {
|
||||
Ok(self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.await?
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
}
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
|
1438
wisp/src/lib.rs
1438
wisp/src/lib.rs
File diff suppressed because it is too large
Load diff
|
@ -1,41 +1,41 @@
|
|||
use crate::{
|
||||
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead},
|
||||
Role, WispError, WISP_VERSION,
|
||||
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||
Role, WispError, WISP_VERSION,
|
||||
};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
/// Wisp stream type.
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub enum StreamType {
|
||||
/// TCP Wisp stream.
|
||||
Tcp,
|
||||
/// UDP Wisp stream.
|
||||
Udp,
|
||||
/// Unknown Wisp stream type used for custom streams by protocol extensions.
|
||||
Unknown(u8),
|
||||
/// TCP Wisp stream.
|
||||
Tcp,
|
||||
/// UDP Wisp stream.
|
||||
Udp,
|
||||
/// Unknown Wisp stream type used for custom streams by protocol extensions.
|
||||
Unknown(u8),
|
||||
}
|
||||
|
||||
impl From<u8> for StreamType {
|
||||
fn from(value: u8) -> Self {
|
||||
use StreamType as S;
|
||||
match value {
|
||||
0x01 => S::Tcp,
|
||||
0x02 => S::Udp,
|
||||
x => S::Unknown(x),
|
||||
}
|
||||
}
|
||||
fn from(value: u8) -> Self {
|
||||
use StreamType as S;
|
||||
match value {
|
||||
0x01 => S::Tcp,
|
||||
0x02 => S::Udp,
|
||||
x => S::Unknown(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StreamType> for u8 {
|
||||
fn from(value: StreamType) -> Self {
|
||||
use StreamType as S;
|
||||
match value {
|
||||
S::Tcp => 0x01,
|
||||
S::Udp => 0x02,
|
||||
S::Unknown(x) => x,
|
||||
}
|
||||
}
|
||||
fn from(value: StreamType) -> Self {
|
||||
use StreamType as S;
|
||||
match value {
|
||||
S::Tcp => 0x01,
|
||||
S::Udp => 0x02,
|
||||
S::Unknown(x) => x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Close reason.
|
||||
|
@ -44,56 +44,56 @@ impl From<StreamType> for u8 {
|
|||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons)
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub enum CloseReason {
|
||||
/// Reason unspecified or unknown.
|
||||
Unknown = 0x01,
|
||||
/// Voluntary stream closure.
|
||||
Voluntary = 0x02,
|
||||
/// Unexpected stream closure due to a network error.
|
||||
Unexpected = 0x03,
|
||||
/// Incompatible extensions. Only used during the handshake.
|
||||
IncompatibleExtensions = 0x04,
|
||||
/// Stream creation failed due to invalid information.
|
||||
ServerStreamInvalidInfo = 0x41,
|
||||
/// Stream creation failed due to an unreachable destination host.
|
||||
ServerStreamUnreachable = 0x42,
|
||||
/// Stream creation timed out due to the destination server not responding.
|
||||
ServerStreamConnectionTimedOut = 0x43,
|
||||
/// Stream creation failed due to the destination server refusing the connection.
|
||||
ServerStreamConnectionRefused = 0x44,
|
||||
/// TCP data transfer timed out.
|
||||
ServerStreamTimedOut = 0x47,
|
||||
/// Stream destination address/domain is intentionally blocked by the proxy server.
|
||||
ServerStreamBlockedAddress = 0x48,
|
||||
/// Connection throttled by the server.
|
||||
ServerStreamThrottled = 0x49,
|
||||
/// The client has encountered an unexpected error.
|
||||
ClientUnexpected = 0x81,
|
||||
/// Reason unspecified or unknown.
|
||||
Unknown = 0x01,
|
||||
/// Voluntary stream closure.
|
||||
Voluntary = 0x02,
|
||||
/// Unexpected stream closure due to a network error.
|
||||
Unexpected = 0x03,
|
||||
/// Incompatible extensions. Only used during the handshake.
|
||||
IncompatibleExtensions = 0x04,
|
||||
/// Stream creation failed due to invalid information.
|
||||
ServerStreamInvalidInfo = 0x41,
|
||||
/// Stream creation failed due to an unreachable destination host.
|
||||
ServerStreamUnreachable = 0x42,
|
||||
/// Stream creation timed out due to the destination server not responding.
|
||||
ServerStreamConnectionTimedOut = 0x43,
|
||||
/// Stream creation failed due to the destination server refusing the connection.
|
||||
ServerStreamConnectionRefused = 0x44,
|
||||
/// TCP data transfer timed out.
|
||||
ServerStreamTimedOut = 0x47,
|
||||
/// Stream destination address/domain is intentionally blocked by the proxy server.
|
||||
ServerStreamBlockedAddress = 0x48,
|
||||
/// Connection throttled by the server.
|
||||
ServerStreamThrottled = 0x49,
|
||||
/// The client has encountered an unexpected error.
|
||||
ClientUnexpected = 0x81,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for CloseReason {
|
||||
type Error = WispError;
|
||||
fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
|
||||
use CloseReason as R;
|
||||
match close_reason {
|
||||
0x01 => Ok(R::Unknown),
|
||||
0x02 => Ok(R::Voluntary),
|
||||
0x03 => Ok(R::Unexpected),
|
||||
0x04 => Ok(R::IncompatibleExtensions),
|
||||
0x41 => Ok(R::ServerStreamInvalidInfo),
|
||||
0x42 => Ok(R::ServerStreamUnreachable),
|
||||
0x43 => Ok(R::ServerStreamConnectionTimedOut),
|
||||
0x44 => Ok(R::ServerStreamConnectionRefused),
|
||||
0x47 => Ok(R::ServerStreamTimedOut),
|
||||
0x48 => Ok(R::ServerStreamBlockedAddress),
|
||||
0x49 => Ok(R::ServerStreamThrottled),
|
||||
0x81 => Ok(R::ClientUnexpected),
|
||||
_ => Err(Self::Error::InvalidCloseReason),
|
||||
}
|
||||
}
|
||||
type Error = WispError;
|
||||
fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
|
||||
use CloseReason as R;
|
||||
match close_reason {
|
||||
0x01 => Ok(R::Unknown),
|
||||
0x02 => Ok(R::Voluntary),
|
||||
0x03 => Ok(R::Unexpected),
|
||||
0x04 => Ok(R::IncompatibleExtensions),
|
||||
0x41 => Ok(R::ServerStreamInvalidInfo),
|
||||
0x42 => Ok(R::ServerStreamUnreachable),
|
||||
0x43 => Ok(R::ServerStreamConnectionTimedOut),
|
||||
0x44 => Ok(R::ServerStreamConnectionRefused),
|
||||
0x47 => Ok(R::ServerStreamTimedOut),
|
||||
0x48 => Ok(R::ServerStreamBlockedAddress),
|
||||
0x49 => Ok(R::ServerStreamThrottled),
|
||||
0x81 => Ok(R::ClientUnexpected),
|
||||
_ => Err(Self::Error::InvalidCloseReason),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Encode {
|
||||
fn encode(self, bytes: &mut BytesMut);
|
||||
fn encode(self, bytes: &mut BytesMut);
|
||||
}
|
||||
|
||||
/// Packet used to create a new stream.
|
||||
|
@ -101,49 +101,49 @@ trait Encode {
|
|||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectPacket {
|
||||
/// Whether the new stream should use a TCP or UDP socket.
|
||||
pub stream_type: StreamType,
|
||||
/// Destination TCP/UDP port for the new stream.
|
||||
pub destination_port: u16,
|
||||
/// Destination hostname, in a UTF-8 string.
|
||||
pub destination_hostname: String,
|
||||
/// Whether the new stream should use a TCP or UDP socket.
|
||||
pub stream_type: StreamType,
|
||||
/// Destination TCP/UDP port for the new stream.
|
||||
pub destination_port: u16,
|
||||
/// Destination hostname, in a UTF-8 string.
|
||||
pub destination_hostname: String,
|
||||
}
|
||||
|
||||
impl ConnectPacket {
|
||||
/// Create a new connect packet.
|
||||
pub fn new(
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
}
|
||||
}
|
||||
/// Create a new connect packet.
|
||||
pub fn new(
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<BytesMut> for ConnectPacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < (1 + 2) {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
stream_type: bytes.get_u8().into(),
|
||||
destination_port: bytes.get_u16_le(),
|
||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||
})
|
||||
}
|
||||
impl TryFrom<Payload<'_>> for ConnectPacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < (1 + 2) {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
stream_type: bytes.get_u8().into(),
|
||||
destination_port: bytes.get_u16_le(),
|
||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for ConnectPacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.stream_type.into());
|
||||
bytes.put_u16_le(self.destination_port);
|
||||
bytes.extend(self.destination_hostname.bytes());
|
||||
}
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.stream_type.into());
|
||||
bytes.put_u16_le(self.destination_port);
|
||||
bytes.extend(self.destination_hostname.bytes());
|
||||
}
|
||||
}
|
||||
|
||||
/// Packet used for Wisp TCP stream flow control.
|
||||
|
@ -151,33 +151,33 @@ impl Encode for ConnectPacket {
|
|||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue).
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ContinuePacket {
|
||||
/// Number of packets that the server can buffer for the current stream.
|
||||
pub buffer_remaining: u32,
|
||||
/// Number of packets that the server can buffer for the current stream.
|
||||
pub buffer_remaining: u32,
|
||||
}
|
||||
|
||||
impl ContinuePacket {
|
||||
/// Create a new continue packet.
|
||||
pub fn new(buffer_remaining: u32) -> Self {
|
||||
Self { buffer_remaining }
|
||||
}
|
||||
/// Create a new continue packet.
|
||||
pub fn new(buffer_remaining: u32) -> Self {
|
||||
Self { buffer_remaining }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<BytesMut> for ContinuePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 4 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
buffer_remaining: bytes.get_u32_le(),
|
||||
})
|
||||
}
|
||||
impl TryFrom<Payload<'_>> for ContinuePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 4 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
buffer_remaining: bytes.get_u32_le(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for ContinuePacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u32_le(self.buffer_remaining);
|
||||
}
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u32_le(self.buffer_remaining);
|
||||
}
|
||||
}
|
||||
|
||||
/// Packet used to close a stream.
|
||||
|
@ -186,42 +186,42 @@ impl Encode for ContinuePacket {
|
|||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close).
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ClosePacket {
|
||||
/// The close reason.
|
||||
pub reason: CloseReason,
|
||||
/// The close reason.
|
||||
pub reason: CloseReason,
|
||||
}
|
||||
|
||||
impl ClosePacket {
|
||||
/// Create a new close packet.
|
||||
pub fn new(reason: CloseReason) -> Self {
|
||||
Self { reason }
|
||||
}
|
||||
/// Create a new close packet.
|
||||
pub fn new(reason: CloseReason) -> Self {
|
||||
Self { reason }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<BytesMut> for ClosePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
reason: bytes.get_u8().try_into()?,
|
||||
})
|
||||
}
|
||||
impl TryFrom<Payload<'_>> for ClosePacket {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
reason: bytes.get_u8().try_into()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for ClosePacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.reason as u8);
|
||||
}
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.reason as u8);
|
||||
}
|
||||
}
|
||||
|
||||
/// Wisp version sent in the handshake.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WispVersion {
|
||||
/// Major Wisp version according to semver.
|
||||
pub major: u8,
|
||||
/// Minor Wisp version according to semver.
|
||||
pub minor: u8,
|
||||
/// Major Wisp version according to semver.
|
||||
pub major: u8,
|
||||
/// Minor Wisp version according to semver.
|
||||
pub minor: u8,
|
||||
}
|
||||
|
||||
/// Packet used in the initial handshake.
|
||||
|
@ -229,325 +229,327 @@ pub struct WispVersion {
|
|||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InfoPacket {
|
||||
/// Wisp version sent in the packet.
|
||||
pub version: WispVersion,
|
||||
/// List of protocol extensions sent in the packet.
|
||||
pub extensions: Vec<AnyProtocolExtension>,
|
||||
/// Wisp version sent in the packet.
|
||||
pub version: WispVersion,
|
||||
/// List of protocol extensions sent in the packet.
|
||||
pub extensions: Vec<AnyProtocolExtension>,
|
||||
}
|
||||
|
||||
impl Encode for InfoPacket {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.version.major);
|
||||
bytes.put_u8(self.version.minor);
|
||||
for extension in self.extensions {
|
||||
bytes.extend_from_slice(&Bytes::from(extension));
|
||||
}
|
||||
}
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.version.major);
|
||||
bytes.put_u8(self.version.minor);
|
||||
for extension in self.extensions {
|
||||
bytes.extend_from_slice(&Bytes::from(extension));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Type of packet recieved.
|
||||
pub enum PacketType {
|
||||
/// Connect packet.
|
||||
Connect(ConnectPacket),
|
||||
/// Data packet.
|
||||
Data(Bytes),
|
||||
/// Continue packet.
|
||||
Continue(ContinuePacket),
|
||||
/// Close packet.
|
||||
Close(ClosePacket),
|
||||
/// Info packet.
|
||||
Info(InfoPacket),
|
||||
pub enum PacketType<'a> {
|
||||
/// Connect packet.
|
||||
Connect(ConnectPacket),
|
||||
/// Data packet.
|
||||
Data(Payload<'a>),
|
||||
/// Continue packet.
|
||||
Continue(ContinuePacket),
|
||||
/// Close packet.
|
||||
Close(ClosePacket),
|
||||
/// Info packet.
|
||||
Info(InfoPacket),
|
||||
}
|
||||
|
||||
impl PacketType {
|
||||
/// Get the packet type used in the protocol.
|
||||
pub fn as_u8(&self) -> u8 {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(_) => 0x01,
|
||||
P::Data(_) => 0x02,
|
||||
P::Continue(_) => 0x03,
|
||||
P::Close(_) => 0x04,
|
||||
P::Info(_) => 0x05,
|
||||
}
|
||||
}
|
||||
impl PacketType<'_> {
|
||||
/// Get the packet type used in the protocol.
|
||||
pub fn as_u8(&self) -> u8 {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(_) => 0x01,
|
||||
P::Data(_) => 0x02,
|
||||
P::Continue(_) => 0x03,
|
||||
P::Close(_) => 0x04,
|
||||
P::Info(_) => 0x05,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_packet_size(&self) -> usize {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
|
||||
P::Data(p) => p.len(),
|
||||
P::Continue(_) => 4,
|
||||
P::Close(_) => 1,
|
||||
P::Info(_) => 2,
|
||||
}
|
||||
}
|
||||
pub(crate) fn get_packet_size(&self) -> usize {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
|
||||
P::Data(p) => p.len(),
|
||||
P::Continue(_) => 4,
|
||||
P::Close(_) => 1,
|
||||
P::Info(_) => 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for PacketType {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(x) => x.encode(bytes),
|
||||
P::Data(x) => bytes.extend_from_slice(&x),
|
||||
P::Continue(x) => x.encode(bytes),
|
||||
P::Close(x) => x.encode(bytes),
|
||||
P::Info(x) => x.encode(bytes),
|
||||
};
|
||||
}
|
||||
impl Encode for PacketType<'_> {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
use PacketType as P;
|
||||
match self {
|
||||
P::Connect(x) => x.encode(bytes),
|
||||
P::Data(x) => bytes.extend_from_slice(&x),
|
||||
P::Continue(x) => x.encode(bytes),
|
||||
P::Close(x) => x.encode(bytes),
|
||||
P::Info(x) => x.encode(bytes),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Wisp protocol packet.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Packet {
|
||||
/// Stream this packet is associated with.
|
||||
pub stream_id: u32,
|
||||
/// Packet type recieved.
|
||||
pub packet_type: PacketType,
|
||||
pub struct Packet<'a> {
|
||||
/// Stream this packet is associated with.
|
||||
pub stream_id: u32,
|
||||
/// Packet type recieved.
|
||||
pub packet_type: PacketType<'a>,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
/// Create a new packet.
|
||||
///
|
||||
/// The helper functions should be used for most use cases.
|
||||
pub fn new(stream_id: u32, packet: PacketType) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: packet,
|
||||
}
|
||||
}
|
||||
impl<'a> Packet<'a> {
|
||||
/// Create a new packet.
|
||||
///
|
||||
/// The helper functions should be used for most use cases.
|
||||
pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: packet,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new connect packet.
|
||||
pub fn new_connect(
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Connect(ConnectPacket::new(
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
)),
|
||||
}
|
||||
}
|
||||
/// Create a new connect packet.
|
||||
pub fn new_connect(
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Connect(ConnectPacket::new(
|
||||
stream_type,
|
||||
destination_port,
|
||||
destination_hostname,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new data packet.
|
||||
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Data(data),
|
||||
}
|
||||
}
|
||||
/// Create a new data packet.
|
||||
pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Data(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new continue packet.
|
||||
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
|
||||
}
|
||||
}
|
||||
/// Create a new continue packet.
|
||||
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new close packet.
|
||||
pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Close(ClosePacket::new(reason)),
|
||||
}
|
||||
}
|
||||
/// Create a new close packet.
|
||||
pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: PacketType::Close(ClosePacket::new(reason)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
|
||||
Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version: WISP_VERSION,
|
||||
extensions,
|
||||
}),
|
||||
}
|
||||
}
|
||||
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
|
||||
Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version: WISP_VERSION,
|
||||
extensions,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> {
|
||||
use PacketType as P;
|
||||
Ok(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: match packet_type {
|
||||
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
||||
0x02 => P::Data(bytes.freeze()),
|
||||
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
||||
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
||||
// 0x05 is handled seperately
|
||||
_ => return Err(WispError::InvalidPacketType),
|
||||
},
|
||||
})
|
||||
}
|
||||
fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
|
||||
use PacketType as P;
|
||||
Ok(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: match packet_type {
|
||||
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
||||
0x02 => P::Data(bytes),
|
||||
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
||||
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
||||
// 0x05 is handled seperately
|
||||
_ => return Err(WispError::InvalidPacketType),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_parse_info(
|
||||
frame: Frame,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
if packet_type == 0x05 {
|
||||
Self::parse_info(bytes, role, extension_builders)
|
||||
} else {
|
||||
Self::parse_packet(packet_type, bytes)
|
||||
}
|
||||
}
|
||||
pub(crate) fn maybe_parse_info(
|
||||
frame: Frame<'a>,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
if packet_type == 0x05 {
|
||||
Self::parse_info(bytes, role, extension_builders)
|
||||
} else {
|
||||
Self::parse_packet(packet_type, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn maybe_handle_extension(
|
||||
frame: Frame,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
read: &mut (dyn WebSocketRead + Send),
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<Option<Self>, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 5 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
match packet_type {
|
||||
0x01 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Connect(bytes.try_into()?),
|
||||
})),
|
||||
0x02 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Data(bytes.freeze()),
|
||||
})),
|
||||
0x03 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Continue(bytes.try_into()?),
|
||||
})),
|
||||
0x04 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Close(bytes.try_into()?),
|
||||
})),
|
||||
0x05 => Ok(None),
|
||||
packet_type => {
|
||||
if let Some(extension) = extensions
|
||||
.iter_mut()
|
||||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||
{
|
||||
extension.handle_packet(bytes.freeze(), read, write).await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub(crate) async fn maybe_handle_extension(
|
||||
frame: Frame<'a>,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
read: &mut (dyn WebSocketRead + Send),
|
||||
write: &LockedWebSocketWrite,
|
||||
) -> Result<Option<Self>, WispError> {
|
||||
if !frame.finished {
|
||||
return Err(WispError::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != OpCode::Binary {
|
||||
return Err(WispError::WsFrameInvalidType);
|
||||
}
|
||||
let mut bytes = frame.payload;
|
||||
if bytes.remaining() < 5 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
match packet_type {
|
||||
0x01 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Connect(bytes.try_into()?),
|
||||
})),
|
||||
0x02 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Data(bytes),
|
||||
})),
|
||||
0x03 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Continue(bytes.try_into()?),
|
||||
})),
|
||||
0x04 => Ok(Some(Self {
|
||||
stream_id: bytes.get_u32_le(),
|
||||
packet_type: PacketType::Close(bytes.try_into()?),
|
||||
})),
|
||||
0x05 => Ok(None),
|
||||
packet_type => {
|
||||
if let Some(extension) = extensions
|
||||
.iter_mut()
|
||||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||
{
|
||||
extension
|
||||
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
|
||||
.await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_info(
|
||||
mut bytes: BytesMut,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
// packet type is already read by code that calls this
|
||||
if bytes.remaining() < 4 + 2 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if bytes.get_u32_le() != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
fn parse_info(
|
||||
mut bytes: Payload<'a>,
|
||||
role: Role,
|
||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||
) -> Result<Self, WispError> {
|
||||
// packet type is already read by code that calls this
|
||||
if bytes.remaining() < 4 + 2 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if bytes.get_u32_le() != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
|
||||
let version = WispVersion {
|
||||
major: bytes.get_u8(),
|
||||
minor: bytes.get_u8(),
|
||||
};
|
||||
let version = WispVersion {
|
||||
major: bytes.get_u8(),
|
||||
minor: bytes.get_u8(),
|
||||
};
|
||||
|
||||
if version.major != WISP_VERSION.major {
|
||||
return Err(WispError::IncompatibleProtocolVersion);
|
||||
}
|
||||
if version.major != WISP_VERSION.major {
|
||||
return Err(WispError::IncompatibleProtocolVersion);
|
||||
}
|
||||
|
||||
let mut extensions = Vec::new();
|
||||
let mut extensions = Vec::new();
|
||||
|
||||
while bytes.remaining() > 4 {
|
||||
// We have some extensions
|
||||
let id = bytes.get_u8();
|
||||
let length = usize::try_from(bytes.get_u32_le())?;
|
||||
if bytes.remaining() < length {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
|
||||
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
|
||||
extensions.push(extension)
|
||||
}
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
}
|
||||
while bytes.remaining() > 4 {
|
||||
// We have some extensions
|
||||
let id = bytes.get_u8();
|
||||
let length = usize::try_from(bytes.get_u32_le())?;
|
||||
if bytes.remaining() < length {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
|
||||
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
|
||||
extensions.push(extension)
|
||||
}
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version,
|
||||
extensions,
|
||||
}),
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
stream_id: 0,
|
||||
packet_type: PacketType::Info(InfoPacket {
|
||||
version,
|
||||
extensions,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for Packet {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.packet_type.as_u8());
|
||||
bytes.put_u32_le(self.stream_id);
|
||||
self.packet_type.encode(bytes);
|
||||
}
|
||||
impl Encode for Packet<'_> {
|
||||
fn encode(self, bytes: &mut BytesMut) {
|
||||
bytes.put_u8(self.packet_type.as_u8());
|
||||
bytes.put_u32_le(self.stream_id);
|
||||
self.packet_type.encode(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<BytesMut> for Packet {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
Self::parse_packet(packet_type, bytes)
|
||||
}
|
||||
impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
|
||||
type Error = WispError;
|
||||
fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
|
||||
if bytes.remaining() < 1 {
|
||||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
let packet_type = bytes.get_u8();
|
||||
Self::parse_packet(packet_type, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for BytesMut {
|
||||
fn from(packet: Packet) -> Self {
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
|
||||
packet.encode(&mut encoded);
|
||||
encoded
|
||||
}
|
||||
impl From<Packet<'_>> for BytesMut {
|
||||
fn from(packet: Packet) -> Self {
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
|
||||
packet.encode(&mut encoded);
|
||||
encoded
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ws::Frame> for Packet {
|
||||
type Error = WispError;
|
||||
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> {
|
||||
if !frame.finished {
|
||||
return Err(Self::Error::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != ws::OpCode::Binary {
|
||||
return Err(Self::Error::WsFrameInvalidType);
|
||||
}
|
||||
Packet::try_from(frame.payload)
|
||||
}
|
||||
impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
|
||||
type Error = WispError;
|
||||
fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
|
||||
if !frame.finished {
|
||||
return Err(Self::Error::WsFrameNotFinished);
|
||||
}
|
||||
if frame.opcode != ws::OpCode::Binary {
|
||||
return Err(Self::Error::WsFrameInvalidType);
|
||||
}
|
||||
Packet::try_from(frame.payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for ws::Frame {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Self::binary(BytesMut::from(packet))
|
||||
}
|
||||
impl From<Packet<'_>> for ws::Frame<'static> {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Self::binary(Payload::Bytes(BytesMut::from(packet)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,146 +1,146 @@
|
|||
//! futures sink unfold with a close function
|
||||
use core::{future::Future, pin::Pin};
|
||||
use futures::{
|
||||
ready,
|
||||
task::{Context, Poll},
|
||||
Sink,
|
||||
ready,
|
||||
task::{Context, Poll},
|
||||
Sink,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
pin_project! {
|
||||
/// UnfoldState used for stream and sink unfolds
|
||||
#[project = UnfoldStateProj]
|
||||
#[project_replace = UnfoldStateProjReplace]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum UnfoldState<T, Fut> {
|
||||
Value {
|
||||
value: T,
|
||||
},
|
||||
Future {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
},
|
||||
Empty,
|
||||
}
|
||||
/// UnfoldState used for stream and sink unfolds
|
||||
#[project = UnfoldStateProj]
|
||||
#[project_replace = UnfoldStateProjReplace]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum UnfoldState<T, Fut> {
|
||||
Value {
|
||||
value: T,
|
||||
},
|
||||
Future {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
},
|
||||
Empty,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Fut> UnfoldState<T, Fut> {
|
||||
pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut Fut>> {
|
||||
match self.project() {
|
||||
UnfoldStateProj::Future { future } => Some(future),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut Fut>> {
|
||||
match self.project() {
|
||||
UnfoldStateProj::Future { future } => Some(future),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn take_value(self: Pin<&mut Self>) -> Option<T> {
|
||||
match &*self {
|
||||
Self::Value { .. } => match self.project_replace(Self::Empty) {
|
||||
UnfoldStateProjReplace::Value { value } => Some(value),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn take_value(self: Pin<&mut Self>) -> Option<T> {
|
||||
match &*self {
|
||||
Self::Value { .. } => match self.project_replace(Self::Empty) {
|
||||
UnfoldStateProjReplace::Value { value } => Some(value),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Sink for the [`unfold`] function.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "sinks do nothing unless polled"]
|
||||
pub struct Unfold<T, F, R, CT, CF, CR> {
|
||||
function: F,
|
||||
close_function: CF,
|
||||
#[pin]
|
||||
state: UnfoldState<T, R>,
|
||||
#[pin]
|
||||
close_state: UnfoldState<CT, CR>
|
||||
}
|
||||
/// Sink for the [`unfold`] function.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "sinks do nothing unless polled"]
|
||||
pub struct Unfold<T, F, R, CT, CF, CR> {
|
||||
function: F,
|
||||
close_function: CF,
|
||||
#[pin]
|
||||
state: UnfoldState<T, R>,
|
||||
#[pin]
|
||||
close_state: UnfoldState<CT, CR>
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>(
|
||||
init: T,
|
||||
function: F,
|
||||
close_init: CT,
|
||||
close_function: CF,
|
||||
init: T,
|
||||
function: F,
|
||||
close_init: CT,
|
||||
close_function: CF,
|
||||
) -> Unfold<T, F, R, CT, CF, CR>
|
||||
where
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
{
|
||||
Unfold {
|
||||
function,
|
||||
close_function,
|
||||
state: UnfoldState::Value { value: init },
|
||||
close_state: UnfoldState::Value { value: close_init },
|
||||
}
|
||||
Unfold {
|
||||
function,
|
||||
close_function,
|
||||
state: UnfoldState::Value { value: init },
|
||||
close_state: UnfoldState::Value { value: close_init },
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR>
|
||||
where
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
{
|
||||
type Error = E;
|
||||
type Error = E;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.poll_flush(cx)
|
||||
}
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
|
||||
let mut this = self.project();
|
||||
let future = match this.state.as_mut().take_value() {
|
||||
Some(value) => (this.function)(value, item),
|
||||
None => panic!("start_send called without poll_ready being called first"),
|
||||
};
|
||||
this.state.set(UnfoldState::Future { future });
|
||||
Ok(())
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
|
||||
let mut this = self.project();
|
||||
let future = match this.state.as_mut().take_value() {
|
||||
Some(value) => (this.function)(value, item),
|
||||
None => panic!("start_send called without poll_ready being called first"),
|
||||
};
|
||||
this.state.set(UnfoldState::Future { future });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let mut this = self.project();
|
||||
Poll::Ready(if let Some(future) = this.state.as_mut().project_future() {
|
||||
match ready!(future.poll(cx)) {
|
||||
Ok(state) => {
|
||||
this.state.set(UnfoldState::Value { value: state });
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
this.state.set(UnfoldState::Empty);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let mut this = self.project();
|
||||
Poll::Ready(if let Some(future) = this.state.as_mut().project_future() {
|
||||
match ready!(future.poll(cx)) {
|
||||
Ok(state) => {
|
||||
this.state.set(UnfoldState::Value { value: state });
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
this.state.set(UnfoldState::Empty);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
let mut this = self.project();
|
||||
Poll::Ready(
|
||||
if let Some(future) = this.close_state.as_mut().project_future() {
|
||||
match ready!(future.poll(cx)) {
|
||||
Ok(state) => {
|
||||
this.close_state.set(UnfoldState::Value { value: state });
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
this.close_state.set(UnfoldState::Empty);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let future = match this.close_state.as_mut().take_value() {
|
||||
Some(value) => (this.close_function)(value),
|
||||
None => panic!("start_send called without poll_ready being called first"),
|
||||
};
|
||||
this.close_state.set(UnfoldState::Future { future });
|
||||
return Poll::Pending;
|
||||
},
|
||||
)
|
||||
}
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
let mut this = self.project();
|
||||
Poll::Ready(
|
||||
if let Some(future) = this.close_state.as_mut().project_future() {
|
||||
match ready!(future.poll(cx)) {
|
||||
Ok(state) => {
|
||||
this.close_state.set(UnfoldState::Value { value: state });
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
this.close_state.set(UnfoldState::Empty);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let future = match this.close_state.as_mut().take_value() {
|
||||
Some(value) => (this.close_function)(value),
|
||||
None => panic!("start_send called without poll_ready being called first"),
|
||||
};
|
||||
this.close_state.set(UnfoldState::Future { future });
|
||||
return Poll::Pending;
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
sink_unfold,
|
||||
ws::{Frame, LockedWebSocketWrite},
|
||||
ws::{Frame, LockedWebSocketWrite, Payload},
|
||||
CloseReason, Packet, Role, StreamType, WispError,
|
||||
};
|
||||
|
||||
|
@ -9,9 +9,10 @@ use event_listener::Event;
|
|||
use flume as mpsc;
|
||||
use futures::{
|
||||
channel::oneshot,
|
||||
ready, select, stream::{self, IntoAsyncRead},
|
||||
ready, select,
|
||||
stream::{self, IntoAsyncRead},
|
||||
task::{noop_waker_ref, Context, Poll},
|
||||
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt,
|
||||
AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
|
@ -23,7 +24,7 @@ use std::{
|
|||
};
|
||||
|
||||
pub(crate) enum WsEvent {
|
||||
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
||||
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||
CreateStream(
|
||||
StreamType,
|
||||
String,
|
||||
|
@ -100,8 +101,10 @@ pub struct MuxStreamWrite {
|
|||
}
|
||||
|
||||
impl MuxStreamWrite {
|
||||
/// Write data to the stream.
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
||||
pub(crate) async fn write_payload_internal(
|
||||
&self,
|
||||
frame: Frame<'static>,
|
||||
) -> Result<(), WispError> {
|
||||
if self.role == Role::Client
|
||||
&& self.stream_type == StreamType::Tcp
|
||||
&& self.flow_control.load(Ordering::Acquire) == 0
|
||||
|
@ -112,9 +115,7 @@ impl MuxStreamWrite {
|
|||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
|
||||
self.tx
|
||||
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
|
||||
.await?;
|
||||
self.tx.write_frame(frame).await?;
|
||||
|
||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||
self.flow_control.store(
|
||||
|
@ -125,6 +126,20 @@ impl MuxStreamWrite {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Write a payload to the stream.
|
||||
pub fn write_payload<'a>(
|
||||
&'a self,
|
||||
data: Payload<'_>,
|
||||
) -> impl Future<Output = Result<(), WispError>> + 'a {
|
||||
let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data));
|
||||
self.write_payload_internal(frame)
|
||||
}
|
||||
|
||||
/// Write data to the stream.
|
||||
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
|
||||
self.write_payload(Payload::Borrowed(data.as_ref())).await
|
||||
}
|
||||
|
||||
/// Get a handle to close the connection.
|
||||
///
|
||||
/// Useful to close the connection without having access to the stream.
|
||||
|
@ -173,16 +188,16 @@ impl MuxStreamWrite {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> {
|
||||
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>> {
|
||||
let handle = self.get_close_handle();
|
||||
Box::pin(sink_unfold::unfold(
|
||||
self,
|
||||
|tx, data| async move {
|
||||
tx.write(data).await?;
|
||||
tx.write_payload_internal(data).await?;
|
||||
Ok(tx)
|
||||
},
|
||||
handle,
|
||||
move |handle| async {
|
||||
|handle| async move {
|
||||
handle.close(CloseReason::Unknown).await?;
|
||||
Ok(handle)
|
||||
},
|
||||
|
@ -258,8 +273,13 @@ impl MuxStream {
|
|||
self.rx.read().await
|
||||
}
|
||||
|
||||
/// Write a payload to the stream.
|
||||
pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
|
||||
self.tx.write_payload(data).await
|
||||
}
|
||||
|
||||
/// Write data to the stream.
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
||||
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
|
||||
self.tx.write(data).await
|
||||
}
|
||||
|
||||
|
@ -301,6 +321,7 @@ impl MuxStream {
|
|||
},
|
||||
tx: MuxStreamIoSink {
|
||||
tx: self.tx.into_sink(),
|
||||
stream_id: self.stream_id,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -355,7 +376,9 @@ impl MuxProtocolExtensionStream {
|
|||
encoded.put_u8(packet_type);
|
||||
encoded.put_u32_le(self.stream_id);
|
||||
encoded.extend(data);
|
||||
self.tx.write_frame(Frame::binary(encoded)).await
|
||||
self.tx
|
||||
.write_frame(Frame::binary(Payload::Bytes(encoded)))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -391,12 +414,12 @@ impl Stream for MuxStreamIo {
|
|||
}
|
||||
}
|
||||
|
||||
impl Sink<Bytes> for MuxStreamIo {
|
||||
impl Sink<&[u8]> for MuxStreamIo {
|
||||
type Error = std::io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().tx.poll_ready(cx)
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
|
||||
self.project().tx.start_send(item)
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -433,7 +456,8 @@ pin_project! {
|
|||
/// Write side of a multiplexor stream that implements futures `Sink`.
|
||||
pub struct MuxStreamIoSink {
|
||||
#[pin]
|
||||
tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>,
|
||||
tx: Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>>,
|
||||
stream_id: u32,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -444,7 +468,7 @@ impl MuxStreamIoSink {
|
|||
}
|
||||
}
|
||||
|
||||
impl Sink<Bytes> for MuxStreamIoSink {
|
||||
impl Sink<&[u8]> for MuxStreamIoSink {
|
||||
type Error = std::io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
|
@ -452,10 +476,14 @@ impl Sink<Bytes> for MuxStreamIoSink {
|
|||
.poll_ready(cx)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
|
||||
let stream_id = self.stream_id;
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.start_send(Frame::from(Packet::new_data(
|
||||
stream_id,
|
||||
Payload::Borrowed(item),
|
||||
)))
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -564,10 +592,10 @@ impl AsyncRead for MuxStreamAsyncRead {
|
|||
}
|
||||
impl AsyncBufRead for MuxStreamAsyncRead {
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
|
||||
self.project().rx.poll_fill_buf(cx)
|
||||
self.project().rx.poll_fill_buf(cx)
|
||||
}
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
self.project().rx.consume(amt)
|
||||
self.project().rx.consume(amt)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -582,7 +610,10 @@ pin_project! {
|
|||
|
||||
impl MuxStreamAsyncWrite {
|
||||
pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
|
||||
Self { tx: sink, error: None }
|
||||
Self {
|
||||
tx: sink,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -599,7 +630,7 @@ impl AsyncWrite for MuxStreamAsyncWrite {
|
|||
let mut this = self.as_mut().project();
|
||||
|
||||
ready!(this.tx.as_mut().poll_ready(cx))?;
|
||||
match this.tx.as_mut().start_send(Bytes::copy_from_slice(buf)) {
|
||||
match this.tx.as_mut().start_send(buf) {
|
||||
Ok(()) => {
|
||||
let mut cx = Context::from_waker(noop_waker_ref());
|
||||
let cx = &mut cx;
|
||||
|
|
230
wisp/src/ws.rs
230
wisp/src/ws.rs
|
@ -4,83 +4,168 @@
|
|||
//! for other WebSocket implementations.
|
||||
//!
|
||||
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
||||
use std::sync::Arc;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use crate::WispError;
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::lock::Mutex;
|
||||
|
||||
/// Payload of the websocket frame.
|
||||
#[derive(Debug)]
|
||||
pub enum Payload<'a> {
|
||||
/// Borrowed payload. Currently used when writing data.
|
||||
Borrowed(&'a [u8]),
|
||||
/// BytesMut payload. Currently used when reading data.
|
||||
Bytes(BytesMut),
|
||||
}
|
||||
|
||||
impl From<BytesMut> for Payload<'static> {
|
||||
fn from(value: BytesMut) -> Self {
|
||||
Self::Bytes(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a [u8]> for Payload<'a> {
|
||||
fn from(value: &'a [u8]) -> Self {
|
||||
Self::Borrowed(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Payload<'_> {
|
||||
/// Turn a Payload<'a> into a Payload<'static> by copying the data.
|
||||
pub fn into_owned(self) -> Self {
|
||||
match self {
|
||||
Self::Bytes(x) => Self::Bytes(x),
|
||||
Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Payload<'_>> for BytesMut {
|
||||
fn from(value: Payload<'_>) -> Self {
|
||||
match value {
|
||||
Payload::Bytes(x) => x,
|
||||
Payload::Borrowed(x) => x.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Payload<'_> {
|
||||
type Target = [u8];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::Bytes(x) => x.deref(),
|
||||
Self::Borrowed(x) => x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Payload<'_> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Bytes(x) => Self::Bytes(x.clone()),
|
||||
Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Buf for Payload<'_> {
|
||||
fn remaining(&self) -> usize {
|
||||
match self {
|
||||
Self::Bytes(x) => x.remaining(),
|
||||
Self::Borrowed(x) => x.remaining(),
|
||||
}
|
||||
}
|
||||
|
||||
fn chunk(&self) -> &[u8] {
|
||||
match self {
|
||||
Self::Bytes(x) => x.chunk(),
|
||||
Self::Borrowed(x) => x.chunk(),
|
||||
}
|
||||
}
|
||||
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
match self {
|
||||
Self::Bytes(x) => x.advance(cnt),
|
||||
Self::Borrowed(x) => x.advance(cnt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Opcode of the WebSocket frame.
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum OpCode {
|
||||
/// Text frame.
|
||||
Text,
|
||||
/// Binary frame.
|
||||
Binary,
|
||||
/// Close frame.
|
||||
Close,
|
||||
/// Ping frame.
|
||||
Ping,
|
||||
/// Pong frame.
|
||||
Pong,
|
||||
/// Text frame.
|
||||
Text,
|
||||
/// Binary frame.
|
||||
Binary,
|
||||
/// Close frame.
|
||||
Close,
|
||||
/// Ping frame.
|
||||
Ping,
|
||||
/// Pong frame.
|
||||
Pong,
|
||||
}
|
||||
|
||||
/// WebSocket frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Frame {
|
||||
/// Whether the frame is finished or not.
|
||||
pub finished: bool,
|
||||
/// Opcode of the WebSocket frame.
|
||||
pub opcode: OpCode,
|
||||
/// Payload of the WebSocket frame.
|
||||
pub payload: BytesMut,
|
||||
pub struct Frame<'a> {
|
||||
/// Whether the frame is finished or not.
|
||||
pub finished: bool,
|
||||
/// Opcode of the WebSocket frame.
|
||||
pub opcode: OpCode,
|
||||
/// Payload of the WebSocket frame.
|
||||
pub payload: Payload<'a>,
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
/// Create a new text frame.
|
||||
pub fn text(payload: BytesMut) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Text,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
impl<'a> Frame<'a> {
|
||||
/// Create a new text frame.
|
||||
pub fn text(payload: Payload<'a>) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Text,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new binary frame.
|
||||
pub fn binary(payload: BytesMut) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Binary,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
/// Create a new binary frame.
|
||||
pub fn binary(payload: Payload<'a>) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Binary,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new close frame.
|
||||
pub fn close(payload: BytesMut) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Close,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
/// Create a new close frame.
|
||||
pub fn close(payload: Payload<'a>) -> Self {
|
||||
Self {
|
||||
finished: true,
|
||||
opcode: OpCode::Close,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic WebSocket read trait.
|
||||
#[async_trait]
|
||||
pub trait WebSocketRead {
|
||||
/// Read a frame from the socket.
|
||||
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>;
|
||||
/// Read a frame from the socket.
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError>;
|
||||
}
|
||||
|
||||
/// Generic WebSocket write trait.
|
||||
#[async_trait]
|
||||
pub trait WebSocketWrite {
|
||||
/// Write a frame to the socket.
|
||||
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
|
||||
/// Write a frame to the socket.
|
||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
|
||||
|
||||
/// Close the socket.
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||
/// Close the socket.
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||
}
|
||||
|
||||
/// Locked WebSocket.
|
||||
|
@ -88,35 +173,38 @@ pub trait WebSocketWrite {
|
|||
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
|
||||
|
||||
impl LockedWebSocketWrite {
|
||||
/// Create a new locked websocket.
|
||||
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
||||
Self(Mutex::new(ws).into())
|
||||
}
|
||||
/// Create a new locked websocket.
|
||||
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
||||
Self(Mutex::new(ws).into())
|
||||
}
|
||||
|
||||
/// Write a frame to the websocket.
|
||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
/// Write a frame to the websocket.
|
||||
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
|
||||
/// Close the websocket.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_close().await
|
||||
}
|
||||
/// Close the websocket.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_close().await
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame>, pub R)
|
||||
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
|
||||
where
|
||||
R: WebSocketRead + Send;
|
||||
R: WebSocketRead + Send;
|
||||
|
||||
#[async_trait]
|
||||
impl<R> WebSocketRead for AppendingWebSocketRead<R>
|
||||
where
|
||||
R: WebSocketRead + Send,
|
||||
R: WebSocketRead + Send,
|
||||
{
|
||||
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
||||
if let Some(x) = self.0.take() {
|
||||
return Ok(x);
|
||||
}
|
||||
return self.1.wisp_read_frame(tx).await;
|
||||
}
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &LockedWebSocketWrite,
|
||||
) -> Result<Frame<'static>, WispError> {
|
||||
if let Some(x) = self.0.take() {
|
||||
return Ok(x);
|
||||
}
|
||||
return self.1.wisp_read_frame(tx).await;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue