From b705484d9f5183772ce4396d1fc5d00529a4cdfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Tue, 2 Jan 2024 19:32:47 +0100 Subject: [PATCH] Dont use libc dns resolver by default + By default libc dns resolution is blocking. Which force async runtime to spawn blocking thread for it which lead to heavy memory usage --- README.md | 2 ++ justfile | 2 +- src/main.rs | 82 ++++++++++++++++++++++++++++++----------------- src/tcp.rs | 3 +- src/tunnel/mod.rs | 5 ++- 5 files changed, 59 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index df0aa30..6f38043 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,8 @@ Options: dns://1.1.1.1 for using udp dns+https://1.1.1.1 for using dns over HTTPS dns+tls://8.8.8.8 for using dns over TLS + To use libc resolver, use + system://0.0.0.0 -r, --restrict-http-upgrade-path-prefix Server will only accept connection from if this specific path prefix is used during websocket upgrade. Useful if you specify in the client a custom path prefix and you want the server to only allow this one. diff --git a/justfile b/justfile index 44c339e..2d184f3 100644 --- a/justfile +++ b/justfile @@ -10,7 +10,7 @@ make_release $VERSION $FORCE="": git add Cargo.* git commit -m 'Bump version v'$VERSION git tag $FORCE v$VERSION -m 'version v'$VERSION - git push + git push $FORCE git push $FORCE origin v$VERSION @just docker_release v$VERSION diff --git a/src/main.rs b/src/main.rs index 022734b..76929ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use futures_util::{stream, TryStreamExt}; use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts}; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; +use log::{debug, warn}; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; @@ -202,6 +203,8 @@ struct Server { /// dns://1.1.1.1 for using udp /// dns+https://1.1.1.1 for using dns over HTTPS /// dns+tls://8.8.8.8 for using dns over TLS + /// To use libc resolver, use + /// system://0.0.0.0 #[arg(long, verbatim_doc_comment)] dns_resolver: Option>, @@ -519,7 +522,7 @@ impl Debug for WsServerConfig { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct WsClientConfig { pub remote_addr: (Host, u16), pub socket_so_mark: Option, @@ -533,6 +536,7 @@ pub struct WsClientConfig { pub websocket_mask_frame: bool, pub http_proxy: Option, cnx_pool: Option>, + pub dns_resolver: DnsResolver, } impl WsClientConfig { @@ -626,6 +630,12 @@ async fn main() { websocket_mask_frame: args.websocket_mask_frame, http_proxy: args.http_proxy, cnx_pool: None, + dns_resolver: if let Ok(resolver) = hickory_resolver::AsyncResolver::tokio_from_system_conf() { + DnsResolver::TrustDns(resolver) + } else { + debug!("Fall-backing to system dns resolver"); + DnsResolver::System + }, }; let pool = bb8::Pool::builder() @@ -654,7 +664,7 @@ async fn main() { remote.1, cfg.socket_so_mark, cfg.timeout_connect, - &DnsResolver::System, + &cfg.dns_resolver, ) .await }; @@ -673,7 +683,7 @@ async fn main() { let cfg = client_config.clone(); let remote = tunnel.remote.clone(); let connect_to_dest = |_| async { - udp::connect(&remote.0, remote.1, cfg.timeout_connect, &DnsResolver::System).await + udp::connect(&remote.0, remote.1, cfg.timeout_connect, &cfg.dns_resolver).await }; if let Err(err) = @@ -690,9 +700,8 @@ async fn main() { let connect_to_dest = |remote: (Host, u16)| { let so_mark = cfg.socket_so_mark; let timeout = cfg.timeout_connect; - async move { - tcp::connect(&remote.0, remote.1, so_mark, timeout, &DnsResolver::System).await - } + let dns_resolver = &cfg.dns_resolver; + async move { tcp::connect(&remote.0, remote.1, so_mark, timeout, dns_resolver).await } }; if let Err(err) = @@ -841,32 +850,45 @@ async fn main() { }; let dns_resolver = match args.dns_resolver { - None => DnsResolver::System, - Some(resolvers) => { - let mut cfg = ResolverConfig::new(); - for resolver in resolvers { - let (protocol, port) = match resolver.scheme() { - "dns" => (hickory_resolver::config::Protocol::Udp, resolver.port().unwrap_or(53)), - "dns+https" => (hickory_resolver::config::Protocol::Https, resolver.port().unwrap_or(853)), - "dns+tls" => (hickory_resolver::config::Protocol::Tls, resolver.port().unwrap_or(12)), - _ => panic!("invalid protocol for dns resolver"), - }; - let sock = match resolver.host().unwrap() { - Host::Domain(host) => match Host::parse(host) { - Ok(Host::Ipv4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)), - Ok(Host::Ipv6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), - Ok(Host::Domain(_)) | Err(_) => { - panic!("Dns resolver must be an ip address, got {}", host) - } - }, - Host::Ipv4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)), - Host::Ipv6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), - }; - cfg.add_name_server(NameServerConfig::new(sock, protocol)) + None => { + if let Ok(resolver) = hickory_resolver::AsyncResolver::tokio_from_system_conf() { + DnsResolver::TrustDns(resolver) + } else { + warn!("Fall-backing to system dns resolver. You should consider specifying a dns resolver. To avoid performance issue"); + DnsResolver::System } + } + Some(resolvers) => { + if resolvers.iter().any(|r| r.scheme() == "system") { + DnsResolver::System + } else { + let mut cfg = ResolverConfig::new(); + for resolver in resolvers { + let (protocol, port) = match resolver.scheme() { + "dns" => (hickory_resolver::config::Protocol::Udp, resolver.port().unwrap_or(53)), + "dns+https" => { + (hickory_resolver::config::Protocol::Https, resolver.port().unwrap_or(443)) + } + "dns+tls" => (hickory_resolver::config::Protocol::Tls, resolver.port().unwrap_or(853)), + _ => panic!("invalid protocol for dns resolver"), + }; + let sock = match resolver.host().unwrap() { + Host::Domain(host) => match Host::parse(host) { + Ok(Host::Ipv4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)), + Ok(Host::Ipv6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), + Ok(Host::Domain(_)) | Err(_) => { + panic!("Dns resolver must be an ip address, got {}", host) + } + }, + Host::Ipv4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)), + Host::Ipv6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)), + }; + cfg.add_name_server(NameServerConfig::new(sock, protocol)) + } - let opts = ResolverOpts::default(); - DnsResolver::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)) + let opts = ResolverOpts::default(); + DnsResolver::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts)) + } } }; let server_config = WsServerConfig { diff --git a/src/tcp.rs b/src/tcp.rs index c6f8a19..d064374 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -102,11 +102,12 @@ pub async fn connect_with_http_proxy( port: u16, so_mark: Option, connect_timeout: Duration, + dns_resolver: &DnsResolver, ) -> Result { let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned(); let proxy_port = proxy.port_or_known_default().unwrap_or(80); - let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout, &DnsResolver::System).await?; + let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout, dns_resolver).await?; info!("Connected to http proxy {}:{}", proxy_host, proxy_port); let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) { diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 5aeb704..bdcb4cf 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -3,7 +3,6 @@ mod io; pub mod server; mod tls_reloader; -use crate::dns::DnsResolver; use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig}; use async_trait::async_trait; use bb8::ManageConnection; @@ -127,9 +126,9 @@ impl ManageConnection for WsClientConfig { let timeout = self.timeout_connect; let tcp_stream = if let Some(http_proxy) = &self.http_proxy { - tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout).await? + tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout, &self.dns_resolver).await? } else { - tcp::connect(host, *port, so_mark, timeout, &DnsResolver::System).await? + tcp::connect(host, *port, so_mark, timeout, &self.dns_resolver).await? }; match &self.tls {