From 97251804dd26e6ebb142a3f3fe5e804b1de95324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sat, 20 Jul 2024 15:55:08 +0200 Subject: [PATCH] feat(udp): set SO_MARK for udp cnx --- src/dns.rs | 2 +- src/main.rs | 12 +++++++++--- src/tunnel/server.rs | 1 + src/udp.rs | 9 +++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/dns.rs b/src/dns.rs index c9da957..760b404 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,4 +1,4 @@ -use crate::{tcp}; +use crate::tcp; use anyhow::{anyhow, Context}; use futures_util::{FutureExt, TryFutureExt}; use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; diff --git a/src/main.rs b/src/main.rs index 59226a2..1474a1a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -991,8 +991,14 @@ async fn main() { port, }; let connect_to_dest = |_| async { - udp::connect(&tunnel.remote.0, tunnel.remote.1, cfg.timeout_connect, &cfg.dns_resolver) - .await + udp::connect( + &tunnel.remote.0, + tunnel.remote.1, + cfg.timeout_connect, + cfg.socket_so_mark, + &cfg.dns_resolver, + ) + .await }; if let Err(err) = @@ -1033,7 +1039,7 @@ async fn main() { .map(|s| Box::new(s) as Box) } LocalProtocol::Udp { .. } => { - udp::connect(&remote.host, remote.port, timeout, dns_resolver) + udp::connect(&remote.host, remote.port, timeout, so_mark, dns_resolver) .await .map(|s| Box::new(s) as Box) } diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 9fde71c..f0f6004 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -60,6 +60,7 @@ async fn run_tunnel( &remote.host, remote.port, timeout.unwrap_or(Duration::from_secs(10)), + server_config.socket_so_mark, &server_config.dns_resolver, ) .await?; diff --git a/src/udp.rs b/src/udp.rs index ceb8ad6..f04c4c1 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -11,6 +11,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use tokio::task::JoinSet; use log::warn; +use socket2::SockRef; use std::pin::{pin, Pin}; use std::sync::{Arc, Weak}; use std::task::{ready, Poll}; @@ -323,6 +324,7 @@ pub async fn connect( host: &Host, port: u16, connect_timeout: Duration, + so_mark: Option, dns_resolver: &DnsResolver, ) -> anyhow::Result { info!("Opening UDP connection to {}:{}", host, port); @@ -354,6 +356,13 @@ pub async fn connect( } }; + #[cfg(target_os = "linux")] + if let Some(so_mark) = so_mark { + SockRef::from(&socket) + .set_mark(so_mark) + .with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?; + } + // Spawn the connection attempt in the join set. // We include a delay of ix * 250 milliseconds, as per RFC8305. // See https://datatracker.ietf.org/doc/html/rfc8305#section-5