diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index 0c40ffc..d3bdc27 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -11,6 +11,7 @@ use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use hyper::{Request, Response}; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::sync::oneshot; use tracing::{error, warn, Instrument, Span}; @@ -45,24 +46,32 @@ pub(super) async fn ws_server_upgrade( tokio::spawn( async move { - let (ws_rx, mut ws_tx) = match fut.await { - Ok(ws) => ws.split(tokio::io::split), + let (ws_rx, ws_tx) = match fut.await { + Ok(mut ws) => { + ws.set_auto_pong(false); + ws.set_auto_close(false); + ws.set_auto_apply_mask(mask_frame); + ws.split(tokio::io::split) + } Err(err) => { error!("Error during http upgrade request: {:?}", err); return; } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - ws_tx.set_auto_apply_mask(mask_frame); + let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); tokio::task::spawn( - transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) - .instrument(Span::current()), + transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()), ); - let _ = - transport::io::propagate_local_to_remote(local_rx, WebsocketTunnelWrite::new(ws_tx), close_tx, None) - .await; + let _ = transport::io::propagate_local_to_remote( + local_rx, + WebsocketTunnelWrite::new(ws_tx, pending_ops), + close_tx, + Some(Duration::from_secs(30)), + ) + .await; } .instrument(Span::current()), ); diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index ee6f9a7..e9f0d02 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -436,6 +436,7 @@ impl WsServer { let websocket_upgrade_fn = mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); let conn_fut = http1::Builder::new() + .header_read_timeout(Duration::from_secs(10)) .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .with_upgrades(); diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs index 8d51e4e..86c6a04 100644 --- a/src/tunnel/transport/http2.rs +++ b/src/tunnel/transport/http2.rs @@ -12,11 +12,14 @@ use hyper::http::response::Parts; use hyper::Request; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use log::{debug, error, warn}; +use std::future::Future; use std::io; use std::io::ErrorKind; use std::ops::DerefMut; +use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, Notify}; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use uuid::Uuid; @@ -97,6 +100,14 @@ impl TunnelWrite for Http2TunnelWrite { async fn close(&mut self) -> Result<(), io::Error> { Ok(()) } + + fn pending_operations_notify(&mut self) -> Arc { + Arc::new(Notify::new()) + } + + fn handle_pending_operations(&mut self) -> impl Future> + Send { + std::future::ready(Ok(())) + } } pub async fn connect( @@ -177,6 +188,7 @@ pub async fn connect( .timer(TokioTimer::new()) .adaptive_window(true) .keep_alive_interval(client.config.websocket_ping_frequency) + .keep_alive_timeout(Duration::from_secs(10)) .keep_alive_while_idle(false) .handshake(TokioIo::new(transport)) .await diff --git a/src/tunnel/transport/io.rs b/src/tunnel/transport/io.rs index 14457a7..214b0aa 100644 --- a/src/tunnel/transport/io.rs +++ b/src/tunnel/transport/io.rs @@ -3,10 +3,12 @@ use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWr use bytes::{BufMut, BytesMut}; use futures_util::{pin_mut, FutureExt}; use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::select; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, Notify}; use tokio::time::Instant; use tracing::log::debug; use tracing::{error, info, warn}; @@ -18,6 +20,8 @@ pub trait TunnelWrite: Send + 'static { fn write(&mut self) -> impl Future> + Send; fn ping(&mut self) -> impl Future> + Send; fn close(&mut self) -> impl Future> + Send; + fn pending_operations_notify(&mut self) -> Arc; + fn handle_pending_operations(&mut self) -> impl Future> + Send; } pub trait TunnelRead: Send + 'static { @@ -74,6 +78,20 @@ impl TunnelWrite for TunnelWriter { Self::Http2(s) => s.close().await, } } + + fn pending_operations_notify(&mut self) -> Arc { + match self { + Self::Websocket(s) => s.pending_operations_notify(), + Self::Http2(s) => s.pending_operations_notify(), + } + } + + async fn handle_pending_operations(&mut self) -> Result<(), std::io::Error> { + match self { + Self::Websocket(s) => s.handle_pending_operations().await, + Self::Http2(s) => s.handle_pending_operations().await, + } + } } pub async fn propagate_local_to_remote( @@ -94,6 +112,9 @@ pub async fn propagate_local_to_remote( let start_at = Instant::now().checked_add(frequency).unwrap_or_else(Instant::now); let timeout = tokio::time::interval_at(start_at, frequency); let should_close = close_tx.closed().fuse(); + let notify = ws_tx.pending_operations_notify(); + let mut has_pending_operations = notify.notified(); + let mut has_pending_operations_pin = unsafe { Pin::new_unchecked(&mut has_pending_operations) }; pin_mut!(timeout); pin_mut!(should_close); @@ -108,9 +129,21 @@ pub async fn propagate_local_to_remote( biased; read_len = local_rx.read_buf(ws_tx.buf_mut()) => read_len, - + _ = &mut should_close => break, + _ = &mut has_pending_operations_pin => { + has_pending_operations = notify.notified(); + has_pending_operations_pin = unsafe { Pin::new_unchecked(&mut has_pending_operations) }; + match ws_tx.handle_pending_operations().await { + Ok(_) => continue, + Err(err) => { + warn!("error while handling pending operations {}", err); + break; + } + } + }, + _ = timeout.tick(), if ping_frequency.is_some() => { debug!("sending ping to keep connection alive"); ws_tx.ping().await?; diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index 317a5db..46f094f 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -5,7 +5,7 @@ use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX}; use crate::tunnel::RemoteAddr; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; -use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; use http_body_util::Empty; use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; @@ -14,24 +14,38 @@ use hyper::upgrade::Upgraded; use hyper::Request; use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; -use log::debug; +use log::{debug, info}; use std::io; use std::io::ErrorKind; use std::ops::DerefMut; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Notify; use tracing::trace; use uuid::Uuid; pub struct WebsocketTunnelWrite { inner: WebSocketWrite>>, buf: BytesMut, + pending_operations: Receiver>, + pending_ops_notify: Arc, + in_flight_ping: AtomicUsize, } impl WebsocketTunnelWrite { - pub fn new(ws: WebSocketWrite>>) -> Self { + pub fn new( + ws: WebSocketWrite>>, + (pending_operations, notify): (Receiver>, Arc), + ) -> Self { Self { inner: ws, buf: BytesMut::with_capacity(MAX_PACKET_LENGTH), + pending_operations, + pending_ops_notify: notify, + in_flight_ping: AtomicUsize::new(0), } } } @@ -76,6 +90,13 @@ impl TunnelWrite for WebsocketTunnelWrite { } async fn ping(&mut self) -> Result<(), io::Error> { + if self.in_flight_ping.fetch_add(1, Relaxed) >= 3 { + return Err(io::Error::new( + ErrorKind::ConnectionAborted, + "too many in flight/un-answered pings", + )); + } + if let Err(err) = self .inner .write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))) @@ -94,15 +115,54 @@ impl TunnelWrite for WebsocketTunnelWrite { Ok(()) } + + fn pending_operations_notify(&mut self) -> Arc { + self.pending_ops_notify.clone() + } + + async fn handle_pending_operations(&mut self) -> Result<(), io::Error> { + while let Ok(frame) = self.pending_operations.try_recv() { + info!("received frame {:?}", frame.opcode); + match frame.opcode { + OpCode::Close => { + if self.inner.write_frame(frame).await.is_err() { + return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send close frame")); + } + } + OpCode::Ping => { + if self.inner.write_frame(Frame::pong(frame.payload)).await.is_err() { + return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send pong frame")); + } + } + OpCode::Pong => { + self.in_flight_ping.fetch_sub(1, Relaxed); + } + OpCode::Continuation | OpCode::Text | OpCode::Binary => unreachable!(), + } + } + + Ok(()) + } } pub struct WebsocketTunnelRead { inner: WebSocketRead>>, + pending_operations: Sender>, + notify_pending_ops: Arc, } impl WebsocketTunnelRead { - pub const fn new(ws: WebSocketRead>>) -> Self { - Self { inner: ws } + pub fn new(ws: WebSocketRead>>) -> (Self, (Receiver>, Arc)) { + let (tx, rx) = tokio::sync::mpsc::channel(10); + let notify = Arc::new(Notify::new()); + ( + Self { + inner: ws, + pending_operations: tx, + notify_pending_ops: notify.clone(), + }, + (rx, notify), + ) } } @@ -127,9 +187,36 @@ impl TunnelRead for WebsocketTunnelRead { Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)), } } - OpCode::Close => return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")), - OpCode::Ping => continue, - OpCode::Pong => continue, + OpCode::Close => { + let _ = self + .pending_operations + .send(Frame::close(CloseCode::Normal.into(), &[])) + .await; + self.notify_pending_ops.notify_waiters(); + return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")); + } + OpCode::Ping => { + if self + .pending_operations + .send(Frame::new(true, msg.opcode, None, Payload::Owned(msg.payload.to_owned()))) + .await + .is_err() + { + return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send ping")); + } + self.notify_pending_ops.notify_waiters(); + } + OpCode::Pong => { + if self + .pending_operations + .send(Frame::pong(Payload::Borrowed(&[]))) + .await + .is_err() + { + return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send pong")); + } + self.notify_pending_ops.notify_waiters(); + } }; } } @@ -196,12 +283,11 @@ pub async fn connect( .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); + ws.set_auto_close(false); + ws.set_auto_pong(false); let (ws_rx, ws_tx) = ws.split(tokio::io::split); - Ok(( - WebsocketTunnelRead::new(ws_rx), - WebsocketTunnelWrite::new(ws_tx), - response.into_parts().0, - )) + let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); + Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0)) }