feat(linux): Add SO_MARK support for DNS request

This commit is contained in:
Σrebe - Romain GERARD 2024-06-13 22:53:28 +02:00
parent 7eaf7dc43e
commit fb378d29d5
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 102 additions and 9 deletions

View file

@ -1,14 +1,21 @@
use anyhow::anyhow; use anyhow::anyhow;
use futures_util::FutureExt;
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts}; use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
use hickory_resolver::TokioAsyncResolver; use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
use hickory_resolver::proto::TokioTime;
use hickory_resolver::{AsyncResolver, TokioHandle};
use log::warn; use log::warn;
use std::future::Future;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use tokio::net::{TcpStream, UdpSocket};
use url::{Host, Url}; use url::{Host, Url};
#[derive(Clone)] #[derive(Clone)]
pub enum DnsResolver { pub enum DnsResolver {
System, System,
TrustDns(TokioAsyncResolver), TrustDns(AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>),
} }
impl DnsResolver { impl DnsResolver {
@ -29,7 +36,7 @@ impl DnsResolver {
Ok(addrs) Ok(addrs)
} }
pub fn new_from_urls(resolvers: &[Url]) -> anyhow::Result<Self> { pub fn new_from_urls(resolvers: &[Url], so_mark: Option<u32>) -> anyhow::Result<Self> {
if resolvers.is_empty() { if resolvers.is_empty() {
// no dns resolver specified, fall-back to default one // no dns resolver specified, fall-back to default one
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else { let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
@ -45,7 +52,11 @@ impl DnsResolver {
opts.cache_size = 1024; opts.cache_size = 1024;
opts.num_concurrent_reqs = cfg.name_servers().len(); opts.num_concurrent_reqs = cfg.name_servers().len();
} }
return Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts))); return Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
)));
}; };
// if one is specified as system, use the default one from libc // if one is specified as system, use the default one from libc
@ -81,6 +92,86 @@ impl DnsResolver {
let mut opts = ResolverOpts::default(); let mut opts = ResolverOpts::default();
opts.timeout = std::time::Duration::from_secs(1); opts.timeout = std::time::Duration::from_secs(1);
Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts))) Ok(Self::TrustDns(AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
)))
}
}
#[derive(Clone)]
pub struct TokioRuntimeProviderWithSoMark {
runtime: TokioRuntimeProvider,
#[cfg(target_os = "linux")]
so_mark: Option<u32>,
}
impl TokioRuntimeProviderWithSoMark {
fn new(so_mark: Option<u32>) -> Self {
Self {
runtime: TokioRuntimeProvider::default(),
#[cfg(target_os = "linux")]
so_mark,
}
}
}
impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = UdpSocket;
type Tcp = AsyncIoTokioAsStd<TcpStream>;
#[inline]
fn create_handle(&self) -> Self::Handle {
self.runtime.create_handle()
}
#[inline]
fn connect_tcp(&self, server_addr: SocketAddr) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Tcp>>>> {
let socket = TcpStream::connect(server_addr);
#[cfg(target_os = "linux")]
let socket = {
use socket2::SockRef;
socket.map({
let so_mark = self.so_mark;
move |sock| {
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
SockRef::from(sock).set_mark(so_mark)?;
}
sock
}
})
};
Box::pin(socket.map(|s| s.map(AsyncIoTokioAsStd)))
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Udp>>>> {
let socket = UdpSocket::bind(local_addr);
#[cfg(target_os = "linux")]
let socket = {
use socket2::SockRef;
socket.map({
let so_mark = self.so_mark;
move |sock| {
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
SockRef::from(sock).set_mark(so_mark)?;
}
sock
}
})
};
Box::pin(socket)
} }
} }

View file

@ -877,7 +877,8 @@ async fn main() {
}, },
cnx_pool: None, cnx_pool: None,
tls_reloader: None, tls_reloader: None,
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("cannot create dns resolver"), dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
.expect("cannot create dns resolver"),
}; };
let tls_reloader = let tls_reloader =
@ -1295,7 +1296,8 @@ async fn main() {
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame, websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config, tls: tls_config,
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("Cannot create DNS resolver"), dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
.expect("Cannot create DNS resolver"),
restriction_config: args.restrict_config, restriction_config: args.restrict_config,
}; };

View file

@ -36,8 +36,8 @@ pub mod server {
use parking_lot::Mutex; use parking_lot::Mutex;
use scopeguard::guard; use scopeguard::guard;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::sync::{mpsc, Arc}; use std::sync::Arc;
use std::{io, process, thread}; use std::{io, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::task::LocalSet; use tokio::task::LocalSet;