Improve udp server forwarding

This commit is contained in:
Σrebe - Romain GERARD 2023-10-26 21:45:55 +02:00
parent 5ec9bbaf38
commit 995167c57c
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 47 additions and 21 deletions

1
Cargo.lock generated
View file

@ -1544,6 +1544,7 @@ dependencies = [
"async-trait", "async-trait",
"base64", "base64",
"bb8", "bb8",
"bytes",
"clap", "clap",
"fast-socks5", "fast-socks5",
"fastwebsockets", "fastwebsockets",

View file

@ -22,6 +22,7 @@ scopeguard = "1.2.0"
uuid = { version = "1.5.0", features = ["v7", "serde"] } uuid = { version = "1.5.0", features = ["v7", "serde"] }
jsonwebtoken = { version = "9.0.0", default-features = false } jsonwebtoken = { version = "9.0.0", default-features = false }
rustls-pemfile = { version = "1.0.3", features = [] } rustls-pemfile = { version = "1.0.3", features = [] }
bytes = { version = "1.5.0", features = [] }
rustls-native-certs = { version = "0.6.3", features = [] } rustls-native-certs = { version = "0.6.3", features = [] }
tokio = { version = "1.33.0", features = ["full"] } tokio = { version = "1.33.0", features = ["full"] }

View file

@ -1,4 +1,5 @@
use anyhow::Context; use anyhow::Context;
use bytes::{Buf, BytesMut};
use futures_util::{stream, Stream}; use futures_util::{stream, Stream};
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
@ -7,21 +8,22 @@ use std::future::Future;
use std::io; use std::io;
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::ops::DerefMut;
use std::pin::{pin, Pin}; use std::pin::{pin, Pin};
use std::sync::{Arc, RwLock, Weak}; use std::sync::{Arc, Mutex, RwLock, Weak};
use std::task::Poll; use std::task::{Poll, Waker};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::time::Sleep; use tokio::time::Sleep;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
const DEFAULT_UDP_BUFFER_SIZE: usize = 64 * 1024; // 64kb const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>)>>;
struct UdpServer { struct UdpServer {
listener: Arc<UdpSocket>, listener: Arc<UdpSocket>,
buffer: Vec<u8>, peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
peers: HashMap<SocketAddr, DuplexStream, ahash::RandomState>,
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>, keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
pub cnx_timeout: Option<Duration>, pub cnx_timeout: Option<Duration>,
} }
@ -31,7 +33,6 @@ impl UdpServer {
Self { Self {
listener, listener,
peers: HashMap::with_hasher(ahash::RandomState::new()), peers: HashMap::with_hasher(ahash::RandomState::new()),
buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE],
keys_to_delete: Default::default(), keys_to_delete: Default::default(),
cnx_timeout: timeout, cnx_timeout: timeout,
} }
@ -62,8 +63,7 @@ pub struct UdpStream {
peer: SocketAddr, peer: SocketAddr,
#[pin] #[pin]
deadline: Option<Sleep>, deadline: Option<Sleep>,
#[pin] io: IoInner,
io: DuplexStream,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>, keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
} }
@ -92,7 +92,17 @@ impl AsyncRead for UdpStream {
} }
} }
project.io.poll_read(cx, buf) let mut guard = project.io.lock().unwrap();
let (inner, waker) = guard.deref_mut();
if inner.has_remaining() {
let max = inner.remaining().min(buf.remaining());
buf.put_slice(&inner[..max]);
inner.advance(max);
Poll::Ready(Ok(()))
} else {
waker.replace(cx.waker().clone());
Poll::Pending
}
} }
} }
@ -138,7 +148,7 @@ pub async fn run_server(
let stream = stream::unfold(udp_server, |mut server| async { let stream = stream::unfold(udp_server, |mut server| async {
loop { loop {
server.clean_dead_keys(); server.clean_dead_keys();
let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await { let peer_addr = match server.listener.peek_sender().await {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
error!("Cannot read from UDP server. Closing server: {}", err); error!("Cannot read from UDP server. Closing server: {}", err);
@ -148,18 +158,32 @@ pub async fn run_server(
match server.peers.entry(peer_addr) { match server.peers.entry(peer_addr) {
Entry::Occupied(mut peer) => { Entry::Occupied(mut peer) => {
let ret = peer.get_mut().write_all(&server.buffer[..nb_bytes]).await; let mut guard = peer.get_mut().lock().unwrap();
if let Err(err) = ret { let (buf, waker) = guard.deref_mut();
info!("Peer {:?} disconnected {:?}", peer_addr, err); // As we have done a peek_sender before, we are sure that there is pending read data
peer.remove(); // and we don't want to wait to avoid holding the lock across await point
match server.listener.try_recv_buf(buf) {
Ok(0) => {} // don't wake if nothing was read
Ok(_) => {
if let Some(waker) = waker.take() {
waker.wake()
}
}
Err(_) => {
drop(guard);
server.keys_to_delete.write().unwrap().push(peer_addr);
}
} }
} }
Entry::Vacant(peer) => { Entry::Vacant(peer) => {
let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE); let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
rx.write_all(&server.buffer[..nb_bytes]) match server.listener.recv_buf(&mut buf).await {
.await Ok(0) | Err(_) => continue,
.unwrap_or_default(); // should never fail Ok(len) => len,
peer.insert(rx); };
let io = Arc::new(Mutex::new((buf, None)));
peer.insert(io.clone());
let udp_client = UdpStream { let udp_client = UdpStream {
socket: server.clone_socket(), socket: server.clone_socket(),
peer: peer_addr, peer: peer_addr,
@ -168,7 +192,7 @@ pub async fn run_server(
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until), .map(tokio::time::sleep_until),
keys_to_delete: Arc::downgrade(&server.keys_to_delete), keys_to_delete: Arc::downgrade(&server.keys_to_delete),
io: tx, io,
}; };
return Some((Ok(udp_client), (server))); return Some((Ok(udp_client), (server)));
} }