Add support for custom dns resolver on server

This commit is contained in:
Σrebe - Romain GERARD 2023-12-19 22:41:11 +01:00
parent d1de41646f
commit d456c67f19
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
8 changed files with 354 additions and 26 deletions

32
src/dns.rs Normal file
View file

@ -0,0 +1,32 @@
use anyhow::Context;
use hickory_resolver::TokioAsyncResolver;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[derive(Clone)]
pub enum DnsResolver {
System,
TrustDns(TokioAsyncResolver),
}
impl DnsResolver {
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
let addrs: Vec<SocketAddr> = match self {
DnsResolver::System => tokio::net::lookup_host(format!("{}:{}", domain, port))
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?
.collect(),
DnsResolver::TrustDns(dns_resolver) => dns_resolver
.lookup_ip(domain)
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?
.into_iter()
.map(|ip| match ip {
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
})
.collect(),
};
Ok(addrs)
}
}

View file

@ -1,3 +1,4 @@
mod dns;
mod embedded_certificate;
mod socks5;
mod stdio;
@ -9,13 +10,14 @@ mod udp;
use base64::Engine;
use clap::Parser;
use futures_util::{stream, TryStreamExt};
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter};
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
@ -27,6 +29,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
use tracing::{error, info, Level};
use crate::dns::DnsResolver;
use crate::tunnel::to_host_port;
use tracing_subscriber::EnvFilter;
use url::{Host, Url};
@ -157,6 +160,16 @@ struct Server {
#[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)]
restrict_to: Option<Vec<String>>,
/// Dns resolver to use to lookup ips of domain name
/// This option is not going to work if you use transparent proxy
/// Can be specified multiple time
/// Example:
/// dns://1.1.1.1 for using udp
/// dns+https://1.1.1.1 for using dns over HTTPS
/// dns+tls://8.8.8.8 for using dns over TLS
#[arg(long, verbatim_doc_comment)]
dns_resolver: Option<Vec<Url>>,
/// Server will only accept connection from if this specific path prefix is used during websocket upgrade.
/// Useful if you specify in the client a custom path prefix and you want the server to only allow this one.
/// The path prefix act as a secret to authenticate clients
@ -445,6 +458,7 @@ pub struct WsServerConfig {
pub timeout_connect: Duration,
pub websocket_mask_frame: bool,
pub tls: Option<TlsServerConfig>,
pub dns_resolver: DnsResolver,
}
impl Debug for WsServerConfig {
@ -592,7 +606,14 @@ async fn main() {
let remote = tunnel.remote.clone();
let cfg = client_config.clone();
let connect_to_dest = |_| async {
tcp::connect(&remote.0, remote.1, cfg.socket_so_mark, cfg.timeout_connect).await
tcp::connect(
&remote.0,
remote.1,
cfg.socket_so_mark,
cfg.timeout_connect,
&DnsResolver::System,
)
.await
};
if let Err(err) =
@ -608,8 +629,9 @@ async fn main() {
tokio::spawn(async move {
let cfg = client_config.clone();
let remote = tunnel.remote.clone();
let connect_to_dest =
|_| async { udp::connect(&remote.0, remote.1, cfg.timeout_connect).await };
let connect_to_dest = |_| async {
udp::connect(&remote.0, remote.1, cfg.timeout_connect, &DnsResolver::System).await
};
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await
@ -625,7 +647,9 @@ async fn main() {
let connect_to_dest = |remote: (Host, u16)| {
let so_mark = cfg.socket_so_mark;
let timeout = cfg.timeout_connect;
async move { tcp::connect(&remote.0, remote.1, so_mark, timeout).await }
async move {
tcp::connect(&remote.0, remote.1, so_mark, timeout, &DnsResolver::System).await
}
};
if let Err(err) =
@ -770,6 +794,29 @@ async fn main() {
None
};
let dns_resolver = match args.dns_resolver {
None => DnsResolver::System,
Some(resolvers) => {
let mut cfg = ResolverConfig::new();
for resolver in resolvers {
let (protocol, port) = match resolver.scheme() {
"dns" => (hickory_resolver::config::Protocol::Udp, resolver.port().unwrap_or(53)),
"dns+https" => (hickory_resolver::config::Protocol::Https, resolver.port().unwrap_or(853)),
"dns+tls" => (hickory_resolver::config::Protocol::Tls, resolver.port().unwrap_or(12)),
_ => panic!("invalid protocol for dns resolver"),
};
let sock = match resolver.host().unwrap() {
Host::Domain(_) => panic!("Dns resolver must be an ip address"),
Host::Ipv4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
Host::Ipv6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
};
cfg.add_name_server(NameServerConfig::new(sock, protocol))
}
let opts = ResolverOpts::default();
DnsResolver::TrustDns(hickory_resolver::AsyncResolver::tokio(cfg, opts))
}
};
let server_config = WsServerConfig {
socket_so_mark: args.socket_so_mark,
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
@ -779,6 +826,7 @@ async fn main() {
timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config,
dns_resolver,
};
info!(

View file

@ -1,6 +1,7 @@
use anyhow::{anyhow, Context};
use std::{io, vec};
use crate::dns::DnsResolver;
use base64::Engine;
use bytes::BytesMut;
use log::warn;
@ -41,14 +42,15 @@ pub async fn connect(
port: u16,
so_mark: Option<u32>,
connect_timeout: Duration,
dns_resolver: &DnsResolver,
) -> Result<TcpStream, anyhow::Error> {
info!("Opening TCP connection to {}:{}", host, port);
let socket_addrs: Vec<SocketAddr> = match host {
Host::Domain(domain) => tokio::net::lookup_host(format!("{}:{}", domain, port))
Host::Domain(domain) => dns_resolver
.lookup_host(domain.as_str(), port)
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?
.collect(),
.with_context(|| format!("cannot resolve domain: {}", domain))?,
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
};
@ -104,7 +106,7 @@ pub async fn connect_with_http_proxy(
let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned();
let proxy_port = proxy.port_or_known_default().unwrap_or(80);
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout, &DnsResolver::System).await?;
info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {

View file

@ -2,6 +2,7 @@ pub mod client;
mod io;
pub mod server;
use crate::dns::DnsResolver;
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
@ -126,7 +127,7 @@ impl ManageConnection for WsClientConfig {
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout).await?
} else {
tcp::connect(host, *port, so_mark, timeout).await?
tcp::connect(host, *port, so_mark, timeout, &DnsResolver::System).await?
};
match &self.tls {

View file

@ -67,7 +67,13 @@ async fn from_query(
match jwt.claims.p {
LocalProtocol::Udp { timeout, .. } => {
let host = Host::parse(&jwt.claims.r)?;
let cnx = udp::connect(&host, jwt.claims.rp, timeout.unwrap_or(Duration::from_secs(10))).await?;
let cnx = udp::connect(
&host,
jwt.claims.rp,
timeout.unwrap_or(Duration::from_secs(10)),
&server_config.dns_resolver,
)
.await?;
Ok((
LocalProtocol::Udp { timeout: None },
host,
@ -79,9 +85,15 @@ async fn from_query(
LocalProtocol::Tcp => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(&host, port, server_config.socket_so_mark, Duration::from_secs(10))
.await?
.into_split();
let (rx, tx) = tcp::connect(
&host,
port,
server_config.socket_so_mark,
Duration::from_secs(10),
&server_config.dns_resolver,
)
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
}

View file

@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;
use crate::dns::DnsResolver;
use tokio::sync::Notify;
use tokio::time::{timeout, Interval};
use tracing::{debug, error, info};
@ -295,16 +296,21 @@ impl AsyncWrite for MyUdpSocket {
}
}
pub async fn connect(host: &Host<String>, port: u16, connect_timeout: Duration) -> anyhow::Result<MyUdpSocket> {
pub async fn connect(
host: &Host<String>,
port: u16,
connect_timeout: Duration,
dns_resolver: &DnsResolver,
) -> anyhow::Result<MyUdpSocket> {
info!("Opening UDP connection to {}:{}", host, port);
let socket_addrs: Vec<SocketAddr> = match host {
Host::Domain(domain) => timeout(connect_timeout, tokio::net::lookup_host(format!("{}:{}", domain, port)))
.await
.with_context(|| format!("cannot resolve domain: {}", domain))??
.collect(),
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
Host::Domain(domain) => dns_resolver
.lookup_host(domain.as_str(), port)
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?,
};
let mut cnx = None;