From 9883b8b32b109881f7dbfbc94446ecdf5cba2191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Wed, 1 Nov 2023 11:20:45 +0100 Subject: [PATCH] feat(udp): Use activity based timeout instead of hard deadline --- src/tunnel/io.rs | 11 ++++------- src/udp.rs | 43 ++++++++++++++++++++++++++++--------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 0caa459..427259c 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,7 +1,6 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use futures_util::pin_mut; use hyper::upgrade::Upgraded; -use std::pin::Pin; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; @@ -25,8 +24,8 @@ pub(super) async fn propagate_read( // We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration // We reuse the future to avoid creating a timer in the tight loop - let mut timeout_unpin = tokio::time::sleep(ping_frequency); - let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) }; + let timeout = tokio::time::interval_at(tokio::time::Instant::now() + ping_frequency, ping_frequency); + pin_mut!(timeout); pin_mut!(local_rx); loop { @@ -37,12 +36,10 @@ pub(super) async fn propagate_read( _ = close_tx.closed() => break, - _ = &mut timeout => { + _ = timeout.tick() => { debug!("sending ping to keep websocket connection alive"); ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; - timeout_unpin = tokio::time::sleep(ping_frequency); - timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) }; continue; } }; @@ -51,7 +48,7 @@ pub(super) async fn propagate_read( Ok(0) => break, Ok(read_len) => read_len, Err(err) => { - warn!("error while reading incoming bytes from local tx tunnel {}", err); + warn!("error while reading incoming bytes from local tx tunnel: {}", err); break; } }; diff --git a/src/udp.rs b/src/udp.rs index ed29bb5..d1a9bb0 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -19,7 +19,7 @@ use tokio::net::UdpSocket; use tokio::sync::futures::Notified; use tokio::sync::Notify; -use tokio::time::{timeout, Sleep}; +use tokio::time::{timeout, Interval}; use tracing::{debug, error, info}; use url::Host; @@ -67,7 +67,8 @@ pub struct UdpStream { socket: Arc, peer: SocketAddr, #[pin] - deadline: Option, + watchdog_deadline: Option, + data_read_before_deadline: bool, has_been_notified: bool, #[pin] pending_notification: Option>, @@ -94,12 +95,11 @@ impl UdpStream { fn new( socket: Arc, peer: SocketAddr, - deadline: Option, + watchdog_deadline: Option, keys_to_delete: Weak>>, ) -> (Self, Arc) { let has_data_to_read = Notify::new(); let has_read_data = Notify::new(); - has_data_to_read.notify_one(); let io = Arc::new(IoInner { has_data_to_read, has_read_data, @@ -107,7 +107,9 @@ impl UdpStream { let mut s = Self { socket, peer, - deadline, + watchdog_deadline: watchdog_deadline + .map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)), + data_read_before_deadline: false, has_been_notified: false, pending_notification: None, io: io.clone(), @@ -128,12 +130,19 @@ impl AsyncRead for UdpStream { obuf: &mut ReadBuf<'_>, ) -> Poll> { let mut project = self.project(); - if let Some(deadline) = project.deadline.as_pin_mut() { - if deadline.poll(cx).is_ready() { - return Poll::Ready(Err(Error::new( - ErrorKind::TimedOut, - format!("UDP stream timeout with {}", project.peer), - ))); + // Look that the timeout for client has not elapsed + if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() { + if deadline.poll_tick(cx).is_ready() { + return if *project.data_read_before_deadline { + *project.data_read_before_deadline = false; + let _ = deadline.poll_tick(cx); + Poll::Pending + } else { + Poll::Ready(Err(Error::new( + ErrorKind::TimedOut, + format!("UDP stream timeout with {}", project.peer), + ))) + }; } } @@ -144,6 +153,7 @@ impl AsyncRead for UdpStream { let peer = ready!(project.socket.poll_recv_from(cx, obuf))?; debug_assert_eq!(peer, *project.peer); + *project.data_read_before_deadline = true; let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) }; project.pending_notification.as_mut().set(Some(notified)); project.io.has_read_data.notify_one(); @@ -184,7 +194,9 @@ pub async fn run_server( // New returned peer hasn't read its data yet, await for it. if let Some(await_peer) = peer_with_data { if let Some(peer) = server.peers.get(&await_peer) { + info!("waiting for peer {} to read its first data", await_peer.port()); peer.has_read_data.notified().await; + info!("peer {} to read its first data", await_peer.port()); } }; @@ -200,19 +212,20 @@ pub async fn run_server( match server.peers.get(&peer_addr) { Some(io) => { + info!("waiting for peer {} to read its data", peer_addr.port()); io.has_data_to_read.notify_one(); io.has_read_data.notified().await; + info!("peer {} to read its data", peer_addr.port()); } None => { + info!("New UDP connection from {}", peer_addr); let (udp_client, io) = UdpStream::new( server.clone_socket(), peer_addr, - server - .cnx_timeout - .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) - .map(tokio::time::sleep_until), + server.cnx_timeout, Arc::downgrade(&server.keys_to_delete), ); + io.has_data_to_read.notify_waiters(); server.peers.insert(peer_addr, io); return Some((Ok(udp_client), (server, Some(peer_addr)))); }