feat(udp): Use activity based timeout instead of hard deadline

This commit is contained in:
Σrebe - Romain GERARD 2023-11-01 11:20:45 +01:00
parent 9a5ba4783b
commit 9883b8b32b
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 32 additions and 22 deletions

View file

@ -1,7 +1,6 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use std::pin::Pin;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
@ -25,8 +24,8 @@ pub(super) async fn propagate_read(
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration // We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop // We reuse the future to avoid creating a timer in the tight loop
let mut timeout_unpin = tokio::time::sleep(ping_frequency); let timeout = tokio::time::interval_at(tokio::time::Instant::now() + ping_frequency, ping_frequency);
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) }; pin_mut!(timeout);
pin_mut!(local_rx); pin_mut!(local_rx);
loop { loop {
@ -37,12 +36,10 @@ pub(super) async fn propagate_read(
_ = close_tx.closed() => break, _ = close_tx.closed() => break,
_ = &mut timeout => { _ = timeout.tick() => {
debug!("sending ping to keep websocket connection alive"); debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
timeout_unpin = tokio::time::sleep(ping_frequency);
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
continue; continue;
} }
}; };
@ -51,7 +48,7 @@ pub(super) async fn propagate_read(
Ok(0) => break, Ok(0) => break,
Ok(read_len) => read_len, Ok(read_len) => read_len,
Err(err) => { Err(err) => {
warn!("error while reading incoming bytes from local tx tunnel {}", err); warn!("error while reading incoming bytes from local tx tunnel: {}", err);
break; break;
} }
}; };

View file

@ -19,7 +19,7 @@ use tokio::net::UdpSocket;
use tokio::sync::futures::Notified; use tokio::sync::futures::Notified;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio::time::{timeout, Sleep}; use tokio::time::{timeout, Interval};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
use url::Host; use url::Host;
@ -67,7 +67,8 @@ pub struct UdpStream {
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
peer: SocketAddr, peer: SocketAddr,
#[pin] #[pin]
deadline: Option<Sleep>, watchdog_deadline: Option<Interval>,
data_read_before_deadline: bool,
has_been_notified: bool, has_been_notified: bool,
#[pin] #[pin]
pending_notification: Option<Notified<'static>>, pending_notification: Option<Notified<'static>>,
@ -94,12 +95,11 @@ impl UdpStream {
fn new( fn new(
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
peer: SocketAddr, peer: SocketAddr,
deadline: Option<Sleep>, watchdog_deadline: Option<Duration>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>, keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
) -> (Self, Arc<IoInner>) { ) -> (Self, Arc<IoInner>) {
let has_data_to_read = Notify::new(); let has_data_to_read = Notify::new();
let has_read_data = Notify::new(); let has_read_data = Notify::new();
has_data_to_read.notify_one();
let io = Arc::new(IoInner { let io = Arc::new(IoInner {
has_data_to_read, has_data_to_read,
has_read_data, has_read_data,
@ -107,7 +107,9 @@ impl UdpStream {
let mut s = Self { let mut s = Self {
socket, socket,
peer, peer,
deadline, watchdog_deadline: watchdog_deadline
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
data_read_before_deadline: false,
has_been_notified: false, has_been_notified: false,
pending_notification: None, pending_notification: None,
io: io.clone(), io: io.clone(),
@ -128,12 +130,19 @@ impl AsyncRead for UdpStream {
obuf: &mut ReadBuf<'_>, obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let mut project = self.project(); let mut project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() { // Look that the timeout for client has not elapsed
if deadline.poll(cx).is_ready() { if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
return Poll::Ready(Err(Error::new( if deadline.poll_tick(cx).is_ready() {
return if *project.data_read_before_deadline {
*project.data_read_before_deadline = false;
let _ = deadline.poll_tick(cx);
Poll::Pending
} else {
Poll::Ready(Err(Error::new(
ErrorKind::TimedOut, ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer), format!("UDP stream timeout with {}", project.peer),
))); )))
};
} }
} }
@ -144,6 +153,7 @@ impl AsyncRead for UdpStream {
let peer = ready!(project.socket.poll_recv_from(cx, obuf))?; let peer = ready!(project.socket.poll_recv_from(cx, obuf))?;
debug_assert_eq!(peer, *project.peer); debug_assert_eq!(peer, *project.peer);
*project.data_read_before_deadline = true;
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) }; let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
project.pending_notification.as_mut().set(Some(notified)); project.pending_notification.as_mut().set(Some(notified));
project.io.has_read_data.notify_one(); project.io.has_read_data.notify_one();
@ -184,7 +194,9 @@ pub async fn run_server(
// New returned peer hasn't read its data yet, await for it. // New returned peer hasn't read its data yet, await for it.
if let Some(await_peer) = peer_with_data { if let Some(await_peer) = peer_with_data {
if let Some(peer) = server.peers.get(&await_peer) { if let Some(peer) = server.peers.get(&await_peer) {
info!("waiting for peer {} to read its first data", await_peer.port());
peer.has_read_data.notified().await; peer.has_read_data.notified().await;
info!("peer {} to read its first data", await_peer.port());
} }
}; };
@ -200,19 +212,20 @@ pub async fn run_server(
match server.peers.get(&peer_addr) { match server.peers.get(&peer_addr) {
Some(io) => { Some(io) => {
info!("waiting for peer {} to read its data", peer_addr.port());
io.has_data_to_read.notify_one(); io.has_data_to_read.notify_one();
io.has_read_data.notified().await; io.has_read_data.notified().await;
info!("peer {} to read its data", peer_addr.port());
} }
None => { None => {
info!("New UDP connection from {}", peer_addr);
let (udp_client, io) = UdpStream::new( let (udp_client, io) = UdpStream::new(
server.clone_socket(), server.clone_socket(),
peer_addr, peer_addr,
server server.cnx_timeout,
.cnx_timeout
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until),
Arc::downgrade(&server.keys_to_delete), Arc::downgrade(&server.keys_to_delete),
); );
io.has_data_to_read.notify_waiters();
server.peers.insert(peer_addr, io); server.peers.insert(peer_addr, io);
return Some((Ok(udp_client), (server, Some(peer_addr)))); return Some((Ok(udp_client), (server, Some(peer_addr))));
} }