diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index ed4e999..2e28fce 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -2,9 +2,10 @@ pub mod twisp; pub mod utils; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use anyhow::Context; +use bytes::BytesMut; use cfg_if::cfg_if; use event_listener::Event; use futures_util::FutureExt; @@ -13,12 +14,12 @@ use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, select, - task::JoinSet, + task::JoinSet, time::interval, }; use tokio_util::compat::FuturesAsyncReadCompatExt; use uuid::Uuid; use wisp_mux::{ - CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux, + ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux }; use crate::{ @@ -237,6 +238,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { .context("failed to create server multiplexor")? .with_required_extensions(&required_extensions) .await?; + let mux = Arc::new(mux); debug!( "new wisp client id {:?} connected with extensions {:?}", @@ -255,6 +257,14 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { trace!("wisp client id {:?} multiplexor result {:?}", mux_id, x) }))); + let ping_mux = mux.clone(); + set.spawn(async move { + let mut interval = interval(Duration::from_secs(30)); + while ping_mux.send_ping(Payload::Bytes(BytesMut::new())).await.is_ok() { + interval.tick().await; + } + }); + while let Some((connect, stream)) = mux.server_new_stream().await { set.spawn(handle_stream( connect, diff --git a/wisp/src/inner.rs b/wisp/src/inner.rs index 142714c..4bbae67 100644 --- a/wisp/src/inner.rs +++ b/wisp/src/inner.rs @@ -23,6 +23,8 @@ pub(crate) enum WsEvent { u16, oneshot::Sender>, ), + SendPing(Payload<'static>, oneshot::Sender>), + SendPong(Payload<'static>), WispMessage(Option>, Option>), EndFut(Option), } @@ -234,6 +236,8 @@ impl MuxInner { let (mut frame, optional_frame) = msg?; if frame.opcode == OpCode::Close { return Ok(None); + } else if frame.opcode == OpCode::Ping { + return Ok(Some(WsEvent::SendPong(frame.payload))); } if let Some(ref extra_frame) = optional_frame { @@ -308,6 +312,12 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } + WsEvent::SendPing(payload, channel) => { + let _ = channel.send(self.tx.write_frame(Frame::new(OpCode::Ping, payload, true)).await); + } + WsEvent::SendPong(payload) => { + self.tx.write_frame(Frame::new(OpCode::Pong, payload, true)).await?; + } WsEvent::EndFut(x) => { if let Some(reason) = x { let _ = self diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index f30c8db..d27912a 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -31,7 +31,7 @@ use std::{ Arc, }, }; -use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; +use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload}; /// Wisp version supported by this crate. pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; @@ -363,6 +363,19 @@ impl ServerMux { self.muxstream_recv.recv_async().await.ok() } + /// Send a ping to the client. + pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + let (tx, rx) = oneshot::channel(); + self.actor_tx + .send_async(WsEvent::SendPing(payload, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); @@ -554,6 +567,19 @@ impl ClientMux { rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } + /// Send a ping to the server. + pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { + if self.actor_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + let (tx, rx) = oneshot::channel(); + self.actor_tx + .send_async(WsEvent::SendPing(payload, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { if self.actor_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 9ca9342..d75b694 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -120,6 +120,15 @@ pub struct Frame<'a> { } impl<'a> Frame<'a> { + /// Create a new frame. + pub fn new(opcode: OpCode, payload: Payload<'a>, finished: bool) -> Self { + Self { + finished, + opcode, + payload, + } + } + /// Create a new text frame. pub fn text(payload: Payload<'a>) -> Self { Self {