diff --git a/src/main.rs b/src/main.rs index b5aef5f..d76a657 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,8 @@ use bytes::Bytes; use fastwebsockets::{upgrade, FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use futures_util::{SinkExt, StreamExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, + body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request, + Response, StatusCode, }; use hyper_util::rt::TokioIo; use tokio::net::{TcpListener, TcpStream}; @@ -35,18 +36,34 @@ async fn main() -> Result<(), Error> { Ok(()) } -async fn accept_http(mut req: Request, addr: String) -> Result, WebSocketError> { +async fn accept_http( + mut req: Request, + addr: String, +) -> Result, WebSocketError> { if upgrade::is_upgrade_request(&req) { let uri = req.uri().clone(); - let (res, fut) = upgrade::upgrade(&mut req)?; + let (mut res, fut) = upgrade::upgrade(&mut req)?; tokio::spawn(async move { - if let Err(e) = accept_ws(fut, uri.path().to_string(), addr.clone()).await - { + if let Err(e) = accept_ws(fut, uri.path().to_string(), addr.clone()).await { println!("{:?}: error in ws: {:?}", addr, e); } }); + if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") { + let first_protocol = protocol + .to_str() + .expect("failed to get protocol") + .split(',') + .next() + .expect("failed to get first protocol") + .trim(); + res.headers_mut().insert( + "Sec-Websocket-Protocol", + HeaderValue::from_str(first_protocol).unwrap(), + ); + } + Ok(res) } else { Ok(Response::builder() @@ -59,7 +76,7 @@ async fn accept_http(mut req: Request, addr: String) -> Result Result<(), Box> { let mut ws_stream = FragmentCollector::new(fut.await?);