From 297176293caa344d97d5cf8e04e94517cc0c0e71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Fri, 3 Nov 2023 09:17:56 +0100 Subject: [PATCH] cleanup --- src/tunnel/client.rs | 2 +- src/tunnel/io.rs | 11 +++++------ src/tunnel/server.rs | 3 +-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 6710f04..96a202a 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -102,7 +102,7 @@ where // Forward local tx to websocket tx let ping_frequency = client_cfg.websocket_ping_frequency; - tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current())); + tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current())); // Forward websocket rx to local rx let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 537f503..36a0447 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -14,7 +14,7 @@ pub(super) async fn propagate_read( local_rx: impl AsyncRead, mut ws_tx: WebSocketWrite>, mut close_tx: oneshot::Sender<()>, - ping_frequency: Duration, + ping_frequency: Option, ) -> Result<(), WebSocketError> { let _guard = scopeguard::guard((), |_| { info!("Closing local tx ==> websocket tx tunnel"); @@ -25,10 +25,9 @@ pub(super) async fn propagate_read( // We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration // We reuse the future to avoid creating a timer in the tight loop - let start_at = Instant::now() - .checked_add(ping_frequency) - .unwrap_or(Instant::now() + Duration::from_secs(3600 * 24)); - let timeout = tokio::time::interval_at(start_at, ping_frequency); + let frequency = ping_frequency.unwrap_or(Duration::from_secs(u64::MAX)); + let start_at = Instant::now().checked_add(frequency).unwrap_or(Instant::now()); + let timeout = tokio::time::interval_at(start_at, frequency); pin_mut!(timeout); pin_mut!(local_rx); @@ -40,7 +39,7 @@ pub(super) async fn propagate_read( _ = close_tx.closed() => break, - _ = timeout.tick() => { + _ = timeout.tick(), if ping_frequency.is_some() => { debug!("sending ping to keep websocket connection alive"); ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 5f683a6..4b7a4c7 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -143,12 +143,11 @@ async fn server_upgrade( } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX); ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current())); - let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await; + let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, None).await; } .instrument(Span::current()), );