feat(udp): Use activity based timeout instead of hard deadline

This commit is contained in:
Σrebe - Romain GERARD 2023-11-01 11:20:45 +01:00
parent 9a5ba4783b
commit 9883b8b32b
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 32 additions and 22 deletions

View file

@ -1,7 +1,6 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut;
use hyper::upgrade::Upgraded;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
@ -25,8 +24,8 @@ pub(super) async fn propagate_read(
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
let timeout = tokio::time::interval_at(tokio::time::Instant::now() + ping_frequency, ping_frequency);
pin_mut!(timeout);
pin_mut!(local_rx);
loop {
@ -37,12 +36,10 @@ pub(super) async fn propagate_read(
_ = close_tx.closed() => break,
_ = &mut timeout => {
_ = timeout.tick() => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
timeout_unpin = tokio::time::sleep(ping_frequency);
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
continue;
}
};
@ -51,7 +48,7 @@ pub(super) async fn propagate_read(
Ok(0) => break,
Ok(read_len) => read_len,
Err(err) => {
warn!("error while reading incoming bytes from local tx tunnel {}", err);
warn!("error while reading incoming bytes from local tx tunnel: {}", err);
break;
}
};

View file

@ -19,7 +19,7 @@ use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;
use tokio::sync::Notify;
use tokio::time::{timeout, Sleep};
use tokio::time::{timeout, Interval};
use tracing::{debug, error, info};
use url::Host;
@ -67,7 +67,8 @@ pub struct UdpStream {
socket: Arc<UdpSocket>,
peer: SocketAddr,
#[pin]
deadline: Option<Sleep>,
watchdog_deadline: Option<Interval>,
data_read_before_deadline: bool,
has_been_notified: bool,
#[pin]
pending_notification: Option<Notified<'static>>,
@ -94,12 +95,11 @@ impl UdpStream {
fn new(
socket: Arc<UdpSocket>,
peer: SocketAddr,
deadline: Option<Sleep>,
watchdog_deadline: Option<Duration>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
) -> (Self, Arc<IoInner>) {
let has_data_to_read = Notify::new();
let has_read_data = Notify::new();
has_data_to_read.notify_one();
let io = Arc::new(IoInner {
has_data_to_read,
has_read_data,
@ -107,7 +107,9 @@ impl UdpStream {
let mut s = Self {
socket,
peer,
deadline,
watchdog_deadline: watchdog_deadline
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
data_read_before_deadline: false,
has_been_notified: false,
pending_notification: None,
io: io.clone(),
@ -128,12 +130,19 @@ impl AsyncRead for UdpStream {
obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() {
if deadline.poll(cx).is_ready() {
return Poll::Ready(Err(Error::new(
ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer),
)));
// Look that the timeout for client has not elapsed
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
if deadline.poll_tick(cx).is_ready() {
return if *project.data_read_before_deadline {
*project.data_read_before_deadline = false;
let _ = deadline.poll_tick(cx);
Poll::Pending
} else {
Poll::Ready(Err(Error::new(
ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer),
)))
};
}
}
@ -144,6 +153,7 @@ impl AsyncRead for UdpStream {
let peer = ready!(project.socket.poll_recv_from(cx, obuf))?;
debug_assert_eq!(peer, *project.peer);
*project.data_read_before_deadline = true;
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
project.pending_notification.as_mut().set(Some(notified));
project.io.has_read_data.notify_one();
@ -184,7 +194,9 @@ pub async fn run_server(
// New returned peer hasn't read its data yet, await for it.
if let Some(await_peer) = peer_with_data {
if let Some(peer) = server.peers.get(&await_peer) {
info!("waiting for peer {} to read its first data", await_peer.port());
peer.has_read_data.notified().await;
info!("peer {} to read its first data", await_peer.port());
}
};
@ -200,19 +212,20 @@ pub async fn run_server(
match server.peers.get(&peer_addr) {
Some(io) => {
info!("waiting for peer {} to read its data", peer_addr.port());
io.has_data_to_read.notify_one();
io.has_read_data.notified().await;
info!("peer {} to read its data", peer_addr.port());
}
None => {
info!("New UDP connection from {}", peer_addr);
let (udp_client, io) = UdpStream::new(
server.clone_socket(),
peer_addr,
server
.cnx_timeout
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until),
server.cnx_timeout,
Arc::downgrade(&server.keys_to_delete),
);
io.has_data_to_read.notify_waiters();
server.peers.insert(peer_addr, io);
return Some((Ok(udp_client), (server, Some(peer_addr))));
}