feat(udp): Unleash max performance for udp server

This commit is contained in:
Σrebe - Romain GERARD 2023-10-29 14:52:02 +01:00
parent a832004783
commit 9af089d0b3
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4

View file

@ -1,5 +1,5 @@
use anyhow::Context; use anyhow::Context;
use futures_util::{pin_mut, stream, Stream}; use futures_util::{stream, Stream};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
@ -16,15 +16,20 @@ use std::task::{ready, Poll, Waker};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio::time::Sleep; use tokio::time::Sleep;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
type IoInner = Arc<(Notify, Mutex<Option<Waker>>, Notify)>; struct IoInner {
has_data_to_read: &'static Notify,
waker: Mutex<Option<Waker>>,
has_read_data: Notify,
}
struct UdpServer { struct UdpServer {
listener: Arc<UdpSocket>, listener: Arc<UdpSocket>,
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>, peers: HashMap<SocketAddr, Arc<IoInner>, ahash::RandomState>,
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>, keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
pub cnx_timeout: Option<Duration>, pub cnx_timeout: Option<Duration>,
} }
@ -47,7 +52,16 @@ impl UdpServer {
debug!("Cleaning {} dead udp peers", nb_key_to_delete); debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write(); let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() { for key in keys_to_delete.iter() {
self.peers.remove(key); let Some(peer) = self.peers.remove(key) else {
continue;
};
#[allow(mutable_transmutes)]
unsafe {
let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>(
peer.has_data_to_read,
));
}
} }
keys_to_delete.clear(); keys_to_delete.clear();
} }
@ -63,7 +77,9 @@ pub struct UdpStream {
#[pin] #[pin]
deadline: Option<Sleep>, deadline: Option<Sleep>,
has_been_notified: bool, has_been_notified: bool,
io: IoInner, #[pin]
pending_notification: Option<Notified<'static>>,
io: Arc<IoInner>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>, keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
} }
@ -73,6 +89,8 @@ impl PinnedDrop for UdpStream {
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() { if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
keys_to_delete.write().push(self.peer); keys_to_delete.write().push(self.peer);
} }
self.io.has_read_data.notify_one();
} }
} }
@ -82,7 +100,7 @@ impl AsyncRead for UdpStream {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
obuf: &mut ReadBuf<'_>, obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let project = self.project(); let mut project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() { if let Some(deadline) = project.deadline.as_pin_mut() {
if deadline.poll(cx).is_ready() { if deadline.poll(cx).is_ready() {
return Poll::Ready(Err(Error::new( return Poll::Ready(Err(Error::new(
@ -92,19 +110,20 @@ impl AsyncRead for UdpStream {
} }
} }
if !*project.has_been_notified { if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
let notified = project.io.0.notified();
pin_mut!(notified);
if !notified.poll(cx).is_ready() { if !notified.poll(cx).is_ready() {
project.io.1.lock().replace(cx.waker().clone()); project.io.waker.lock().replace(cx.waker().clone());
return Poll::Pending; return Poll::Pending;
} }
*project.has_been_notified = true; project.pending_notification.as_mut().set(None);
} }
let _ = ready!(project.socket.poll_recv(cx, obuf)); let _ = ready!(project.socket.poll_recv(cx, obuf));
*project.has_been_notified = false; project
project.io.2.notify_one(); .pending_notification
.as_mut()
.set(Some(project.io.has_data_to_read.notified()));
project.io.has_read_data.notify_one();
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
@ -149,8 +168,8 @@ pub async fn run_server(
let udp_server = UdpServer::new(Arc::new(listener), timeout); let udp_server = UdpServer::new(Arc::new(listener), timeout);
let stream = stream::unfold(udp_server, |mut server| async { let stream = stream::unfold(udp_server, |mut server| async {
server.clean_dead_keys();
loop { loop {
server.clean_dead_keys();
let peer_addr = match server.listener.peek_sender().await { let peer_addr = match server.listener.peek_sender().await {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
@ -161,22 +180,24 @@ pub async fn run_server(
match server.peers.entry(peer_addr) { match server.peers.entry(peer_addr) {
Entry::Occupied(mut peer) => { Entry::Occupied(mut peer) => {
peer.get().2.notified().await; let io = peer.get_mut();
{ io.has_read_data.notified().await;
peer.get().0.notify_one(); io.has_data_to_read.notify_one();
let mut guard = peer.get_mut().1.lock(); let waker = io.waker.lock().deref_mut().take();
let waker = guard.deref_mut(); if let Some(waker) = waker {
if let Some(waker) = waker.take() { waker.wake();
drop(guard);
waker.wake();
}
} }
} }
Entry::Vacant(peer) => { Entry::Vacant(peer) => {
let notify = Notify::new(); let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new()));
let has_read = Notify::new(); let pending_notification = has_data_to_read.notified();
notify.notify_one(); let has_read_data = Notify::new();
let io = Arc::new((notify, Mutex::new(None), has_read)); has_data_to_read.notify_one();
let io = Arc::new(IoInner {
has_data_to_read,
waker: Mutex::new(None),
has_read_data,
});
peer.insert(io.clone()); peer.insert(io.clone());
let udp_client = UdpStream { let udp_client = UdpStream {
socket: server.clone_socket(), socket: server.clone_socket(),
@ -187,6 +208,7 @@ pub async fn run_server(
.map(tokio::time::sleep_until), .map(tokio::time::sleep_until),
keys_to_delete: Arc::downgrade(&server.keys_to_delete), keys_to_delete: Arc::downgrade(&server.keys_to_delete),
has_been_notified: false, has_been_notified: false,
pending_notification: Some(pending_notification),
io, io,
}; };
return Some((Ok(udp_client), (server))); return Some((Ok(udp_client), (server)));