Respect UDP framing for wireguard

This commit is contained in:
Σrebe - Romain GERARD 2023-10-28 19:58:25 +02:00
parent a9d3cf0ab5
commit 02dcab74ec
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4

View file

@ -3,7 +3,7 @@ use bytes::{Buf, BytesMut};
use futures_util::{stream, Stream};
use pin_project::{pin_project, pinned_drop};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::io;
use std::io::{Error, ErrorKind};
@ -20,7 +20,7 @@ use tracing::{debug, error, info};
const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>)>>;
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>, VecDeque<usize>)>>;
struct UdpServer {
listener: Arc<UdpSocket>,
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
@ -91,11 +91,16 @@ impl AsyncRead for UdpStream {
}
let mut guard = project.io.lock().unwrap();
let (ibuf, waker) = guard.deref_mut();
if ibuf.has_remaining() {
let max = ibuf.remaining().min(obuf.remaining());
obuf.put_slice(&ibuf[..max]);
ibuf.advance(max);
let (ibuf, waker, read_lens) = guard.deref_mut();
if let Some(read_len) = read_lens.pop_front() {
if read_len > obuf.remaining() {
read_lens.push_front(read_len);
waker.replace(cx.waker().clone());
return Poll::Pending;
}
obuf.put_slice(&ibuf[..read_len]);
ibuf.advance(read_len);
Poll::Ready(Ok(()))
} else {
waker.replace(cx.waker().clone());
@ -157,13 +162,15 @@ pub async fn run_server(
match server.peers.entry(peer_addr) {
Entry::Occupied(mut peer) => {
let mut guard = peer.get_mut().lock().unwrap();
let (buf, waker) = guard.deref_mut();
let (buf, waker, read_lens) = guard.deref_mut();
// As we have done a peek_sender before, we are sure that there is pending read data
// and we don't want to wait to avoid holding the lock across await point
match server.listener.try_recv_buf(buf) {
Ok(0) => {} // don't wake if nothing was read
Ok(_) => {
Ok(len) => {
read_lens.push_back(len);
if let Some(waker) = waker.take() {
drop(guard);
waker.wake()
}
}
@ -175,12 +182,15 @@ pub async fn run_server(
}
Entry::Vacant(peer) => {
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
match server.listener.recv_buf(&mut buf).await {
Ok(0) | Err(_) => continue,
let len = match server.listener.recv_buf(&mut buf).await {
Ok(0) | Err(_) => {
continue;
}
Ok(len) => len,
};
let io = Arc::new(Mutex::new((buf, None)));
let mut read_lens = VecDeque::with_capacity(64);
read_lens.push_back(len);
let io = Arc::new(Mutex::new((buf, None, read_lens)));
peer.insert(io.clone());
let udp_client = UdpStream {
socket: server.clone_socket(),