Improve udp server forwarding
This commit is contained in:
parent
5ec9bbaf38
commit
995167c57c
3 changed files with 47 additions and 21 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1544,6 +1544,7 @@ dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64",
|
"base64",
|
||||||
"bb8",
|
"bb8",
|
||||||
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
"fast-socks5",
|
"fast-socks5",
|
||||||
"fastwebsockets",
|
"fastwebsockets",
|
||||||
|
|
|
@ -22,6 +22,7 @@ scopeguard = "1.2.0"
|
||||||
uuid = { version = "1.5.0", features = ["v7", "serde"] }
|
uuid = { version = "1.5.0", features = ["v7", "serde"] }
|
||||||
jsonwebtoken = { version = "9.0.0", default-features = false }
|
jsonwebtoken = { version = "9.0.0", default-features = false }
|
||||||
rustls-pemfile = { version = "1.0.3", features = [] }
|
rustls-pemfile = { version = "1.0.3", features = [] }
|
||||||
|
bytes = { version = "1.5.0", features = [] }
|
||||||
|
|
||||||
rustls-native-certs = { version = "0.6.3", features = [] }
|
rustls-native-certs = { version = "0.6.3", features = [] }
|
||||||
tokio = { version = "1.33.0", features = ["full"] }
|
tokio = { version = "1.33.0", features = ["full"] }
|
||||||
|
|
66
src/udp.rs
66
src/udp.rs
|
@ -1,4 +1,5 @@
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
use bytes::{Buf, BytesMut};
|
||||||
use futures_util::{stream, Stream};
|
use futures_util::{stream, Stream};
|
||||||
use pin_project::{pin_project, pinned_drop};
|
use pin_project::{pin_project, pinned_drop};
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
|
@ -7,21 +8,22 @@ use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::{Error, ErrorKind};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::ops::DerefMut;
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
use std::sync::{Arc, RwLock, Weak};
|
use std::sync::{Arc, Mutex, RwLock, Weak};
|
||||||
use std::task::Poll;
|
use std::task::{Poll, Waker};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::time::Sleep;
|
use tokio::time::Sleep;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
const DEFAULT_UDP_BUFFER_SIZE: usize = 64 * 1024; // 64kb
|
const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
|
||||||
|
|
||||||
|
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>)>>;
|
||||||
struct UdpServer {
|
struct UdpServer {
|
||||||
listener: Arc<UdpSocket>,
|
listener: Arc<UdpSocket>,
|
||||||
buffer: Vec<u8>,
|
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
|
||||||
peers: HashMap<SocketAddr, DuplexStream, ahash::RandomState>,
|
|
||||||
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
|
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
|
||||||
pub cnx_timeout: Option<Duration>,
|
pub cnx_timeout: Option<Duration>,
|
||||||
}
|
}
|
||||||
|
@ -31,7 +33,6 @@ impl UdpServer {
|
||||||
Self {
|
Self {
|
||||||
listener,
|
listener,
|
||||||
peers: HashMap::with_hasher(ahash::RandomState::new()),
|
peers: HashMap::with_hasher(ahash::RandomState::new()),
|
||||||
buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE],
|
|
||||||
keys_to_delete: Default::default(),
|
keys_to_delete: Default::default(),
|
||||||
cnx_timeout: timeout,
|
cnx_timeout: timeout,
|
||||||
}
|
}
|
||||||
|
@ -62,8 +63,7 @@ pub struct UdpStream {
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
#[pin]
|
#[pin]
|
||||||
deadline: Option<Sleep>,
|
deadline: Option<Sleep>,
|
||||||
#[pin]
|
io: IoInner,
|
||||||
io: DuplexStream,
|
|
||||||
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +92,17 @@ impl AsyncRead for UdpStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
project.io.poll_read(cx, buf)
|
let mut guard = project.io.lock().unwrap();
|
||||||
|
let (inner, waker) = guard.deref_mut();
|
||||||
|
if inner.has_remaining() {
|
||||||
|
let max = inner.remaining().min(buf.remaining());
|
||||||
|
buf.put_slice(&inner[..max]);
|
||||||
|
inner.advance(max);
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
waker.replace(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +148,7 @@ pub async fn run_server(
|
||||||
let stream = stream::unfold(udp_server, |mut server| async {
|
let stream = stream::unfold(udp_server, |mut server| async {
|
||||||
loop {
|
loop {
|
||||||
server.clean_dead_keys();
|
server.clean_dead_keys();
|
||||||
let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await {
|
let peer_addr = match server.listener.peek_sender().await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Cannot read from UDP server. Closing server: {}", err);
|
error!("Cannot read from UDP server. Closing server: {}", err);
|
||||||
|
@ -148,18 +158,32 @@ pub async fn run_server(
|
||||||
|
|
||||||
match server.peers.entry(peer_addr) {
|
match server.peers.entry(peer_addr) {
|
||||||
Entry::Occupied(mut peer) => {
|
Entry::Occupied(mut peer) => {
|
||||||
let ret = peer.get_mut().write_all(&server.buffer[..nb_bytes]).await;
|
let mut guard = peer.get_mut().lock().unwrap();
|
||||||
if let Err(err) = ret {
|
let (buf, waker) = guard.deref_mut();
|
||||||
info!("Peer {:?} disconnected {:?}", peer_addr, err);
|
// As we have done a peek_sender before, we are sure that there is pending read data
|
||||||
peer.remove();
|
// and we don't want to wait to avoid holding the lock across await point
|
||||||
|
match server.listener.try_recv_buf(buf) {
|
||||||
|
Ok(0) => {} // don't wake if nothing was read
|
||||||
|
Ok(_) => {
|
||||||
|
if let Some(waker) = waker.take() {
|
||||||
|
waker.wake()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
drop(guard);
|
||||||
|
server.keys_to_delete.write().unwrap().push(peer_addr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Entry::Vacant(peer) => {
|
Entry::Vacant(peer) => {
|
||||||
let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE);
|
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
|
||||||
rx.write_all(&server.buffer[..nb_bytes])
|
match server.listener.recv_buf(&mut buf).await {
|
||||||
.await
|
Ok(0) | Err(_) => continue,
|
||||||
.unwrap_or_default(); // should never fail
|
Ok(len) => len,
|
||||||
peer.insert(rx);
|
};
|
||||||
|
|
||||||
|
let io = Arc::new(Mutex::new((buf, None)));
|
||||||
|
peer.insert(io.clone());
|
||||||
let udp_client = UdpStream {
|
let udp_client = UdpStream {
|
||||||
socket: server.clone_socket(),
|
socket: server.clone_socket(),
|
||||||
peer: peer_addr,
|
peer: peer_addr,
|
||||||
|
@ -168,7 +192,7 @@ pub async fn run_server(
|
||||||
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
||||||
.map(tokio::time::sleep_until),
|
.map(tokio::time::sleep_until),
|
||||||
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
||||||
io: tx,
|
io,
|
||||||
};
|
};
|
||||||
return Some((Ok(udp_client), (server)));
|
return Some((Ok(udp_client), (server)));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue