diff --git a/src/dns.rs b/src/dns.rs index a0e3d3c..08d8a87 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,14 +1,21 @@ use anyhow::anyhow; +use futures_util::FutureExt; use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts}; -use hickory_resolver::TokioAsyncResolver; +use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider}; +use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd; +use hickory_resolver::proto::TokioTime; +use hickory_resolver::{AsyncResolver, TokioHandle}; use log::warn; +use std::future::Future; use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::pin::Pin; +use tokio::net::{TcpStream, UdpSocket}; use url::{Host, Url}; #[derive(Clone)] pub enum DnsResolver { System, - TrustDns(TokioAsyncResolver), + TrustDns(AsyncResolver>), } impl DnsResolver { @@ -29,7 +36,7 @@ impl DnsResolver { Ok(addrs) } - pub fn new_from_urls(resolvers: &[Url]) -> anyhow::Result { + pub fn new_from_urls(resolvers: &[Url], so_mark: Option) -> anyhow::Result { if resolvers.is_empty() { // no dns resolver specified, fall-back to default one let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else { @@ -45,7 +52,11 @@ impl DnsResolver { opts.cache_size = 1024; opts.num_concurrent_reqs = cfg.name_servers().len(); } - return Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts))); + return Ok(Self::TrustDns(AsyncResolver::new( + cfg, + opts, + GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)), + ))); }; // if one is specified as system, use the default one from libc @@ -81,6 +92,86 @@ impl DnsResolver { let mut opts = ResolverOpts::default(); opts.timeout = std::time::Duration::from_secs(1); - Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts))) + Ok(Self::TrustDns(AsyncResolver::new( + cfg, + opts, + GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)), + ))) + } +} + +#[derive(Clone)] +pub struct TokioRuntimeProviderWithSoMark { + runtime: TokioRuntimeProvider, + #[cfg(target_os = "linux")] + so_mark: Option, +} + +impl TokioRuntimeProviderWithSoMark { + fn new(so_mark: Option) -> Self { + Self { + runtime: TokioRuntimeProvider::default(), + #[cfg(target_os = "linux")] + so_mark, + } + } +} + +impl RuntimeProvider for TokioRuntimeProviderWithSoMark { + type Handle = TokioHandle; + type Timer = TokioTime; + type Udp = UdpSocket; + type Tcp = AsyncIoTokioAsStd; + + #[inline] + fn create_handle(&self) -> Self::Handle { + self.runtime.create_handle() + } + + #[inline] + fn connect_tcp(&self, server_addr: SocketAddr) -> Pin>>> { + let socket = TcpStream::connect(server_addr); + + #[cfg(target_os = "linux")] + let socket = { + use socket2::SockRef; + + socket.map({ + let so_mark = self.so_mark; + move |sock| { + if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) { + SockRef::from(sock).set_mark(so_mark)?; + } + sock + } + }) + }; + + Box::pin(socket.map(|s| s.map(AsyncIoTokioAsStd))) + } + + fn bind_udp( + &self, + local_addr: SocketAddr, + _server_addr: SocketAddr, + ) -> Pin>>> { + let socket = UdpSocket::bind(local_addr); + + #[cfg(target_os = "linux")] + let socket = { + use socket2::SockRef; + + socket.map({ + let so_mark = self.so_mark; + move |sock| { + if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) { + SockRef::from(sock).set_mark(so_mark)?; + } + sock + } + }) + }; + + Box::pin(socket) } } diff --git a/src/main.rs b/src/main.rs index bc40ed0..541eb37 100644 --- a/src/main.rs +++ b/src/main.rs @@ -877,7 +877,8 @@ async fn main() { }, cnx_pool: None, tls_reloader: None, - dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("cannot create dns resolver"), + dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark) + .expect("cannot create dns resolver"), }; let tls_reloader = @@ -1295,7 +1296,8 @@ async fn main() { timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, tls: tls_config, - dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("Cannot create DNS resolver"), + dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark) + .expect("Cannot create DNS resolver"), restriction_config: args.restrict_config, }; diff --git a/src/stdio.rs b/src/stdio.rs index add7f49..aa44435 100644 --- a/src/stdio.rs +++ b/src/stdio.rs @@ -36,8 +36,8 @@ pub mod server { use parking_lot::Mutex; use scopeguard::guard; use std::io::{Read, Write}; - use std::sync::{mpsc, Arc}; - use std::{io, process, thread}; + use std::sync::Arc; + use std::{io, thread}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::sync::oneshot; use tokio::task::LocalSet;