feat(linux): Add SO_MARK support for DNS request
This commit is contained in:
parent
7eaf7dc43e
commit
fb378d29d5
3 changed files with 102 additions and 9 deletions
101
src/dns.rs
101
src/dns.rs
|
@ -1,14 +1,21 @@
|
|||
use anyhow::anyhow;
|
||||
use futures_util::FutureExt;
|
||||
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
|
||||
use hickory_resolver::TokioAsyncResolver;
|
||||
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
|
||||
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
|
||||
use hickory_resolver::proto::TokioTime;
|
||||
use hickory_resolver::{AsyncResolver, TokioHandle};
|
||||
use log::warn;
|
||||
use std::future::Future;
|
||||
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::pin::Pin;
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
use url::{Host, Url};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum DnsResolver {
|
||||
System,
|
||||
TrustDns(TokioAsyncResolver),
|
||||
TrustDns(AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>),
|
||||
}
|
||||
|
||||
impl DnsResolver {
|
||||
|
@ -29,7 +36,7 @@ impl DnsResolver {
|
|||
Ok(addrs)
|
||||
}
|
||||
|
||||
pub fn new_from_urls(resolvers: &[Url]) -> anyhow::Result<Self> {
|
||||
pub fn new_from_urls(resolvers: &[Url], so_mark: Option<u32>) -> anyhow::Result<Self> {
|
||||
if resolvers.is_empty() {
|
||||
// no dns resolver specified, fall-back to default one
|
||||
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
|
||||
|
@ -45,7 +52,11 @@ impl DnsResolver {
|
|||
opts.cache_size = 1024;
|
||||
opts.num_concurrent_reqs = cfg.name_servers().len();
|
||||
}
|
||||
return Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)));
|
||||
return Ok(Self::TrustDns(AsyncResolver::new(
|
||||
cfg,
|
||||
opts,
|
||||
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
|
||||
)));
|
||||
};
|
||||
|
||||
// if one is specified as system, use the default one from libc
|
||||
|
@ -81,6 +92,86 @@ impl DnsResolver {
|
|||
|
||||
let mut opts = ResolverOpts::default();
|
||||
opts.timeout = std::time::Duration::from_secs(1);
|
||||
Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)))
|
||||
Ok(Self::TrustDns(AsyncResolver::new(
|
||||
cfg,
|
||||
opts,
|
||||
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TokioRuntimeProviderWithSoMark {
|
||||
runtime: TokioRuntimeProvider,
|
||||
#[cfg(target_os = "linux")]
|
||||
so_mark: Option<u32>,
|
||||
}
|
||||
|
||||
impl TokioRuntimeProviderWithSoMark {
|
||||
fn new(so_mark: Option<u32>) -> Self {
|
||||
Self {
|
||||
runtime: TokioRuntimeProvider::default(),
|
||||
#[cfg(target_os = "linux")]
|
||||
so_mark,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
|
||||
type Handle = TokioHandle;
|
||||
type Timer = TokioTime;
|
||||
type Udp = UdpSocket;
|
||||
type Tcp = AsyncIoTokioAsStd<TcpStream>;
|
||||
|
||||
#[inline]
|
||||
fn create_handle(&self) -> Self::Handle {
|
||||
self.runtime.create_handle()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn connect_tcp(&self, server_addr: SocketAddr) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Tcp>>>> {
|
||||
let socket = TcpStream::connect(server_addr);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
let socket = {
|
||||
use socket2::SockRef;
|
||||
|
||||
socket.map({
|
||||
let so_mark = self.so_mark;
|
||||
move |sock| {
|
||||
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
|
||||
SockRef::from(sock).set_mark(so_mark)?;
|
||||
}
|
||||
sock
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Box::pin(socket.map(|s| s.map(AsyncIoTokioAsStd)))
|
||||
}
|
||||
|
||||
fn bind_udp(
|
||||
&self,
|
||||
local_addr: SocketAddr,
|
||||
_server_addr: SocketAddr,
|
||||
) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Udp>>>> {
|
||||
let socket = UdpSocket::bind(local_addr);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
let socket = {
|
||||
use socket2::SockRef;
|
||||
|
||||
socket.map({
|
||||
let so_mark = self.so_mark;
|
||||
move |sock| {
|
||||
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
|
||||
SockRef::from(sock).set_mark(so_mark)?;
|
||||
}
|
||||
sock
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Box::pin(socket)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -877,7 +877,8 @@ async fn main() {
|
|||
},
|
||||
cnx_pool: None,
|
||||
tls_reloader: None,
|
||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("cannot create dns resolver"),
|
||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
|
||||
.expect("cannot create dns resolver"),
|
||||
};
|
||||
|
||||
let tls_reloader =
|
||||
|
@ -1295,7 +1296,8 @@ async fn main() {
|
|||
timeout_connect: Duration::from_secs(10),
|
||||
websocket_mask_frame: args.websocket_mask_frame,
|
||||
tls: tls_config,
|
||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("Cannot create DNS resolver"),
|
||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
|
||||
.expect("Cannot create DNS resolver"),
|
||||
restriction_config: args.restrict_config,
|
||||
};
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ pub mod server {
|
|||
use parking_lot::Mutex;
|
||||
use scopeguard::guard;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::{io, process, thread};
|
||||
use std::sync::Arc;
|
||||
use std::{io, thread};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::LocalSet;
|
||||
|
|
Loading…
Reference in a new issue