mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
make requiring protocol extensions easy
This commit is contained in:
parent
063b527914
commit
01d7ac5002
6 changed files with 143 additions and 90 deletions
103
wisp/src/lib.rs
103
wisp/src/lib.rs
|
@ -80,8 +80,8 @@ pub enum WispError {
|
|||
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||
/// The protocol extension implementation did not support the action.
|
||||
ExtensionImplNotSupported,
|
||||
/// The UDP protocol extension is not supported by the server.
|
||||
UdpExtensionNotSupported,
|
||||
/// The specified protocol extensions are not supported by the server.
|
||||
ExtensionsNotSupported(Vec<u8>),
|
||||
/// The string was invalid UTF-8.
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
/// The integer failed to convert.
|
||||
|
@ -137,7 +137,9 @@ impl std::fmt::Display for WispError {
|
|||
"Protocol extension implementation error: unsupported feature"
|
||||
)
|
||||
}
|
||||
Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"),
|
||||
Self::ExtensionsNotSupported(list) => {
|
||||
write!(f, "Protocol extensions {:?} not supported", list)
|
||||
}
|
||||
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
|
||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||
|
@ -483,12 +485,12 @@ impl ServerMux {
|
|||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn new<R, W>(
|
||||
pub async fn create<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
buffer_size: u32,
|
||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
|
@ -532,7 +534,7 @@ impl ServerMux {
|
|||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
Ok(ServerMuxResult(
|
||||
Self {
|
||||
muxstream_recv: rx,
|
||||
close_tx: close_tx.clone(),
|
||||
|
@ -590,6 +592,48 @@ impl Drop for ServerMux {
|
|||
}
|
||||
}
|
||||
|
||||
/// Result of `ServerMux::new`.
|
||||
pub struct ServerMuxResult<F>(ServerMux, F)
|
||||
where
|
||||
F: Future<Output = Result<(), WispError>> + Send;
|
||||
|
||||
impl<F> ServerMuxResult<F>
|
||||
where
|
||||
F: Future<Output = Result<(), WispError>> + Send,
|
||||
{
|
||||
/// Require no protocol extensions.
|
||||
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
|
||||
(self.0, self.1)
|
||||
}
|
||||
|
||||
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
||||
/// extensions are not supported.
|
||||
pub async fn with_required_extensions(
|
||||
self,
|
||||
extensions: &[u8],
|
||||
) -> Result<(ServerMux, F), WispError> {
|
||||
let mut unsupported_extensions = Vec::new();
|
||||
for extension in extensions {
|
||||
if !self.0.supported_extension_ids.contains(extension) {
|
||||
unsupported_extensions.push(*extension);
|
||||
}
|
||||
}
|
||||
if unsupported_extensions.is_empty() {
|
||||
Ok((self.0, self.1))
|
||||
} else {
|
||||
self.0.close_extension_incompat().await?;
|
||||
self.1.await?;
|
||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||
}
|
||||
}
|
||||
|
||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
|
||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -620,11 +664,11 @@ impl ClientMux {
|
|||
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn new<R, W>(
|
||||
pub async fn create<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
|
@ -671,7 +715,7 @@ impl ClientMux {
|
|||
}
|
||||
|
||||
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
||||
Ok((
|
||||
Ok(ClientMuxResult(
|
||||
Self {
|
||||
stream_tx: tx.clone(),
|
||||
downgraded,
|
||||
|
@ -710,7 +754,9 @@ impl ClientMux {
|
|||
.iter()
|
||||
.any(|x| *x == UdpProtocolExtension::ID)
|
||||
{
|
||||
return Err(WispError::UdpExtensionNotSupported);
|
||||
return Err(WispError::ExtensionsNotSupported(vec![
|
||||
UdpProtocolExtension::ID,
|
||||
]));
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.stream_tx
|
||||
|
@ -750,3 +796,40 @@ impl Drop for ClientMux {
|
|||
let _ = self.stream_tx.send(WsEvent::EndFut(None));
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of `ClientMux::new`.
|
||||
pub struct ClientMuxResult<F>(ClientMux, F)
|
||||
where
|
||||
F: Future<Output = Result<(), WispError>> + Send;
|
||||
|
||||
impl<F> ClientMuxResult<F>
|
||||
where
|
||||
F: Future<Output = Result<(), WispError>> + Send,
|
||||
{
|
||||
/// Require no protocol extensions.
|
||||
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
|
||||
(self.0, self.1)
|
||||
}
|
||||
|
||||
/// Require protocol extensions by their ID.
|
||||
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> {
|
||||
let mut unsupported_extensions = Vec::new();
|
||||
for extension in extensions {
|
||||
if !self.0.supported_extension_ids.contains(extension) {
|
||||
unsupported_extensions.push(*extension);
|
||||
}
|
||||
}
|
||||
if unsupported_extensions.is_empty() {
|
||||
Ok((self.0, self.1))
|
||||
} else {
|
||||
self.0.close_extension_incompat().await?;
|
||||
self.1.await?;
|
||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||
}
|
||||
}
|
||||
|
||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
||||
self.with_required_extensions(&[UdpProtocolExtension::ID]).await
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue