Do DNS queries for both A and AAAA simultaneously (#302)
* Do DNS queries for both A and AAAA simultaneously We implement a basic version of RFC8305 (happy eyeballs) to establish the connection afterwards. * Try to connect to UDP sockets simultaneously
This commit is contained in:
parent
4f570dc48b
commit
90d378e768
3 changed files with 125 additions and 24 deletions
19
src/dns.rs
19
src/dns.rs
|
@ -1,7 +1,7 @@
|
||||||
use crate::tcp;
|
use crate::tcp;
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use futures_util::{FutureExt, TryFutureExt};
|
use futures_util::{FutureExt, TryFutureExt};
|
||||||
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
|
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
|
||||||
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
|
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
|
||||||
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
|
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
|
||||||
use hickory_resolver::proto::TokioTime;
|
use hickory_resolver::proto::TokioTime;
|
||||||
|
@ -15,6 +15,22 @@ use std::time::Duration;
|
||||||
use tokio::net::{TcpStream, UdpSocket};
|
use tokio::net::{TcpStream, UdpSocket};
|
||||||
use url::{Host, Url};
|
use url::{Host, Url};
|
||||||
|
|
||||||
|
// Interweave v4 and v6 addresses as per RFC8305.
|
||||||
|
// The first address is v6 if we have any v6 addresses.
|
||||||
|
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
|
||||||
|
let mut pick_v6 = false;
|
||||||
|
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
|
||||||
|
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
|
||||||
|
std::iter::from_fn(move || {
|
||||||
|
pick_v6 = !pick_v6;
|
||||||
|
if pick_v6 {
|
||||||
|
v6.next().or_else(|| v4.next())
|
||||||
|
} else {
|
||||||
|
v4.next().or_else(|| v6.next())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum DnsResolver {
|
pub enum DnsResolver {
|
||||||
System,
|
System,
|
||||||
|
@ -112,6 +128,7 @@ impl DnsResolver {
|
||||||
|
|
||||||
let mut opts = ResolverOpts::default();
|
let mut opts = ResolverOpts::default();
|
||||||
opts.timeout = std::time::Duration::from_secs(1);
|
opts.timeout = std::time::Duration::from_secs(1);
|
||||||
|
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
|
||||||
Ok(Self::TrustDns(AsyncResolver::new(
|
Ok(Self::TrustDns(AsyncResolver::new(
|
||||||
cfg,
|
cfg,
|
||||||
opts,
|
opts,
|
||||||
|
|
74
src/tcp.rs
74
src/tcp.rs
|
@ -1,7 +1,8 @@
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use std::{io, vec};
|
use std::{io, vec};
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
use crate::dns::DnsResolver;
|
use crate::dns::{self, DnsResolver};
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use log::warn;
|
use log::warn;
|
||||||
|
@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||||
use tokio::time::timeout;
|
use tokio::time::{sleep, timeout};
|
||||||
use tokio_stream::wrappers::TcpListenerStream;
|
use tokio_stream::wrappers::TcpListenerStream;
|
||||||
use tracing::log::info;
|
use tracing::log::info;
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
|
@ -70,7 +71,9 @@ pub async fn connect(
|
||||||
|
|
||||||
let mut cnx = None;
|
let mut cnx = None;
|
||||||
let mut last_err = None;
|
let mut last_err = None;
|
||||||
for addr in socket_addrs {
|
let mut join_set = JoinSet::new();
|
||||||
|
|
||||||
|
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
|
||||||
debug!("Connecting to {}", addr);
|
debug!("Connecting to {}", addr);
|
||||||
|
|
||||||
let socket = match &addr {
|
let socket = match &addr {
|
||||||
|
@ -79,16 +82,45 @@ pub async fn connect(
|
||||||
};
|
};
|
||||||
|
|
||||||
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
|
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
|
||||||
match timeout(connect_timeout, socket.connect(addr)).await {
|
|
||||||
Ok(Ok(stream)) => {
|
// Spawn the connection attempt in the join set.
|
||||||
cnx = Some(stream);
|
// We include a delay of ix * 250 milliseconds, as per RFC8305.
|
||||||
break;
|
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
|
||||||
|
let fut = async move {
|
||||||
|
if ix > 0 {
|
||||||
|
sleep(Duration::from_millis(250 * ix as u64)).await;
|
||||||
}
|
}
|
||||||
Ok(Err(err)) => {
|
match timeout(connect_timeout, socket.connect(addr)).await {
|
||||||
warn!("Cannot connect to tcp endpoint {addr} reason {err}");
|
Ok(Ok(s)) => Ok(Ok(s)),
|
||||||
|
Ok(Err(e)) => Ok(Err((addr, e))),
|
||||||
|
Err(e) => Err((addr, e)),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
join_set.spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the next future that finishes in the join set, until we got one
|
||||||
|
// that resulted in a successful connection.
|
||||||
|
// If cnx is no longer None, we exit the loop, since this means that we got
|
||||||
|
// a successful connection.
|
||||||
|
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
|
||||||
|
match res? {
|
||||||
|
Ok(Ok(stream)) => {
|
||||||
|
// We've got a successful connection, so we can abort all other
|
||||||
|
// on-going attempts.
|
||||||
|
join_set.abort_all();
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Connected to tcp endpoint {}, aborted all other connection attempts",
|
||||||
|
stream.peer_addr()?
|
||||||
|
);
|
||||||
|
cnx = Some(stream);
|
||||||
|
}
|
||||||
|
Ok(Err((addr, err))) => {
|
||||||
|
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
|
||||||
last_err = Some(err);
|
last_err = Some(err);
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err((addr, _)) => {
|
||||||
warn!(
|
warn!(
|
||||||
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
|
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
|
||||||
connect_timeout.as_secs()
|
connect_timeout.as_secs()
|
||||||
|
@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures_util::pin_mut;
|
use futures_util::pin_mut;
|
||||||
use std::net::SocketAddr;
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||||
use testcontainers::core::WaitFor;
|
use testcontainers::core::WaitFor;
|
||||||
use testcontainers::runners::AsyncRunner;
|
use testcontainers::runners::AsyncRunner;
|
||||||
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
|
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
|
||||||
|
@ -227,6 +259,26 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sort_socket_addrs() {
|
||||||
|
let addrs = [
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
|
||||||
|
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
|
||||||
|
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
|
||||||
|
];
|
||||||
|
let expected = [
|
||||||
|
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
|
||||||
|
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
|
||||||
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
|
||||||
|
];
|
||||||
|
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
|
||||||
|
assert_eq!(expected, *actual);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_proxy_connection() {
|
async fn test_proxy_connection() {
|
||||||
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
|
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
|
||||||
|
|
56
src/udp.rs
56
src/udp.rs
|
@ -8,6 +8,7 @@ use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::{Error, ErrorKind};
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
|
@ -18,9 +19,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::sync::futures::Notified;
|
use tokio::sync::futures::Notified;
|
||||||
|
|
||||||
use crate::dns::DnsResolver;
|
use crate::dns::{self, DnsResolver};
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
use tokio::time::{timeout, Interval};
|
use tokio::time::{sleep, timeout, Interval};
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
use url::Host;
|
use url::Host;
|
||||||
|
|
||||||
|
@ -337,7 +338,9 @@ pub async fn connect(
|
||||||
|
|
||||||
let mut cnx = None;
|
let mut cnx = None;
|
||||||
let mut last_err = None;
|
let mut last_err = None;
|
||||||
for addr in socket_addrs {
|
let mut join_set = JoinSet::new();
|
||||||
|
|
||||||
|
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
|
||||||
debug!("connecting to {}", addr);
|
debug!("connecting to {}", addr);
|
||||||
|
|
||||||
let socket = match &addr {
|
let socket = match &addr {
|
||||||
|
@ -353,18 +356,47 @@ pub async fn connect(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match timeout(connect_timeout, socket.connect(addr)).await {
|
// Spawn the connection attempt in the join set.
|
||||||
Ok(Ok(_)) => {
|
// We include a delay of ix * 250 milliseconds, as per RFC8305.
|
||||||
cnx = Some(socket);
|
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
|
||||||
break;
|
let fut = async move {
|
||||||
|
if ix > 0 {
|
||||||
|
sleep(Duration::from_millis(250 * ix as u64)).await;
|
||||||
}
|
}
|
||||||
Ok(Err(err)) => {
|
|
||||||
debug!("Cannot connect udp socket to specified peer {addr} reason {err}");
|
match timeout(connect_timeout, socket.connect(addr)).await {
|
||||||
|
Ok(Ok(())) => Ok(Ok(socket)),
|
||||||
|
Ok(Err(e)) => Ok(Err((addr, e))),
|
||||||
|
Err(e) => Err((addr, e)),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
join_set.spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the next future that finishes in the join set, until we got one
|
||||||
|
// that resulted in a successful connection.
|
||||||
|
// If cnx is no longer None, we exit the loop, since this means that we got
|
||||||
|
// a successful connection.
|
||||||
|
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
|
||||||
|
match res? {
|
||||||
|
Ok(Ok(socket)) => {
|
||||||
|
// We've got a successful connection, so we can abort all other
|
||||||
|
// on-going attempts.
|
||||||
|
join_set.abort_all();
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Connected to udp endpoint {}, aborted all other connection attempts",
|
||||||
|
socket.peer_addr()?
|
||||||
|
);
|
||||||
|
cnx = Some(socket);
|
||||||
|
}
|
||||||
|
Ok(Err((addr, err))) => {
|
||||||
|
debug!("Cannot connect to udp endpoint {addr} reason {err}");
|
||||||
last_err = Some(err);
|
last_err = Some(err);
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err((addr, _)) => {
|
||||||
debug!(
|
warn!(
|
||||||
"Cannot connect udp socket to specified peer {addr} due to timeout of {}s elapsed",
|
"Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed",
|
||||||
connect_timeout.as_secs()
|
connect_timeout.as_secs()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue