diff --git a/src/udp.rs b/src/udp.rs index 5b6c234..739e698 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,5 +1,5 @@ use anyhow::Context; -use futures_util::{pin_mut, stream, Stream}; +use futures_util::{stream, Stream}; use parking_lot::{Mutex, RwLock}; use pin_project::{pin_project, pinned_drop}; @@ -16,15 +16,20 @@ use std::task::{ready, Poll, Waker}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; +use tokio::sync::futures::Notified; use tokio::sync::Notify; use tokio::time::Sleep; use tracing::{debug, error, info}; -type IoInner = Arc<(Notify, Mutex>, Notify)>; +struct IoInner { + has_data_to_read: &'static Notify, + waker: Mutex>, + has_read_data: Notify, +} struct UdpServer { listener: Arc, - peers: HashMap, + peers: HashMap, ahash::RandomState>, keys_to_delete: Arc>>, pub cnx_timeout: Option, } @@ -47,7 +52,16 @@ impl UdpServer { debug!("Cleaning {} dead udp peers", nb_key_to_delete); let mut keys_to_delete = self.keys_to_delete.write(); for key in keys_to_delete.iter() { - self.peers.remove(key); + let Some(peer) = self.peers.remove(key) else { + continue; + }; + + #[allow(mutable_transmutes)] + unsafe { + let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>( + peer.has_data_to_read, + )); + } } keys_to_delete.clear(); } @@ -63,7 +77,9 @@ pub struct UdpStream { #[pin] deadline: Option, has_been_notified: bool, - io: IoInner, + #[pin] + pending_notification: Option>, + io: Arc, keys_to_delete: Weak>>, } @@ -73,6 +89,8 @@ impl PinnedDrop for UdpStream { if let Some(keys_to_delete) = self.keys_to_delete.upgrade() { keys_to_delete.write().push(self.peer); } + + self.io.has_read_data.notify_one(); } } @@ -82,7 +100,7 @@ impl AsyncRead for UdpStream { cx: &mut std::task::Context<'_>, obuf: &mut ReadBuf<'_>, ) -> Poll> { - let project = self.project(); + let mut project = self.project(); if let Some(deadline) = project.deadline.as_pin_mut() { if deadline.poll(cx).is_ready() { return Poll::Ready(Err(Error::new( @@ -92,19 +110,20 @@ impl AsyncRead for UdpStream { } } - if !*project.has_been_notified { - let notified = project.io.0.notified(); - pin_mut!(notified); + if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() { if !notified.poll(cx).is_ready() { - project.io.1.lock().replace(cx.waker().clone()); + project.io.waker.lock().replace(cx.waker().clone()); return Poll::Pending; } - *project.has_been_notified = true; + project.pending_notification.as_mut().set(None); } let _ = ready!(project.socket.poll_recv(cx, obuf)); - *project.has_been_notified = false; - project.io.2.notify_one(); + project + .pending_notification + .as_mut() + .set(Some(project.io.has_data_to_read.notified())); + project.io.has_read_data.notify_one(); Poll::Ready(Ok(())) } } @@ -149,8 +168,8 @@ pub async fn run_server( let udp_server = UdpServer::new(Arc::new(listener), timeout); let stream = stream::unfold(udp_server, |mut server| async { - server.clean_dead_keys(); loop { + server.clean_dead_keys(); let peer_addr = match server.listener.peek_sender().await { Ok(ret) => ret, Err(err) => { @@ -161,22 +180,24 @@ pub async fn run_server( match server.peers.entry(peer_addr) { Entry::Occupied(mut peer) => { - peer.get().2.notified().await; - { - peer.get().0.notify_one(); - let mut guard = peer.get_mut().1.lock(); - let waker = guard.deref_mut(); - if let Some(waker) = waker.take() { - drop(guard); - waker.wake(); - } + let io = peer.get_mut(); + io.has_read_data.notified().await; + io.has_data_to_read.notify_one(); + let waker = io.waker.lock().deref_mut().take(); + if let Some(waker) = waker { + waker.wake(); } } Entry::Vacant(peer) => { - let notify = Notify::new(); - let has_read = Notify::new(); - notify.notify_one(); - let io = Arc::new((notify, Mutex::new(None), has_read)); + let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new())); + let pending_notification = has_data_to_read.notified(); + let has_read_data = Notify::new(); + has_data_to_read.notify_one(); + let io = Arc::new(IoInner { + has_data_to_read, + waker: Mutex::new(None), + has_read_data, + }); peer.insert(io.clone()); let udp_client = UdpStream { socket: server.clone_socket(), @@ -187,6 +208,7 @@ pub async fn run_server( .map(tokio::time::sleep_until), keys_to_delete: Arc::downgrade(&server.keys_to_delete), has_been_notified: false, + pending_notification: Some(pending_notification), io, }; return Some((Ok(udp_client), (server)));