diff --git a/src/udp.rs b/src/udp.rs index 8cdb739..7be195b 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -3,7 +3,7 @@ use bytes::{Buf, BytesMut}; use futures_util::{stream, Stream}; use pin_project::{pin_project, pinned_drop}; use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; @@ -20,7 +20,7 @@ use tracing::{debug, error, info}; const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb -type IoInner = Arc)>>; +type IoInner = Arc, VecDeque)>>; struct UdpServer { listener: Arc, peers: HashMap, @@ -91,11 +91,16 @@ impl AsyncRead for UdpStream { } let mut guard = project.io.lock().unwrap(); - let (ibuf, waker) = guard.deref_mut(); - if ibuf.has_remaining() { - let max = ibuf.remaining().min(obuf.remaining()); - obuf.put_slice(&ibuf[..max]); - ibuf.advance(max); + let (ibuf, waker, read_lens) = guard.deref_mut(); + if let Some(read_len) = read_lens.pop_front() { + if read_len > obuf.remaining() { + read_lens.push_front(read_len); + waker.replace(cx.waker().clone()); + return Poll::Pending; + } + + obuf.put_slice(&ibuf[..read_len]); + ibuf.advance(read_len); Poll::Ready(Ok(())) } else { waker.replace(cx.waker().clone()); @@ -157,13 +162,15 @@ pub async fn run_server( match server.peers.entry(peer_addr) { Entry::Occupied(mut peer) => { let mut guard = peer.get_mut().lock().unwrap(); - let (buf, waker) = guard.deref_mut(); + let (buf, waker, read_lens) = guard.deref_mut(); // As we have done a peek_sender before, we are sure that there is pending read data // and we don't want to wait to avoid holding the lock across await point match server.listener.try_recv_buf(buf) { Ok(0) => {} // don't wake if nothing was read - Ok(_) => { + Ok(len) => { + read_lens.push_back(len); if let Some(waker) = waker.take() { + drop(guard); waker.wake() } } @@ -175,12 +182,15 @@ pub async fn run_server( } Entry::Vacant(peer) => { let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE); - match server.listener.recv_buf(&mut buf).await { - Ok(0) | Err(_) => continue, + let len = match server.listener.recv_buf(&mut buf).await { + Ok(0) | Err(_) => { + continue; + } Ok(len) => len, }; - - let io = Arc::new(Mutex::new((buf, None))); + let mut read_lens = VecDeque::with_capacity(64); + read_lens.push_back(len); + let io = Arc::new(Mutex::new((buf, None, read_lens))); peer.insert(io.clone()); let udp_client = UdpStream { socket: server.clone_socket(),