Respect UDP framing for wireguard
This commit is contained in:
parent
a9d3cf0ab5
commit
02dcab74ec
1 changed files with 23 additions and 13 deletions
36
src/udp.rs
36
src/udp.rs
|
@ -3,7 +3,7 @@ use bytes::{Buf, BytesMut};
|
||||||
use futures_util::{stream, Stream};
|
use futures_util::{stream, Stream};
|
||||||
use pin_project::{pin_project, pinned_drop};
|
use pin_project::{pin_project, pinned_drop};
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::{Error, ErrorKind};
|
||||||
|
@ -20,7 +20,7 @@ use tracing::{debug, error, info};
|
||||||
|
|
||||||
const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
|
const DEFAULT_UDP_BUFFER_SIZE: usize = 32 * 1024; // 32kb
|
||||||
|
|
||||||
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>)>>;
|
type IoInner = Arc<Mutex<(BytesMut, Option<Waker>, VecDeque<usize>)>>;
|
||||||
struct UdpServer {
|
struct UdpServer {
|
||||||
listener: Arc<UdpSocket>,
|
listener: Arc<UdpSocket>,
|
||||||
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
|
peers: HashMap<SocketAddr, IoInner, ahash::RandomState>,
|
||||||
|
@ -91,11 +91,16 @@ impl AsyncRead for UdpStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut guard = project.io.lock().unwrap();
|
let mut guard = project.io.lock().unwrap();
|
||||||
let (ibuf, waker) = guard.deref_mut();
|
let (ibuf, waker, read_lens) = guard.deref_mut();
|
||||||
if ibuf.has_remaining() {
|
if let Some(read_len) = read_lens.pop_front() {
|
||||||
let max = ibuf.remaining().min(obuf.remaining());
|
if read_len > obuf.remaining() {
|
||||||
obuf.put_slice(&ibuf[..max]);
|
read_lens.push_front(read_len);
|
||||||
ibuf.advance(max);
|
waker.replace(cx.waker().clone());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
obuf.put_slice(&ibuf[..read_len]);
|
||||||
|
ibuf.advance(read_len);
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
waker.replace(cx.waker().clone());
|
waker.replace(cx.waker().clone());
|
||||||
|
@ -157,13 +162,15 @@ pub async fn run_server(
|
||||||
match server.peers.entry(peer_addr) {
|
match server.peers.entry(peer_addr) {
|
||||||
Entry::Occupied(mut peer) => {
|
Entry::Occupied(mut peer) => {
|
||||||
let mut guard = peer.get_mut().lock().unwrap();
|
let mut guard = peer.get_mut().lock().unwrap();
|
||||||
let (buf, waker) = guard.deref_mut();
|
let (buf, waker, read_lens) = guard.deref_mut();
|
||||||
// As we have done a peek_sender before, we are sure that there is pending read data
|
// As we have done a peek_sender before, we are sure that there is pending read data
|
||||||
// and we don't want to wait to avoid holding the lock across await point
|
// and we don't want to wait to avoid holding the lock across await point
|
||||||
match server.listener.try_recv_buf(buf) {
|
match server.listener.try_recv_buf(buf) {
|
||||||
Ok(0) => {} // don't wake if nothing was read
|
Ok(0) => {} // don't wake if nothing was read
|
||||||
Ok(_) => {
|
Ok(len) => {
|
||||||
|
read_lens.push_back(len);
|
||||||
if let Some(waker) = waker.take() {
|
if let Some(waker) = waker.take() {
|
||||||
|
drop(guard);
|
||||||
waker.wake()
|
waker.wake()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,12 +182,15 @@ pub async fn run_server(
|
||||||
}
|
}
|
||||||
Entry::Vacant(peer) => {
|
Entry::Vacant(peer) => {
|
||||||
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
|
let mut buf = BytesMut::with_capacity(DEFAULT_UDP_BUFFER_SIZE);
|
||||||
match server.listener.recv_buf(&mut buf).await {
|
let len = match server.listener.recv_buf(&mut buf).await {
|
||||||
Ok(0) | Err(_) => continue,
|
Ok(0) | Err(_) => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
Ok(len) => len,
|
Ok(len) => len,
|
||||||
};
|
};
|
||||||
|
let mut read_lens = VecDeque::with_capacity(64);
|
||||||
let io = Arc::new(Mutex::new((buf, None)));
|
read_lens.push_back(len);
|
||||||
|
let io = Arc::new(Mutex::new((buf, None, read_lens)));
|
||||||
peer.insert(io.clone());
|
peer.insert(io.clone());
|
||||||
let udp_client = UdpStream {
|
let udp_client = UdpStream {
|
||||||
socket: server.clone_socket(),
|
socket: server.clone_socket(),
|
||||||
|
|
Loading…
Reference in a new issue