feat(linux): Add SO_MARK support for DNS request
This commit is contained in:
parent
7eaf7dc43e
commit
fb378d29d5
3 changed files with 102 additions and 9 deletions
101
src/dns.rs
101
src/dns.rs
|
@ -1,14 +1,21 @@
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
|
use futures_util::FutureExt;
|
||||||
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
|
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
|
||||||
use hickory_resolver::TokioAsyncResolver;
|
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
|
||||||
|
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
|
||||||
|
use hickory_resolver::proto::TokioTime;
|
||||||
|
use hickory_resolver::{AsyncResolver, TokioHandle};
|
||||||
use log::warn;
|
use log::warn;
|
||||||
|
use std::future::Future;
|
||||||
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tokio::net::{TcpStream, UdpSocket};
|
||||||
use url::{Host, Url};
|
use url::{Host, Url};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum DnsResolver {
|
pub enum DnsResolver {
|
||||||
System,
|
System,
|
||||||
TrustDns(TokioAsyncResolver),
|
TrustDns(AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DnsResolver {
|
impl DnsResolver {
|
||||||
|
@ -29,7 +36,7 @@ impl DnsResolver {
|
||||||
Ok(addrs)
|
Ok(addrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_from_urls(resolvers: &[Url]) -> anyhow::Result<Self> {
|
pub fn new_from_urls(resolvers: &[Url], so_mark: Option<u32>) -> anyhow::Result<Self> {
|
||||||
if resolvers.is_empty() {
|
if resolvers.is_empty() {
|
||||||
// no dns resolver specified, fall-back to default one
|
// no dns resolver specified, fall-back to default one
|
||||||
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
|
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
|
||||||
|
@ -45,7 +52,11 @@ impl DnsResolver {
|
||||||
opts.cache_size = 1024;
|
opts.cache_size = 1024;
|
||||||
opts.num_concurrent_reqs = cfg.name_servers().len();
|
opts.num_concurrent_reqs = cfg.name_servers().len();
|
||||||
}
|
}
|
||||||
return Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)));
|
return Ok(Self::TrustDns(AsyncResolver::new(
|
||||||
|
cfg,
|
||||||
|
opts,
|
||||||
|
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
|
||||||
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
// if one is specified as system, use the default one from libc
|
// if one is specified as system, use the default one from libc
|
||||||
|
@ -81,6 +92,86 @@ impl DnsResolver {
|
||||||
|
|
||||||
let mut opts = ResolverOpts::default();
|
let mut opts = ResolverOpts::default();
|
||||||
opts.timeout = std::time::Duration::from_secs(1);
|
opts.timeout = std::time::Duration::from_secs(1);
|
||||||
Ok(Self::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)))
|
Ok(Self::TrustDns(AsyncResolver::new(
|
||||||
|
cfg,
|
||||||
|
opts,
|
||||||
|
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(so_mark)),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TokioRuntimeProviderWithSoMark {
|
||||||
|
runtime: TokioRuntimeProvider,
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
so_mark: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokioRuntimeProviderWithSoMark {
|
||||||
|
fn new(so_mark: Option<u32>) -> Self {
|
||||||
|
Self {
|
||||||
|
runtime: TokioRuntimeProvider::default(),
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
so_mark,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
|
||||||
|
type Handle = TokioHandle;
|
||||||
|
type Timer = TokioTime;
|
||||||
|
type Udp = UdpSocket;
|
||||||
|
type Tcp = AsyncIoTokioAsStd<TcpStream>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn create_handle(&self) -> Self::Handle {
|
||||||
|
self.runtime.create_handle()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn connect_tcp(&self, server_addr: SocketAddr) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Tcp>>>> {
|
||||||
|
let socket = TcpStream::connect(server_addr);
|
||||||
|
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
let socket = {
|
||||||
|
use socket2::SockRef;
|
||||||
|
|
||||||
|
socket.map({
|
||||||
|
let so_mark = self.so_mark;
|
||||||
|
move |sock| {
|
||||||
|
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
|
||||||
|
SockRef::from(sock).set_mark(so_mark)?;
|
||||||
|
}
|
||||||
|
sock
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(socket.map(|s| s.map(AsyncIoTokioAsStd)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bind_udp(
|
||||||
|
&self,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
_server_addr: SocketAddr,
|
||||||
|
) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Udp>>>> {
|
||||||
|
let socket = UdpSocket::bind(local_addr);
|
||||||
|
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
let socket = {
|
||||||
|
use socket2::SockRef;
|
||||||
|
|
||||||
|
socket.map({
|
||||||
|
let so_mark = self.so_mark;
|
||||||
|
move |sock| {
|
||||||
|
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
|
||||||
|
SockRef::from(sock).set_mark(so_mark)?;
|
||||||
|
}
|
||||||
|
sock
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(socket)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -877,7 +877,8 @@ async fn main() {
|
||||||
},
|
},
|
||||||
cnx_pool: None,
|
cnx_pool: None,
|
||||||
tls_reloader: None,
|
tls_reloader: None,
|
||||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("cannot create dns resolver"),
|
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
|
||||||
|
.expect("cannot create dns resolver"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let tls_reloader =
|
let tls_reloader =
|
||||||
|
@ -1295,7 +1296,8 @@ async fn main() {
|
||||||
timeout_connect: Duration::from_secs(10),
|
timeout_connect: Duration::from_secs(10),
|
||||||
websocket_mask_frame: args.websocket_mask_frame,
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
tls: tls_config,
|
tls: tls_config,
|
||||||
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver).expect("Cannot create DNS resolver"),
|
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, args.socket_so_mark)
|
||||||
|
.expect("Cannot create DNS resolver"),
|
||||||
restriction_config: args.restrict_config,
|
restriction_config: args.restrict_config,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,8 @@ pub mod server {
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use scopeguard::guard;
|
use scopeguard::guard;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::sync::{mpsc, Arc};
|
use std::sync::Arc;
|
||||||
use std::{io, process, thread};
|
use std::{io, thread};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::task::LocalSet;
|
use tokio::task::LocalSet;
|
||||||
|
|
Loading…
Reference in a new issue