From 995167c57c66572eacf7fc64876f137eaa42d00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Thu, 26 Oct 2023 21:45:55 +0200 Subject: [PATCH] Improve udp server forwarding --- Cargo.lock | 1 + Cargo.toml | 1 + src/udp.rs | 66 +++++++++++++++++++++++++++++++++++++----------------- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bdf0209..b29f30f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1544,6 +1544,7 @@ dependencies = [ "async-trait", "base64", "bb8", + "bytes", "clap", "fast-socks5", "fastwebsockets", diff --git a/Cargo.toml b/Cargo.toml index 0fa946c..8ff3b95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ scopeguard = "1.2.0" uuid = { version = "1.5.0", features = ["v7", "serde"] } jsonwebtoken = { version = "9.0.0", default-features = false } rustls-pemfile = { version = "1.0.3", features = [] } +bytes = { version = "1.5.0", features = [] } rustls-native-certs = { version = "0.6.3", features = [] } tokio = { version = "1.33.0", features = ["full"] } diff --git a/src/udp.rs b/src/udp.rs index b5945cc..47741c7 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,4 +1,5 @@ use anyhow::Context; +use bytes::{Buf, BytesMut}; use futures_util::{stream, Stream}; use pin_project::{pin_project, pinned_drop}; use std::collections::hash_map::Entry; @@ -7,21 +8,22 @@ use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; use std::net::SocketAddr; +use std::ops::DerefMut; use std::pin::{pin, Pin}; -use std::sync::{Arc, RwLock, Weak}; -use std::task::Poll; +use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::task::{Poll, Waker}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; use tokio::time::Sleep; use tracing::{debug, error, info}; -const DEFAULT_UDP_BUFFER_SIZE: usize = 64 * 1024; // 64kb +const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb +type IoInner = Arc)>>; struct UdpServer { listener: Arc, - buffer: Vec, - peers: HashMap, + peers: HashMap, keys_to_delete: Arc>>, pub cnx_timeout: Option, } @@ -31,7 +33,6 @@ impl UdpServer { Self { listener, peers: HashMap::with_hasher(ahash::RandomState::new()), - buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE], keys_to_delete: Default::default(), cnx_timeout: timeout, } @@ -62,8 +63,7 @@ pub struct UdpStream { peer: SocketAddr, #[pin] deadline: Option, - #[pin] - io: DuplexStream, + io: IoInner, keys_to_delete: Weak>>, } @@ -92,7 +92,17 @@ impl AsyncRead for UdpStream { } } - project.io.poll_read(cx, buf) + let mut guard = project.io.lock().unwrap(); + let (inner, waker) = guard.deref_mut(); + if inner.has_remaining() { + let max = inner.remaining().min(buf.remaining()); + buf.put_slice(&inner[..max]); + inner.advance(max); + Poll::Ready(Ok(())) + } else { + waker.replace(cx.waker().clone()); + Poll::Pending + } } } @@ -138,7 +148,7 @@ pub async fn run_server( let stream = stream::unfold(udp_server, |mut server| async { loop { server.clean_dead_keys(); - let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await { + let peer_addr = match server.listener.peek_sender().await { Ok(ret) => ret, Err(err) => { error!("Cannot read from UDP server. Closing server: {}", err); @@ -148,18 +158,32 @@ pub async fn run_server( match server.peers.entry(peer_addr) { Entry::Occupied(mut peer) => { - let ret = peer.get_mut().write_all(&server.buffer[..nb_bytes]).await; - if let Err(err) = ret { - info!("Peer {:?} disconnected {:?}", peer_addr, err); - peer.remove(); + let mut guard = peer.get_mut().lock().unwrap(); + let (buf, waker) = guard.deref_mut(); + // As we have done a peek_sender before, we are sure that there is pending read data + // and we don't want to wait to avoid holding the lock across await point + match server.listener.try_recv_buf(buf) { + Ok(0) => {} // don't wake if nothing was read + Ok(_) => { + if let Some(waker) = waker.take() { + waker.wake() + } + } + Err(_) => { + drop(guard); + server.keys_to_delete.write().unwrap().push(peer_addr); + } } } Entry::Vacant(peer) => { - let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE); - rx.write_all(&server.buffer[..nb_bytes]) - .await - .unwrap_or_default(); // should never fail - peer.insert(rx); + let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE); + match server.listener.recv_buf(&mut buf).await { + Ok(0) | Err(_) => continue, + Ok(len) => len, + }; + + let io = Arc::new(Mutex::new((buf, None))); + peer.insert(io.clone()); let udp_client = UdpStream { socket: server.clone_socket(), peer: peer_addr, @@ -168,7 +192,7 @@ pub async fn run_server( .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) .map(tokio::time::sleep_until), keys_to_delete: Arc::downgrade(&server.keys_to_delete), - io: tx, + io, }; return Some((Ok(udp_client), (server))); }