diff --git a/src/dns.rs b/src/dns.rs index e8728e0..718819f 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -15,10 +15,11 @@ use std::time::Duration; use tokio::net::{TcpStream, UdpSocket}; use url::{Host, Url}; -// Interweave v4 and v6 addresses as per RFC8305. +// Interleave v4 and v6 addresses as per RFC8305. // The first address is v6 if we have any v6 addresses. -pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator { - let mut pick_v6 = false; +#[inline] +fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator { + let mut pick_v6 = !prefer_ipv6; let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_))); let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_))); std::iter::from_fn(move || { @@ -34,28 +35,39 @@ pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator>), + TrustDns { + resolver: AsyncResolver>, + prefer_ipv6: bool, + }, } impl DnsResolver { pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result> { let addrs: Vec = match self { Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(), - Self::TrustDns(dns_resolver) => dns_resolver - .lookup_ip(domain) - .await? - .into_iter() - .map(|ip| match ip { - IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)), - IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), - }) - .collect(), + Self::TrustDns { resolver, prefer_ipv6 } => { + let addrs: Vec<_> = resolver + .lookup_ip(domain) + .await? + .into_iter() + .map(|ip| match ip { + IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)), + IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), + }) + .collect(); + sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect() + } }; Ok(addrs) } - pub fn new_from_urls(resolvers: &[Url], proxy: Option, so_mark: Option) -> anyhow::Result { + pub fn new_from_urls( + resolvers: &[Url], + proxy: Option, + so_mark: Option, + prefer_ipv6: bool, + ) -> anyhow::Result { if resolvers.is_empty() { // no dns resolver specified, fall-back to default one let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else { @@ -63,7 +75,8 @@ impl DnsResolver { return Ok(Self::System); }; - opts.timeout = std::time::Duration::from_secs(1); + opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; + opts.timeout = Duration::from_secs(1); // Windows end-up with too many dns resolvers, which causes a performance issue // https://github.com/hickory-dns/hickory-dns/issues/1968 #[cfg(target_os = "windows")] @@ -71,11 +84,14 @@ impl DnsResolver { opts.cache_size = 1024; opts.num_concurrent_reqs = cfg.name_servers().len(); } - return Ok(Self::TrustDns(AsyncResolver::new( - cfg, - opts, - GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)), - ))); + return Ok(Self::TrustDns { + resolver: AsyncResolver::new( + cfg, + opts, + GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)), + ), + prefer_ipv6, + }); }; // if one is specified as system, use the default one from libc @@ -127,13 +143,16 @@ impl DnsResolver { } let mut opts = ResolverOpts::default(); - opts.timeout = std::time::Duration::from_secs(1); + opts.timeout = Duration::from_secs(1); opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; - Ok(Self::TrustDns(AsyncResolver::new( - cfg, - opts, - GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)), - ))) + Ok(Self::TrustDns { + resolver: AsyncResolver::new( + cfg, + opts, + GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)), + ), + prefer_ipv6, + }) } } @@ -235,3 +254,29 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark { Box::pin(socket) } } + +#[cfg(test)] +mod tests { + use crate::dns::sort_socket_addrs; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + + #[test] + fn test_sort_socket_addrs() { + let addrs = [ + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)), + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)), + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)), + ]; + let expected = [ + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)), + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)), + ]; + let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect(); + assert_eq!(expected, *actual); + } +} diff --git a/src/main.rs b/src/main.rs index 2351b85..f7489cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -258,6 +258,17 @@ struct Client { /// **WARN** On windows you may want to specify explicitly the DNS resolver to avoid excessive DNS queries #[arg(long, verbatim_doc_comment)] dns_resolver: Vec, + + /// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6 + /// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6 + /// If you don't have any IPv6 this does not change anything. + #[arg( + long, + default_value = "false", + env = "WSTUNNEL_DNS_PREFER_IPV4", + verbatim_doc_comment + )] + dns_resolver_prefer_ipv4: bool, } #[derive(clap::Args, Debug)] @@ -295,6 +306,17 @@ struct Server { #[arg(long, verbatim_doc_comment)] dns_resolver: Vec, + /// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6 + /// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6 + /// If you don't have any IPv6 this does not change anything. + #[arg( + long, + default_value = "false", + env = "WSTUNNEL_DNS_PREFER_IPV4", + verbatim_doc_comment + )] + dns_resolver_prefer_ipv4: bool, + /// Server will only accept connection from the specified tunnel information. /// Can be specified multiple time /// Example: --restrict-to "google.com:443" --restrict-to "localhost:22" @@ -755,8 +777,13 @@ impl WsClientConfig { #[tokio::main] async fn main() { let args = Wstunnel::parse(); - let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await.unwrap(); - socket.connect("[2001:4810:0:3::78]:443".parse::().unwrap()).await.unwrap(); + let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)) + .await + .unwrap(); + socket + .connect("[2001:4810:0:3::78]:443".parse::().unwrap()) + .await + .unwrap(); // Setup logging let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level"); @@ -902,8 +929,13 @@ async fn main() { websocket_mask_frame: args.websocket_mask_frame, cnx_pool: None, tls_reloader: None, - dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, http_proxy.clone(), args.socket_so_mark) - .expect("cannot create dns resolver"), + dns_resolver: DnsResolver::new_from_urls( + &args.dns_resolver, + http_proxy.clone(), + args.socket_so_mark, + !args.dns_resolver_prefer_ipv4, + ) + .expect("cannot create dns resolver"), http_proxy, }; @@ -1324,8 +1356,13 @@ async fn main() { timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, tls: tls_config, - dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, None, args.socket_so_mark) - .expect("Cannot create DNS resolver"), + dns_resolver: DnsResolver::new_from_urls( + &args.dns_resolver, + None, + args.socket_so_mark, + !args.dns_resolver_prefer_ipv4, + ) + .expect("Cannot create DNS resolver"), restriction_config: args.restrict_config, }; diff --git a/src/tcp.rs b/src/tcp.rs index b9c8001..4cf7fe4 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Context}; use std::{io, vec}; use tokio::task::JoinSet; -use crate::dns::{self, DnsResolver}; +use crate::dns::DnsResolver; use base64::Engine; use bytes::BytesMut; use log::warn; @@ -73,14 +73,11 @@ pub async fn connect( let mut last_err = None; let mut join_set = JoinSet::new(); - for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() { - debug!("Connecting to {}", addr); - + for (ix, addr) in socket_addrs.into_iter().enumerate() { let socket = match &addr { SocketAddr::V4(_) => TcpSocket::new_v4()?, SocketAddr::V6(_) => TcpSocket::new_v6()?, }; - configure_socket(socket2::SockRef::from(&socket), &so_mark)?; // Spawn the connection attempt in the join set. @@ -90,6 +87,7 @@ pub async fn connect( if ix > 0 { sleep(Duration::from_millis(250 * ix as u64)).await; } + debug!("Connecting to {}", addr); match timeout(connect_timeout, socket.connect(addr)).await { Ok(Ok(s)) => Ok(Ok(s)), Ok(Err(e)) => Ok(Err((addr, e))), @@ -107,7 +105,7 @@ pub async fn connect( match res? { Ok(Ok(stream)) => { // We've got a successful connection, so we can abort all other - // on-going attempts. + // ongoing attempts. join_set.abort_all(); debug!( @@ -227,7 +225,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result = dns::sort_socket_addrs(&addrs).copied().collect(); - assert_eq!(expected, *actual); - } - #[tokio::test] async fn test_proxy_connection() { let server_addr: SocketAddr = "[::1]:1236".parse().unwrap(); diff --git a/src/udp.rs b/src/udp.rs index 6e6deee..ceb8ad6 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -19,7 +19,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; use tokio::sync::futures::Notified; -use crate::dns::{self, DnsResolver}; +use crate::dns::DnsResolver; use tokio::sync::Notify; use tokio::time::{sleep, timeout, Interval}; use tracing::{debug, error, info}; @@ -340,9 +340,7 @@ pub async fn connect( let mut last_err = None; let mut join_set = JoinSet::new(); - for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() { - debug!("connecting to {}", addr); - + for (ix, addr) in socket_addrs.into_iter().enumerate() { let socket = match &addr { SocketAddr::V4(_) => UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await, SocketAddr::V6(_) => UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await, @@ -364,6 +362,7 @@ pub async fn connect( sleep(Duration::from_millis(250 * ix as u64)).await; } + debug!("connecting to {}", addr); match timeout(connect_timeout, socket.connect(addr)).await { Ok(Ok(())) => Ok(Ok(socket)), Ok(Err(e)) => Ok(Err((addr, e))), @@ -381,7 +380,7 @@ pub async fn connect( match res? { Ok(Ok(socket)) => { // We've got a successful connection, so we can abort all other - // on-going attempts. + // ongoing attempts. join_set.abort_all(); debug!(