From a83200478317e890a29522c88209cb8e211d42d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sun, 29 Oct 2023 13:52:37 +0100 Subject: [PATCH] feat(udp): Unleash max performance for udp server --- src/udp.rs | 93 ++++++++++++++++++++++-------------------------------- 1 file changed, 38 insertions(+), 55 deletions(-) diff --git a/src/udp.rs b/src/udp.rs index 79d0ee6..5b6c234 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,11 +1,10 @@ use anyhow::Context; -use bytes::{Buf, BytesMut}; -use futures_util::{stream, Stream}; +use futures_util::{pin_mut, stream, Stream}; use parking_lot::{Mutex, RwLock}; use pin_project::{pin_project, pinned_drop}; use std::collections::hash_map::Entry; -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; @@ -13,17 +12,16 @@ use std::net::SocketAddr; use std::ops::DerefMut; use std::pin::{pin, Pin}; use std::sync::{Arc, Weak}; -use std::task::{Poll, Waker}; +use std::task::{ready, Poll, Waker}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; + +use tokio::sync::Notify; use tokio::time::Sleep; use tracing::{debug, error, info}; -const MAX_UDP_PAYLOAD: usize = 65536; -const DEFAULT_UDP_BUFFER_SIZE: usize = 64 * 1024; // 64kb - -type IoInner = Arc, VecDeque)>>; +type IoInner = Arc<(Notify, Mutex>, Notify)>; struct UdpServer { listener: Arc, peers: HashMap, @@ -64,6 +62,7 @@ pub struct UdpStream { peer: SocketAddr, #[pin] deadline: Option, + has_been_notified: bool, io: IoInner, keys_to_delete: Weak>>, } @@ -93,22 +92,20 @@ impl AsyncRead for UdpStream { } } - let mut guard = project.io.lock(); - let (ibuf, waker, read_lens) = guard.deref_mut(); - if let Some(read_len) = read_lens.pop_front() { - if read_len > obuf.remaining() { - read_lens.push_front(read_len); - waker.replace(cx.waker().clone()); + if !*project.has_been_notified { + let notified = project.io.0.notified(); + pin_mut!(notified); + if !notified.poll(cx).is_ready() { + project.io.1.lock().replace(cx.waker().clone()); return Poll::Pending; } - - obuf.put_slice(&ibuf[..read_len]); - ibuf.advance(read_len); - Poll::Ready(Ok(())) - } else { - waker.replace(cx.waker().clone()); - Poll::Pending + *project.has_been_notified = true; } + + let _ = ready!(project.socket.poll_recv(cx, obuf)); + *project.has_been_notified = false; + project.io.2.notify_one(); + Poll::Ready(Ok(())) } } @@ -164,39 +161,22 @@ pub async fn run_server( match server.peers.entry(peer_addr) { Entry::Occupied(mut peer) => { - let mut guard = peer.get_mut().lock(); - let (buf, waker, read_lens) = guard.deref_mut(); - // As we have done a peek_sender before, we are sure that there is pending read data - // and we don't want to wait to avoid holding the lock across await point - if buf.capacity() < MAX_UDP_PAYLOAD { - buf.reserve(MAX_UDP_PAYLOAD * 2); - } - match server.listener.try_recv_buf(buf) { - Ok(0) => {} // don't wake if nothing was read - Ok(len) => { - read_lens.push_back(len); - if let Some(waker) = waker.take() { - drop(guard); - waker.wake() - } - } - Err(_) => { + peer.get().2.notified().await; + { + peer.get().0.notify_one(); + let mut guard = peer.get_mut().1.lock(); + let waker = guard.deref_mut(); + if let Some(waker) = waker.take() { drop(guard); - server.keys_to_delete.write().push(peer_addr); + waker.wake(); } } } Entry::Vacant(peer) => { - let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE); - let len = match server.listener.try_recv_buf(&mut buf) { - Ok(0) | Err(_) => { - continue; - } - Ok(len) => len, - }; - let mut read_lens = VecDeque::with_capacity(64); - read_lens.push_back(len); - let io = Arc::new(Mutex::new((buf, None, read_lens))); + let notify = Notify::new(); + let has_read = Notify::new(); + notify.notify_one(); + let io = Arc::new((notify, Mutex::new(None), has_read)); peer.insert(io.clone()); let udp_client = UdpStream { socket: server.clone_socket(), @@ -206,6 +186,7 @@ pub async fn run_server( .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) .map(tokio::time::sleep_until), keys_to_delete: Arc::downgrade(&server.keys_to_delete), + has_been_notified: false, io, }; return Some((Ok(udp_client), (server))); @@ -310,6 +291,8 @@ mod tests { // Udp Server should respect framing from the client and not merge the two packets let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await; assert!(matches!(ret, Ok(Ok(5)))); + + let _ = timeout(Duration::from_millis(100), server.next()).await; let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[10..])).await; assert!(matches!(ret, Ok(Ok(5)))); assert_eq!(&buf[..16], b"helloworld test\0"); @@ -335,21 +318,21 @@ mod tests { let fut = timeout(Duration::from_millis(100), server.next()).await; assert!(matches!(fut, Ok(Some(Ok(_))))); - let fut2 = timeout(Duration::from_millis(100), server.next()).await; - assert!(matches!(fut, Ok(Some(Ok(_))))); - // Take the stream of data let stream = fut.unwrap().unwrap().unwrap(); pin_mut!(stream); - let stream2 = fut2.unwrap().unwrap().unwrap(); - pin_mut!(stream2); - let mut buf = [0u8; 25]; let ret = stream.read(&mut buf).await; assert!(matches!(ret, Ok(5))); assert_eq!(&buf[..6], b"aaaaa\0"); + let fut2 = timeout(Duration::from_millis(100), server.next()).await; + assert!(matches!(fut2, Ok(Some(Ok(_))))); + + let stream2 = fut2.unwrap().unwrap().unwrap(); + pin_mut!(stream2); + let ret = stream2.read(&mut buf).await; assert!(matches!(ret, Ok(5))); assert_eq!(&buf[..6], b"bbbbb\0");