make extensions owned

This commit is contained in:
Toshit Chawda 2024-04-13 23:45:40 -07:00
parent 76da9fd619
commit ace9bf380d
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 21 additions and 23 deletions

View file

@ -204,7 +204,7 @@ pub async fn make_mux(
.await .await
.map_err(|_| WispError::WsImplSocketClosed)?; .map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await; wtx.wait_for_open().await;
let mux = ClientMux::new(wrx, wtx, Some(&[&UdpProtocolExtensionBuilder()])).await?; let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?;
Ok(mux) Ok(mux)
} }

View file

@ -22,7 +22,7 @@ use tokio_util::either::Either;
use wisp_mux::{ use wisp_mux::{
extensions::{ extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::UdpProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder,
}, },
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
}; };
@ -72,7 +72,7 @@ struct MuxOptions {
pub block_udp: bool, pub block_udp: bool,
pub block_non_http: bool, pub block_non_http: bool,
pub enforce_auth: bool, pub enforce_auth: bool,
pub auth: Arc<PasswordProtocolExtensionBuilder>, pub auth: Arc<Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>>>,
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -176,13 +176,13 @@ async fn main() -> Result<(), Error> {
auth.insert(username.to_string(), password.to_string()); auth.insert(username.to_string(), password.to_string());
} }
} }
let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth)); let pw_ext = PasswordProtocolExtensionBuilder::new_server(auth);
let mux_options = MuxOptions { let mux_options = MuxOptions {
block_local: opt.block_local, block_local: opt.block_local,
block_non_http: opt.block_non_http, block_non_http: opt.block_non_http,
block_udp: opt.block_udp, block_udp: opt.block_udp,
auth: pw_ext, auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]),
enforce_auth, enforce_auth,
}; };
@ -314,7 +314,7 @@ async fn accept_ws(
rx, rx,
tx, tx,
u32::MAX, u32::MAX,
Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]), Some(mux_options.auth.as_slice()),
) )
.await?; .await?;
if !mux if !mux
@ -331,7 +331,7 @@ async fn accept_ws(
} }
(mux, fut) (mux, fut)
} else { } else {
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await? ServerMux::new(rx, tx, u32::MAX, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?
}; };
println!( println!(

View file

@ -152,7 +152,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Sync)>> = Vec::new(); let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
if opts.udp { if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder())); extensions.push(Box::new(UdpProtocolExtensionBuilder()));
} }
@ -160,10 +160,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(auth) = auth { if let Some(auth) = auth {
extensions.push(Box::new(auth)); extensions.push(Box::new(auth));
} }
let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> =
extensions.iter().map(|x| x.as_ref()).collect();
let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?; let (mut mux, fut) = ClientMux::new(rx, tx, Some(extensions.as_slice())).await?;
if opts.udp if opts.udp
&& !mux && !mux
.supported_extension_ids .supported_extension_ids

View file

@ -112,7 +112,7 @@ pub mod udp {
//! rx, //! rx,
//! tx, //! tx,
//! 128, //! 128,
//! Some(&[&UdpProtocolExtensionBuilder()]) //! Some(&[Box::new(UdpProtocolExtensionBuilder())])
//! ); //! );
//! ``` //! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp)
@ -213,7 +213,7 @@ pub mod password {
//! rx, //! rx,
//! tx, //! tx,
//! 128, //! 128,
//! Some(&[&PasswordProtocolExtensionBuilder::new_server(passwords)]) //! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))])
//! ); //! );
//! ``` //! ```
//! //!
@ -224,10 +224,10 @@ pub mod password {
//! tx, //! tx,
//! 128, //! 128,
//! Some(&[ //! Some(&[
//! &PasswordProtocolExtensionBuilder::new_client( //! Box::new(PasswordProtocolExtensionBuilder::new_client(
//! "user1".to_string(), //! "user1".to_string(),
//! "pw".to_string() //! "pw".to_string()
//! ) //! ))
//! ]) //! ])
//! ); //! );
//! ``` //! ```

View file

@ -445,7 +445,7 @@ impl MuxInner {
/// ``` /// ```
/// use wisp_mux::ServerMux; /// use wisp_mux::ServerMux;
/// ///
/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([])); /// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([]));
/// tokio::spawn(async move { /// tokio::spawn(async move {
/// if let Err(e) = fut.await { /// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e); /// println!("error in multiplexor: {:?}", e);
@ -472,14 +472,14 @@ pub struct ServerMux {
impl ServerMux { impl ServerMux {
/// Create a new server-side multiplexor. /// Create a new server-side multiplexor.
/// ///
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn new<R, W>( pub async fn new<R, W>(
mut read: R, mut read: R,
write: W, write: W,
buffer_size: u32, buffer_size: u32,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError> ) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
@ -581,7 +581,7 @@ impl ServerMux {
/// ``` /// ```
/// use wisp_mux::{ClientMux, StreamType}; /// use wisp_mux::{ClientMux, StreamType};
/// ///
/// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?; /// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?;
/// tokio::spawn(async move { /// tokio::spawn(async move {
/// if let Err(e) = fut.await { /// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e); /// println!("error in multiplexor: {:?}", e);
@ -602,13 +602,13 @@ pub struct ClientMux {
impl ClientMux { impl ClientMux {
/// Create a new client side multiplexor. /// Create a new client side multiplexor.
/// ///
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn new<R, W>( pub async fn new<R, W>(
mut read: R, mut read: R,
write: W, write: W,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError> ) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,

View file

@ -380,7 +380,7 @@ impl Packet {
pub(crate) fn maybe_parse_info( pub(crate) fn maybe_parse_info(
frame: Frame, frame: Frame,
role: Role, role: Role,
extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
if !frame.finished { if !frame.finished {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
@ -431,7 +431,7 @@ impl Packet {
fn parse_info( fn parse_info(
mut bytes: Bytes, mut bytes: Bytes,
role: Role, role: Role,
extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
// packet type is already read by code that calls this // packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 { if bytes.remaining() < 4 + 2 {