Respect UDP framing for wireguard
This commit is contained in:
parent
a9d3cf0ab5
commit
02dcab74ec
1 changed files with 23 additions and 13 deletions
36
src/udp.rs
36
src/udp.rs
|
@ -3,7 +3,7 @@ use bytes::{Buf, BytesMut};
|
|||
use futures_util::{stream, Stream};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
@ -20,7 +20,7 @@ use tracing::{debug, error, info};
|
|||
|
||||
const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
|
||||
|
||||
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>)>>;
|
||||
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>, VecDeque<usize>)>>;
|
||||
struct UdpServer {
|
||||
listener: Arc<UdpSocket>,
|
||||
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
|
||||
|
@ -91,11 +91,16 @@ impl AsyncRead for UdpStream {
|
|||
}
|
||||
|
||||
let mut guard = project.io.lock().unwrap();
|
||||
let (ibuf, waker) = guard.deref_mut();
|
||||
if ibuf.has_remaining() {
|
||||
let max = ibuf.remaining().min(obuf.remaining());
|
||||
obuf.put_slice(&ibuf[..max]);
|
||||
ibuf.advance(max);
|
||||
let (ibuf, waker, read_lens) = guard.deref_mut();
|
||||
if let Some(read_len) = read_lens.pop_front() {
|
||||
if read_len > obuf.remaining() {
|
||||
read_lens.push_front(read_len);
|
||||
waker.replace(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
obuf.put_slice(&ibuf[..read_len]);
|
||||
ibuf.advance(read_len);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
waker.replace(cx.waker().clone());
|
||||
|
@ -157,13 +162,15 @@ pub async fn run_server(
|
|||
match server.peers.entry(peer_addr) {
|
||||
Entry::Occupied(mut peer) => {
|
||||
let mut guard = peer.get_mut().lock().unwrap();
|
||||
let (buf, waker) = guard.deref_mut();
|
||||
let (buf, waker, read_lens) = guard.deref_mut();
|
||||
// As we have done a peek_sender before, we are sure that there is pending read data
|
||||
// and we don't want to wait to avoid holding the lock across await point
|
||||
match server.listener.try_recv_buf(buf) {
|
||||
Ok(0) => {} // don't wake if nothing was read
|
||||
Ok(_) => {
|
||||
Ok(len) => {
|
||||
read_lens.push_back(len);
|
||||
if let Some(waker) = waker.take() {
|
||||
drop(guard);
|
||||
waker.wake()
|
||||
}
|
||||
}
|
||||
|
@ -175,12 +182,15 @@ pub async fn run_server(
|
|||
}
|
||||
Entry::Vacant(peer) => {
|
||||
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
|
||||
match server.listener.recv_buf(&mut buf).await {
|
||||
Ok(0) | Err(_) => continue,
|
||||
let len = match server.listener.recv_buf(&mut buf).await {
|
||||
Ok(0) | Err(_) => {
|
||||
continue;
|
||||
}
|
||||
Ok(len) => len,
|
||||
};
|
||||
|
||||
let io = Arc::new(Mutex::new((buf, None)));
|
||||
let mut read_lens = VecDeque::with_capacity(64);
|
||||
read_lens.push_back(len);
|
||||
let io = Arc::new(Mutex::new((buf, None, read_lens)));
|
||||
peer.insert(io.clone());
|
||||
let udp_client = UdpStream {
|
||||
socket: server.clone_socket(),
|
||||
|
|
Loading…
Reference in a new issue