feat(udp): Unleash max performance for udp server
This commit is contained in:
parent
f54d170ed6
commit
a832004783
1 changed files with 38 additions and 55 deletions
93
src/udp.rs
93
src/udp.rs
|
@ -1,11 +1,10 @@
|
|||
use anyhow::Context;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures_util::{stream, Stream};
|
||||
use futures_util::{pin_mut, stream, Stream};
|
||||
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
@ -13,17 +12,16 @@ use std::net::SocketAddr;
|
|||
use std::ops::DerefMut;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{Poll, Waker};
|
||||
use std::task::{ready, Poll, Waker};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::Sleep;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
const MAX_UDP_PAYLOAD: usize = 65536;
|
||||
const DEFAULT_UDP_BUFFER_SIZE: usize = 64 * 1024; // 64kb
|
||||
|
||||
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>, VecDeque<usize>)>>;
|
||||
type IoInner = Arc<(Notify, Mutex<Option<Waker>>, Notify)>;
|
||||
struct UdpServer {
|
||||
listener: Arc<UdpSocket>,
|
||||
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
|
||||
|
@ -64,6 +62,7 @@ pub struct UdpStream {
|
|||
peer: SocketAddr,
|
||||
#[pin]
|
||||
deadline: Option<Sleep>,
|
||||
has_been_notified: bool,
|
||||
io: IoInner,
|
||||
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||
}
|
||||
|
@ -93,22 +92,20 @@ impl AsyncRead for UdpStream {
|
|||
}
|
||||
}
|
||||
|
||||
let mut guard = project.io.lock();
|
||||
let (ibuf, waker, read_lens) = guard.deref_mut();
|
||||
if let Some(read_len) = read_lens.pop_front() {
|
||||
if read_len > obuf.remaining() {
|
||||
read_lens.push_front(read_len);
|
||||
waker.replace(cx.waker().clone());
|
||||
if !*project.has_been_notified {
|
||||
let notified = project.io.0.notified();
|
||||
pin_mut!(notified);
|
||||
if !notified.poll(cx).is_ready() {
|
||||
project.io.1.lock().replace(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
obuf.put_slice(&ibuf[..read_len]);
|
||||
ibuf.advance(read_len);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
waker.replace(cx.waker().clone());
|
||||
Poll::Pending
|
||||
*project.has_been_notified = true;
|
||||
}
|
||||
|
||||
let _ = ready!(project.socket.poll_recv(cx, obuf));
|
||||
*project.has_been_notified = false;
|
||||
project.io.2.notify_one();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,39 +161,22 @@ pub async fn run_server(
|
|||
|
||||
match server.peers.entry(peer_addr) {
|
||||
Entry::Occupied(mut peer) => {
|
||||
let mut guard = peer.get_mut().lock();
|
||||
let (buf, waker, read_lens) = guard.deref_mut();
|
||||
// As we have done a peek_sender before, we are sure that there is pending read data
|
||||
// and we don't want to wait to avoid holding the lock across await point
|
||||
if buf.capacity() < MAX_UDP_PAYLOAD {
|
||||
buf.reserve(MAX_UDP_PAYLOAD * 2);
|
||||
}
|
||||
match server.listener.try_recv_buf(buf) {
|
||||
Ok(0) => {} // don't wake if nothing was read
|
||||
Ok(len) => {
|
||||
read_lens.push_back(len);
|
||||
if let Some(waker) = waker.take() {
|
||||
drop(guard);
|
||||
waker.wake()
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
peer.get().2.notified().await;
|
||||
{
|
||||
peer.get().0.notify_one();
|
||||
let mut guard = peer.get_mut().1.lock();
|
||||
let waker = guard.deref_mut();
|
||||
if let Some(waker) = waker.take() {
|
||||
drop(guard);
|
||||
server.keys_to_delete.write().push(peer_addr);
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
Entry::Vacant(peer) => {
|
||||
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
|
||||
let len = match server.listener.try_recv_buf(&mut buf) {
|
||||
Ok(0) | Err(_) => {
|
||||
continue;
|
||||
}
|
||||
Ok(len) => len,
|
||||
};
|
||||
let mut read_lens = VecDeque::with_capacity(64);
|
||||
read_lens.push_back(len);
|
||||
let io = Arc::new(Mutex::new((buf, None, read_lens)));
|
||||
let notify = Notify::new();
|
||||
let has_read = Notify::new();
|
||||
notify.notify_one();
|
||||
let io = Arc::new((notify, Mutex::new(None), has_read));
|
||||
peer.insert(io.clone());
|
||||
let udp_client = UdpStream {
|
||||
socket: server.clone_socket(),
|
||||
|
@ -206,6 +186,7 @@ pub async fn run_server(
|
|||
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
||||
.map(tokio::time::sleep_until),
|
||||
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
||||
has_been_notified: false,
|
||||
io,
|
||||
};
|
||||
return Some((Ok(udp_client), (server)));
|
||||
|
@ -310,6 +291,8 @@ mod tests {
|
|||
// Udp Server should respect framing from the client and not merge the two packets
|
||||
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await;
|
||||
assert!(matches!(ret, Ok(Ok(5))));
|
||||
|
||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[10..])).await;
|
||||
assert!(matches!(ret, Ok(Ok(5))));
|
||||
assert_eq!(&buf[..16], b"helloworld test\0");
|
||||
|
@ -335,21 +318,21 @@ mod tests {
|
|||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Ok(Some(Ok(_)))));
|
||||
|
||||
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Ok(Some(Ok(_)))));
|
||||
|
||||
// Take the stream of data
|
||||
let stream = fut.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream);
|
||||
|
||||
let stream2 = fut2.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream2);
|
||||
|
||||
let mut buf = [0u8; 25];
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"aaaaa\0");
|
||||
|
||||
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut2, Ok(Some(Ok(_)))));
|
||||
|
||||
let stream2 = fut2.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream2);
|
||||
|
||||
let ret = stream2.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"bbbbb\0");
|
||||
|
|
Loading…
Reference in a new issue