Auto break connection after too many pings
This commit is contained in:
parent
21c4f7ffc6
commit
f55643550b
5 changed files with 165 additions and 24 deletions
|
@ -11,6 +11,7 @@ use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};
|
|||
use hyper::{Request, Response};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{error, warn, Instrument, Span};
|
||||
|
||||
|
@ -45,24 +46,32 @@ pub(super) async fn ws_server_upgrade(
|
|||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let (ws_rx, mut ws_tx) = match fut.await {
|
||||
Ok(ws) => ws.split(tokio::io::split),
|
||||
let (ws_rx, ws_tx) = match fut.await {
|
||||
Ok(mut ws) => {
|
||||
ws.set_auto_pong(false);
|
||||
ws.set_auto_close(false);
|
||||
ws.set_auto_apply_mask(mask_frame);
|
||||
ws.split(tokio::io::split)
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Error during http upgrade request: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
ws_tx.set_auto_apply_mask(mask_frame);
|
||||
|
||||
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
|
||||
tokio::task::spawn(
|
||||
transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx)
|
||||
.instrument(Span::current()),
|
||||
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
||||
);
|
||||
|
||||
let _ =
|
||||
transport::io::propagate_local_to_remote(local_rx, WebsocketTunnelWrite::new(ws_tx), close_tx, None)
|
||||
.await;
|
||||
let _ = transport::io::propagate_local_to_remote(
|
||||
local_rx,
|
||||
WebsocketTunnelWrite::new(ws_tx, pending_ops),
|
||||
close_tx,
|
||||
Some(Duration::from_secs(30)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
|
|
@ -436,6 +436,7 @@ impl WsServer {
|
|||
let websocket_upgrade_fn =
|
||||
mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr);
|
||||
let conn_fut = http1::Builder::new()
|
||||
.header_read_timeout(Duration::from_secs(10))
|
||||
.serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
|
||||
.with_upgrades();
|
||||
|
||||
|
|
|
@ -12,11 +12,14 @@ use hyper::http::response::Parts;
|
|||
use hyper::Request;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use log::{debug, error, warn};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use uuid::Uuid;
|
||||
|
@ -97,6 +100,14 @@ impl TunnelWrite for Http2TunnelWrite {
|
|||
async fn close(&mut self) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pending_operations_notify(&mut self) -> Arc<Notify> {
|
||||
Arc::new(Notify::new())
|
||||
}
|
||||
|
||||
fn handle_pending_operations(&mut self) -> impl Future<Output = Result<(), io::Error>> + Send {
|
||||
std::future::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
|
@ -177,6 +188,7 @@ pub async fn connect(
|
|||
.timer(TokioTimer::new())
|
||||
.adaptive_window(true)
|
||||
.keep_alive_interval(client.config.websocket_ping_frequency)
|
||||
.keep_alive_timeout(Duration::from_secs(10))
|
||||
.keep_alive_while_idle(false)
|
||||
.handshake(TokioIo::new(transport))
|
||||
.await
|
||||
|
|
|
@ -3,10 +3,12 @@ use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWr
|
|||
use bytes::{BufMut, BytesMut};
|
||||
use futures_util::{pin_mut, FutureExt};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::select;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tracing::log::debug;
|
||||
use tracing::{error, info, warn};
|
||||
|
@ -18,6 +20,8 @@ pub trait TunnelWrite: Send + 'static {
|
|||
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
fn pending_operations_notify(&mut self) -> Arc<Notify>;
|
||||
fn handle_pending_operations(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
}
|
||||
|
||||
pub trait TunnelRead: Send + 'static {
|
||||
|
@ -74,6 +78,20 @@ impl TunnelWrite for TunnelWriter {
|
|||
Self::Http2(s) => s.close().await,
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_operations_notify(&mut self) -> Arc<Notify> {
|
||||
match self {
|
||||
Self::Websocket(s) => s.pending_operations_notify(),
|
||||
Self::Http2(s) => s.pending_operations_notify(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_pending_operations(&mut self) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
Self::Websocket(s) => s.handle_pending_operations().await,
|
||||
Self::Http2(s) => s.handle_pending_operations().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn propagate_local_to_remote(
|
||||
|
@ -94,6 +112,9 @@ pub async fn propagate_local_to_remote(
|
|||
let start_at = Instant::now().checked_add(frequency).unwrap_or_else(Instant::now);
|
||||
let timeout = tokio::time::interval_at(start_at, frequency);
|
||||
let should_close = close_tx.closed().fuse();
|
||||
let notify = ws_tx.pending_operations_notify();
|
||||
let mut has_pending_operations = notify.notified();
|
||||
let mut has_pending_operations_pin = unsafe { Pin::new_unchecked(&mut has_pending_operations) };
|
||||
|
||||
pin_mut!(timeout);
|
||||
pin_mut!(should_close);
|
||||
|
@ -108,9 +129,21 @@ pub async fn propagate_local_to_remote(
|
|||
biased;
|
||||
|
||||
read_len = local_rx.read_buf(ws_tx.buf_mut()) => read_len,
|
||||
|
||||
|
||||
_ = &mut should_close => break,
|
||||
|
||||
_ = &mut has_pending_operations_pin => {
|
||||
has_pending_operations = notify.notified();
|
||||
has_pending_operations_pin = unsafe { Pin::new_unchecked(&mut has_pending_operations) };
|
||||
match ws_tx.handle_pending_operations().await {
|
||||
Ok(_) => continue,
|
||||
Err(err) => {
|
||||
warn!("error while handling pending operations {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
_ = timeout.tick(), if ping_frequency.is_some() => {
|
||||
debug!("sending ping to keep connection alive");
|
||||
ws_tx.ping().await?;
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
|
|||
use crate::tunnel::RemoteAddr;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
use http_body_util::Empty;
|
||||
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
|
||||
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||
|
@ -14,24 +14,38 @@ use hyper::upgrade::Upgraded;
|
|||
use hyper::Request;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::debug;
|
||||
use log::{debug, info};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering::Relaxed;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::Notify;
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct WebsocketTunnelWrite {
|
||||
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
||||
buf: BytesMut,
|
||||
pending_operations: Receiver<Frame<'static>>,
|
||||
pending_ops_notify: Arc<Notify>,
|
||||
in_flight_ping: AtomicUsize,
|
||||
}
|
||||
|
||||
impl WebsocketTunnelWrite {
|
||||
pub fn new(ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>) -> Self {
|
||||
pub fn new(
|
||||
ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
||||
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: ws,
|
||||
buf: BytesMut::with_capacity(MAX_PACKET_LENGTH),
|
||||
pending_operations,
|
||||
pending_ops_notify: notify,
|
||||
in_flight_ping: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,6 +90,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
|
|||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<(), io::Error> {
|
||||
if self.in_flight_ping.fetch_add(1, Relaxed) >= 3 {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::ConnectionAborted,
|
||||
"too many in flight/un-answered pings",
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(err) = self
|
||||
.inner
|
||||
.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
|
||||
|
@ -94,15 +115,54 @@ impl TunnelWrite for WebsocketTunnelWrite {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pending_operations_notify(&mut self) -> Arc<Notify> {
|
||||
self.pending_ops_notify.clone()
|
||||
}
|
||||
|
||||
async fn handle_pending_operations(&mut self) -> Result<(), io::Error> {
|
||||
while let Ok(frame) = self.pending_operations.try_recv() {
|
||||
info!("received frame {:?}", frame.opcode);
|
||||
match frame.opcode {
|
||||
OpCode::Close => {
|
||||
if self.inner.write_frame(frame).await.is_err() {
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send close frame"));
|
||||
}
|
||||
}
|
||||
OpCode::Ping => {
|
||||
if self.inner.write_frame(Frame::pong(frame.payload)).await.is_err() {
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send pong frame"));
|
||||
}
|
||||
}
|
||||
OpCode::Pong => {
|
||||
self.in_flight_ping.fetch_sub(1, Relaxed);
|
||||
}
|
||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WebsocketTunnelRead {
|
||||
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
|
||||
pending_operations: Sender<Frame<'static>>,
|
||||
notify_pending_ops: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl WebsocketTunnelRead {
|
||||
pub const fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> Self {
|
||||
Self { inner: ws }
|
||||
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let notify = Arc::new(Notify::new());
|
||||
(
|
||||
Self {
|
||||
inner: ws,
|
||||
pending_operations: tx,
|
||||
notify_pending_ops: notify.clone(),
|
||||
},
|
||||
(rx, notify),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,9 +187,36 @@ impl TunnelRead for WebsocketTunnelRead {
|
|||
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
|
||||
}
|
||||
}
|
||||
OpCode::Close => return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")),
|
||||
OpCode::Ping => continue,
|
||||
OpCode::Pong => continue,
|
||||
OpCode::Close => {
|
||||
let _ = self
|
||||
.pending_operations
|
||||
.send(Frame::close(CloseCode::Normal.into(), &[]))
|
||||
.await;
|
||||
self.notify_pending_ops.notify_waiters();
|
||||
return Err(io::Error::new(ErrorKind::NotConnected, "websocket close"));
|
||||
}
|
||||
OpCode::Ping => {
|
||||
if self
|
||||
.pending_operations
|
||||
.send(Frame::new(true, msg.opcode, None, Payload::Owned(msg.payload.to_owned())))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send ping"));
|
||||
}
|
||||
self.notify_pending_ops.notify_waiters();
|
||||
}
|
||||
OpCode::Pong => {
|
||||
if self
|
||||
.pending_operations
|
||||
.send(Frame::pong(Payload::Borrowed(&[])))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, "cannot send pong"));
|
||||
}
|
||||
self.notify_pending_ops.notify_waiters();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -196,12 +283,11 @@ pub async fn connect(
|
|||
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
|
||||
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
|
||||
ws.set_auto_close(false);
|
||||
ws.set_auto_pong(false);
|
||||
|
||||
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
|
||||
|
||||
Ok((
|
||||
WebsocketTunnelRead::new(ws_rx),
|
||||
WebsocketTunnelWrite::new(ws_tx),
|
||||
response.into_parts().0,
|
||||
))
|
||||
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
|
||||
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue