refacto: split into modules

This commit is contained in:
Σrebe - Romain GERARD 2024-07-28 13:14:08 +02:00
parent 6a07201de1
commit 38cb7ed5f8
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
35 changed files with 745 additions and 596 deletions

3
src/protocols/dns/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod resolver;
pub use resolver::DnsResolver;

View file

@ -0,0 +1,286 @@
use crate::protocols;
use anyhow::{anyhow, Context};
use futures_util::{FutureExt, TryFutureExt};
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
use hickory_resolver::proto::TokioTime;
use hickory_resolver::{AsyncResolver, TokioHandle};
use log::warn;
use std::future::Future;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpStream, UdpSocket};
use url::{Host, Url};
// Interleave v4 and v6 addresses as per RFC8305.
// The first address is v6 if we have any v6 addresses.
#[inline]
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator<Item = &'_ SocketAddr> {
let mut pick_v6 = !prefer_ipv6;
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
std::iter::from_fn(move || {
pick_v6 = !pick_v6;
if pick_v6 {
v6.next().or_else(|| v4.next())
} else {
v4.next().or_else(|| v6.next())
}
})
}
#[derive(Clone)]
pub enum DnsResolver {
System,
TrustDns {
resolver: AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
prefer_ipv6: bool,
},
}
impl DnsResolver {
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
let addrs: Vec<SocketAddr> = match self {
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
Self::TrustDns { resolver, prefer_ipv6 } => {
let addrs: Vec<_> = resolver
.lookup_ip(domain)
.await?
.into_iter()
.map(|ip| match ip {
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
})
.collect();
sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect()
}
};
Ok(addrs)
}
pub fn new_from_urls(
resolvers: &[Url],
proxy: Option<Url>,
so_mark: Option<u32>,
prefer_ipv6: bool,
) -> anyhow::Result<Self> {
fn mk_resolver(
cfg: ResolverConfig,
mut opts: ResolverOpts,
proxy: Option<Url>,
so_mark: Option<u32>,
) -> AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>> {
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
opts.timeout = Duration::from_secs(1);
// Windows end-up with too many dns resolvers, which causes a performance issue
// https://github.com/hickory-dns/hickory-dns/issues/1968
#[cfg(target_os = "windows")]
{
opts.cache_size = 1024;
opts.num_concurrent_reqs = cfg.name_servers().len();
}
AsyncResolver::new(
cfg,
opts,
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
)
}
fn get_sni(resolver: &Url) -> anyhow::Result<String> {
Ok(resolver
.query_pairs()
.find(|(k, _)| k == "sni")
.with_context(|| "Missing `sni` query parameter for dns over https")?
.1
.to_string())
}
fn url_to_ns_config(resolver: &Url) -> anyhow::Result<NameServerConfig> {
let (protocol, port, tls_sni) = match resolver.scheme() {
"dns" => (Protocol::Udp, resolver.port().unwrap_or(53), None),
"dns+https" => (Protocol::Https, resolver.port().unwrap_or(443), Some(get_sni(resolver)?)),
"dns+tls" => (Protocol::Tls, resolver.port().unwrap_or(853), Some(get_sni(resolver)?)),
_ => return Err(anyhow!("invalid protocol for dns resolver")),
};
let host = resolver
.host()
.ok_or_else(|| anyhow!("Invalid dns resolver host: {}", resolver))?;
let sock = match host {
Host::Domain(host) => match Host::parse(host) {
Ok(Host::Ipv4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
Ok(Host::Ipv6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
Ok(Host::Domain(_)) | Err(_) => {
return Err(anyhow!("Dns resolver must be an ip address, got {}", host));
}
},
Host::Ipv4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
Host::Ipv6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
};
let mut ns = NameServerConfig::new(sock, protocol);
ns.tls_dns_name = tls_sni;
Ok(ns)
}
// no dns resolver specified, fall-back to default one
if resolvers.is_empty() {
let Ok((cfg, opts)) = hickory_resolver::system_conf::read_system_conf() else {
warn!("Fall-backing to system dns resolver. You should consider specifying a dns resolver. To avoid performance issue");
return Ok(Self::System);
};
return Ok(Self::TrustDns {
resolver: mk_resolver(cfg, opts, proxy, so_mark),
prefer_ipv6,
});
};
// if one is specified as system, use the default one from libc
if resolvers.iter().any(|r| r.scheme() == "system") {
return Ok(Self::System);
}
// otherwise, use the specified resolvers
let mut cfg = ResolverConfig::new();
for resolver in resolvers.iter() {
cfg.add_name_server(url_to_ns_config(resolver)?);
}
Ok(Self::TrustDns {
resolver: mk_resolver(cfg, ResolverOpts::default(), proxy, so_mark),
prefer_ipv6,
})
}
}
#[derive(Clone)]
pub struct TokioRuntimeProviderWithSoMark {
runtime: TokioRuntimeProvider,
proxy: Option<Arc<Url>>,
#[cfg(target_os = "linux")]
so_mark: Option<u32>,
}
impl TokioRuntimeProviderWithSoMark {
fn new(proxy: Option<Url>, so_mark: Option<u32>) -> Self {
Self {
runtime: TokioRuntimeProvider::default(),
proxy: proxy.map(Arc::new),
#[cfg(target_os = "linux")]
so_mark,
}
}
}
impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = UdpSocket;
type Tcp = AsyncIoTokioAsStd<TcpStream>;
#[inline]
fn create_handle(&self) -> Self::Handle {
self.runtime.create_handle()
}
#[inline]
fn connect_tcp(&self, server_addr: SocketAddr) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Tcp>>>> {
#[cfg(not(target_os = "linux"))]
let so_mark = None;
#[cfg(target_os = "linux")]
let so_mark = self.so_mark;
let proxy = self.proxy.clone();
let socket = async move {
let host = match server_addr.ip() {
IpAddr::V4(addr) => Host::Ipv4(addr),
IpAddr::V6(addr) => Host::Ipv6(addr),
};
if let Some(proxy) = &proxy {
protocols::tcp::connect_with_http_proxy(
proxy,
&host,
server_addr.port(),
so_mark,
Duration::from_secs(10),
&DnsResolver::System, // not going to be used as host is directly an ip address
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.map(|s| s.map(AsyncIoTokioAsStd))
.await
} else {
protocols::tcp::connect(
&host,
server_addr.port(),
so_mark,
Duration::from_secs(10),
&DnsResolver::System, // not going to be used as host is directly an ip address
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.map(|s| s.map(AsyncIoTokioAsStd))
.await
}
};
Box::pin(socket)
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Udp>>>> {
let socket = UdpSocket::bind(local_addr);
#[cfg(target_os = "linux")]
let socket = {
use socket2::SockRef;
socket.map({
let so_mark = self.so_mark;
move |sock| {
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
SockRef::from(sock).set_mark(so_mark)?;
}
sock
}
})
};
Box::pin(socket)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[test]
fn test_sort_socket_addrs() {
let addrs = [
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
];
let expected = [
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect();
assert_eq!(expected, *actual);
}
}

View file

@ -0,0 +1,4 @@
mod server;
pub use server::run_server;
pub use server::HttpProxyListener;

View file

@ -0,0 +1,145 @@
use anyhow::Context;
use std::future::Future;
use bytes::Bytes;
use log::{debug, error};
use std::net::{Ipv4Addr, SocketAddr};
use std::pin::Pin;
use base64::Engine;
use futures_util::{future, stream, Stream};
use http_body_util::Empty;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioTimer;
use parking_lot::Mutex;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tracing::log::info;
use url::Host;
#[allow(clippy::type_complexity)]
pub struct HttpProxyListener {
listener: Pin<Box<dyn Stream<Item = anyhow::Result<(TcpStream, (Host, u16))>> + Send>>,
}
impl Stream for HttpProxyListener {
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
unsafe { self.map_unchecked_mut(|x| &mut x.listener) }.poll_next(cx)
}
}
fn handle_request(
credentials: &Option<String>,
dest: &Mutex<(Host, u16)>,
req: Request<Incoming>,
) -> impl Future<Output = Result<Response<Empty<Bytes>>, &'static str>> {
const PROXY_AUTHORIZATION_PREFIX: &str = "Basic ";
let ok_response = |forward_to: (Host, u16)| -> Result<Response<Empty<Bytes>>, _> {
*dest.lock() = forward_to;
Ok(Response::builder().status(200).body(Empty::new()).unwrap())
};
fn err_response() -> Result<Response<Empty<Bytes>>, &'static str> {
info!("Un-authorized connection to http proxy");
Err("Un-authorized")
}
if req.method() != hyper::Method::CONNECT {
return future::ready(err_response());
}
debug!("HTTP Proxy CONNECT request to {}", req.uri());
let forward_to = (
Host::parse(req.uri().host().unwrap_or_default()).unwrap_or(Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0))),
req.uri().port_u16().unwrap_or(443),
);
let Some(token) = credentials else {
return future::ready(ok_response(forward_to));
};
let Some(auth) = req.headers().get(hyper::header::PROXY_AUTHORIZATION) else {
return future::ready(err_response());
};
let auth = auth.to_str().unwrap_or_default().trim();
if auth.starts_with(PROXY_AUTHORIZATION_PREFIX) && &auth[PROXY_AUTHORIZATION_PREFIX.len()..] == token {
return future::ready(ok_response(forward_to));
}
future::ready(err_response())
}
pub async fn run_server(
bind: SocketAddr,
timeout: Option<Duration>,
credentials: Option<(String, String)>,
) -> Result<HttpProxyListener, anyhow::Error> {
info!(
"Starting http proxy server listening cnx on {} with credentials {:?}",
bind, credentials
);
let listener = TcpListener::bind(bind)
.await
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
let http1 = {
let mut builder = http1::Builder::new();
builder
.timer(TokioTimer::new())
.header_read_timeout(timeout)
.keep_alive(false);
builder
};
let auth_header =
credentials.map(|(user, pass)| base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, pass)));
let listener = stream::unfold((listener, http1, auth_header), |(listener, http1, auth_header)| async {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(v) => v,
Err(err) => {
error!("Error while accepting connection {:?}", err);
continue;
}
};
let forward_to = Mutex::new((Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0)), 0));
let conn_fut = http1.serve_connection(
hyper_util::rt::TokioIo::new(&mut stream),
service_fn(|req| handle_request(&auth_header, &forward_to, req)),
);
match conn_fut.await {
Ok(_) => return Some((Ok((stream, forward_to.into_inner())), (listener, http1, auth_header))),
Err(err) => {
info!("Error while serving connection: {}", err);
continue;
}
}
}
});
Ok(HttpProxyListener {
listener: Box::pin(listener),
})
}
//#[cfg(test)]
//mod tests {
// use super::*;
// use tracing::level_filters::LevelFilter;
//
// #[tokio::test]
// async fn test_run_server() {
// tracing_subscriber::fmt()
// .with_ansi(true)
// .with_max_level(LevelFilter::TRACE)
// .init();
// let x = run_server("127.0.0.1:1212".parse().unwrap(), None, None).await;
// }
//}

9
src/protocols/mod.rs Normal file
View file

@ -0,0 +1,9 @@
pub mod dns;
pub mod http_proxy;
pub mod socks5;
pub mod stdio;
pub mod tcp;
pub mod tls;
pub mod udp;
#[cfg(unix)]
pub mod unix_sock;

View file

@ -0,0 +1,6 @@
mod tcp_server;
mod udp_server;
pub use tcp_server::run_server;
pub use tcp_server::Socks5Listener;
pub use tcp_server::Socks5Stream;

View file

@ -0,0 +1,274 @@
use super::udp_server::Socks5UdpStream;
use crate::LocalProtocol;
use anyhow::Context;
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
use fast_socks5::util::target_addr::TargetAddr;
use fast_socks5::{consts, ReplyError};
use futures_util::{stream, Stream, StreamExt};
use std::io::{Error, IoSlice};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::task::Poll;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio::select;
use tracing::{info, warn};
use url::Host;
#[allow(clippy::type_complexity)]
pub struct Socks5Listener {
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
}
pub enum Socks5Stream {
Tcp(TcpStream),
Udp(Socks5UdpStream),
}
impl Socks5Stream {
pub fn local_protocol(&self) -> LocalProtocol {
match self {
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
Self::Udp(s) => LocalProtocol::Udp {
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
},
}
}
}
impl Stream for Socks5Listener {
type Item = anyhow::Result<(Socks5Stream, (Host, u16))>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socks_server) }.poll_next(cx)
}
}
pub async fn run_server(
bind: SocketAddr,
timeout: Option<Duration>,
credentials: Option<(String, String)>,
) -> Result<Socks5Listener, anyhow::Error> {
info!(
"Starting SOCKS5 server listening cnx on {} with credentials {:?}",
bind, credentials
);
let server = Socks5Server::<DenyAuthentication>::bind(bind)
.await
.with_context(|| format!("Cannot create socks5 server {:?}", bind))?;
let mut cfg = Config::default();
cfg = if let Some((username, password)) = credentials {
cfg.set_allow_no_auth(false);
cfg.with_authentication(SimpleUserPassword { username, password })
} else {
cfg.set_allow_no_auth(true);
cfg
};
cfg.set_dns_resolve(false);
cfg.set_execute_command(false);
cfg.set_udp_support(true);
let udp_server = super::udp_server::run_server(bind, timeout).await?;
let server = server.with_config(cfg);
let stream = stream::unfold((server, Box::pin(udp_server)), move |(server, mut udp_server)| async move {
let mut acceptor = server.incoming();
loop {
let cnx = select! {
biased;
cnx = acceptor.next() => match cnx {
None => return None,
Some(Err(err)) => {
drop(acceptor);
return Some((Err(anyhow::Error::new(err)), (server, udp_server)));
}
Some(Ok(cnx)) => cnx,
},
// new incoming udp stream
udp_conn = udp_server.next() => {
drop(acceptor);
return match udp_conn {
Some(Ok(stream)) => {
let dest = stream.destination();
Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server)))
}
Some(Err(err)) => {
Some((Err(anyhow::Error::new(err)), (server, udp_server)))
}
None => {
None
}
};
}
};
let cnx = match cnx.upgrade_to_socks5().await {
Ok(cnx) => cnx,
Err(err) => {
warn!("Rejecting socks5 cnx: {}", err);
continue;
}
};
let Some(target) = cnx.target_addr() else {
warn!("Rejecting socks5 cnx: no target addr");
continue;
};
let (host, port) = match target {
TargetAddr::Ip(SocketAddr::V4(ip)) => (Host::Ipv4(*ip.ip()), ip.port()),
TargetAddr::Ip(SocketAddr::V6(ip)) => (Host::Ipv6(*ip.ip()), ip.port()),
TargetAddr::Domain(host, port) => (Host::Domain(host.clone()), *port),
};
// Special case for UDP Associate where we return the bind addr of the udp server
if matches!(cnx.cmd(), Some(fast_socks5::Socks5Command::UDPAssociate)) {
let mut cnx = cnx.into_inner();
let ret = cnx.write_all(&new_reply(&ReplyError::Succeeded, bind)).await;
if let Err(err) = ret {
warn!("Cannot reply to socks5 udp client: {}", err);
continue;
}
tokio::spawn(async move {
let mut buf = [0u8; 8];
loop {
match cnx.read(&mut buf).await {
Ok(0) => return,
Err(_) => return,
_ => {}
}
}
});
continue;
};
let mut cnx = cnx.into_inner();
let ret = cnx
.write_all(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
))
.await;
if let Err(err) = ret {
warn!("Cannot reply to socks5 client: {}", err);
continue;
}
drop(acceptor);
return Some((Ok((Socks5Stream::Tcp(cnx), (host, port))), (server, udp_server)));
}
});
let listener = Socks5Listener {
socks_server: Box::pin(stream),
};
Ok(listener)
}
fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
let (addr_type, mut ip_oct, mut port) = match sock_addr {
SocketAddr::V4(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV4,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
SocketAddr::V6(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV6,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
};
let mut reply = vec![
consts::SOCKS5_VERSION,
error.as_u8(), // transform the error into byte code
0x00, // reserved
addr_type, // address type (ipv4, v6, domain)
];
reply.append(&mut ip_oct);
reply.append(&mut port);
reply
}
impl Unpin for Socks5Stream {}
impl AsyncRead for Socks5Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf),
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf),
}
}
}
impl AsyncWrite for Socks5Stream {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx),
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx),
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs),
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Tcp(s) => s.is_write_vectored(),
Self::Udp(s) => s.is_write_vectored(),
}
}
}
//#[cfg(test)]
//mod test {
// use super::*;
// use futures_util::StreamExt;
// use std::str::FromStr;
//
// #[tokio::test]
// async fn socks5_server() {
// let mut x = run_server(SocketAddr::from_str("[::]:4343").unwrap())
// .await
// .unwrap();
//
// loop {
// let cnx = x.next().await.unwrap().unwrap();
// eprintln!("{:?}", cnx);
// }
// }
//}

View file

@ -0,0 +1,285 @@
use anyhow::Context;
use futures_util::{stream, Stream};
use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop};
use std::collections::HashMap;
use std::io;
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
use crate::tunnel::to_host_port;
use bytes::{Buf, Bytes, BytesMut};
use fast_socks5::new_udp_header;
use fast_socks5::util::target_addr::TargetAddr;
use log::warn;
use std::pin::{pin, Pin};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::time::Interval;
use tracing::{debug, error, info};
use url::Host;
type PeerMapKey = (SocketAddr, TargetAddr);
struct IoInner {
sender: mpsc::Sender<Bytes>,
}
struct Socks5UdpServer {
listener: Arc<UdpSocket>,
peers: HashMap<PeerMapKey, Pin<Arc<IoInner>>, ahash::RandomState>,
keys_to_delete: Arc<RwLock<Vec<PeerMapKey>>>,
cnx_timeout: Option<Duration>,
}
impl Socks5UdpServer {
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
let socket = socket2::SockRef::from(&listener);
// Increase receive buffer
const BUF_SIZES: [usize; 7] = [64usize, 32usize, 16usize, 8usize, 4usize, 2usize, 1usize];
for size in BUF_SIZES.iter() {
if let Err(err) = socket.set_recv_buffer_size(size * 1024 * 1024) {
warn!("Cannot increase UDP server recv buffer to {} Mib: {}", size, err);
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
continue;
}
if *size != BUF_SIZES[0] {
info!("Increased UDP server recv buffer to {} Mib", size);
}
break;
}
for size in BUF_SIZES.iter() {
if let Err(err) = socket.set_send_buffer_size(size * 1024 * 1024) {
warn!("Cannot increase UDP server send buffer to {} Mib: {}", size, err);
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
continue;
}
if *size != BUF_SIZES[0] {
info!("Increased UDP server send buffer to {} Mib", size);
}
break;
}
Self {
listener: Arc::new(listener),
peers: HashMap::with_hasher(ahash::RandomState::new()),
keys_to_delete: Default::default(),
cnx_timeout: timeout,
}
}
#[inline]
pub fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().len();
if nb_key_to_delete == 0 {
return;
}
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() {
self.peers.remove(key);
}
keys_to_delete.clear();
}
}
#[pin_project(PinnedDrop)]
pub struct Socks5UdpStream {
#[pin]
recv_data: mpsc::Receiver<Bytes>,
send_socket: Arc<UdpSocket>,
destination: TargetAddr,
peer: SocketAddr,
udp_header: Vec<u8>,
#[pin]
pub watchdog_deadline: Option<Interval>,
data_read_before_deadline: bool,
io: Pin<Arc<IoInner>>,
keys_to_delete: Weak<RwLock<Vec<PeerMapKey>>>,
}
#[pinned_drop]
impl PinnedDrop for Socks5UdpStream {
fn drop(self: Pin<&mut Self>) {
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
keys_to_delete.write().push((self.peer, self.destination.clone()));
}
}
}
impl Socks5UdpStream {
fn new(
send_socket: Arc<UdpSocket>,
peer: SocketAddr,
destination: TargetAddr,
watchdog_deadline: Option<Duration>,
keys_to_delete: Weak<RwLock<Vec<PeerMapKey>>>,
) -> (Self, Pin<Arc<IoInner>>) {
let (tx, rx) = mpsc::channel(1024);
let io = Arc::pin(IoInner { sender: tx });
let udp_header = match &destination {
TargetAddr::Ip(ip) => new_udp_header(*ip).unwrap(),
TargetAddr::Domain(h, p) => new_udp_header((h.as_str(), *p)).unwrap(),
};
let s = Self {
recv_data: rx,
send_socket,
peer,
destination,
watchdog_deadline: watchdog_deadline
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
data_read_before_deadline: false,
io: io.clone(),
keys_to_delete,
udp_header,
};
(s, io)
}
pub fn destination(&self) -> (Host, u16) {
match &self.destination {
TargetAddr::Ip(sock_addr) => to_host_port(*sock_addr),
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
}
}
}
impl AsyncRead for Socks5UdpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut project = self.project();
// Look that the timeout for client has not elapsed
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
if deadline.poll_tick(cx).is_ready() {
if !*project.data_read_before_deadline {
return Poll::Ready(Err(Error::new(
ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer),
)));
};
*project.data_read_before_deadline = false;
while deadline.poll_tick(cx).is_ready() {}
}
}
let Some(data) = ready!(project.recv_data.poll_recv(cx)) else {
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)));
};
if obuf.remaining() < data.len() {
return Poll::Ready(Err(Error::new(
ErrorKind::InvalidData,
"udp dst buffer does not have enough space left. Can't fragment",
)));
}
obuf.put_slice(data.chunk());
*project.data_read_before_deadline = true;
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for Socks5UdpStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
let this = self.project();
let header_len = this.udp_header.len();
this.udp_header.extend_from_slice(buf);
let ret = this
.send_socket
.poll_send_to(cx, this.udp_header.as_slice(), *this.peer);
this.udp_header.truncate(header_len);
ret.map(|r| r.map(|write_len| write_len - header_len))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self.send_socket.poll_send_ready(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
pub async fn run_server(
bind: SocketAddr,
timeout: Option<Duration>,
) -> Result<impl Stream<Item = io::Result<Socks5UdpStream>>, anyhow::Error> {
let listener = UdpSocket::bind(bind)
.await
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
let udp_server = Socks5UdpServer::new(listener, timeout);
static MAX_PACKET_LENGTH: usize = 64 * 1024;
let buffer = BytesMut::with_capacity(MAX_PACKET_LENGTH * 10);
let stream = stream::unfold((udp_server, buffer), |(mut server, mut buf)| async move {
loop {
server.clean_dead_keys();
buf.reserve(MAX_PACKET_LENGTH);
let peer_addr = match server.listener.recv_buf_from(&mut buf).await {
Ok((_read_len, peer_addr)) => peer_addr,
Err(err) => {
error!("Cannot read from UDP server. Closing server: {}", err);
return None;
}
};
let (destination_addr, data) = {
let payload = buf.split().freeze();
let (frag, destination_addr, data) = match fast_socks5::parse_udp_request(payload.chunk()).await {
Ok((frag, addr, data)) => (frag, addr, data),
Err(err) => {
warn!("Skipping invalid UDP socks5 request: {} ", err);
debug!("Invalid UDP socks5 request: {:?}", payload.chunk());
continue;
}
};
// We don't support udp fragmentation
if frag != 0 {
warn!("dropping UDP socks5 fragmented");
continue;
}
(destination_addr, payload.slice_ref(data))
};
let addr = (peer_addr, destination_addr);
match server.peers.get(&addr) {
Some(io) => {
if io.sender.send(data).await.is_err() {
server.peers.remove(&addr);
}
}
None => {
info!("New UDP connection for {}", addr.1);
let (udp_client, io) = Socks5UdpStream::new(
server.listener.clone(),
addr.0,
addr.1.clone(),
server.cnx_timeout,
Arc::downgrade(&server.keys_to_delete),
);
let _ = io.sender.send(data).await;
server.peers.insert(addr, io);
return Some((Ok(udp_client), (server, buf)));
}
}
}
});
Ok(stream)
}

View file

@ -0,0 +1,9 @@
#[cfg(unix)]
mod server_unix;
#[cfg(not(unix))]
mod server_windows;
#[cfg(unix)]
pub use server_unix::run_server;
#[cfg(not(unix))]
pub use server_windows::run_server;

View file

@ -0,0 +1,27 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::oneshot;
use tokio_fd::AsyncFd;
use tracing::info;
pub struct WsStdin {
stdin: AsyncFd,
_receiver: oneshot::Receiver<()>,
}
impl AsyncRead for WsStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
}
}
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
let (tx, rx) = oneshot::channel::<()>();
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
}

View file

@ -0,0 +1,82 @@
use bytes::BytesMut;
use log::error;
use parking_lot::Mutex;
use scopeguard::guard;
use std::io::{Read, Write};
use std::sync::Arc;
use std::{io, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader;
use tracing::info;
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = oneshot::channel::<()>();
let abort_rx = Arc::new(Mutex::new(abort_rx));
let abort_rx2 = abort_rx.clone();
thread::spawn(move || {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx.lock().close();
});
let stdin = stdin;
let mut stdin = stdin.lock();
let mut buf = [0u8; 65536];
loop {
let n = stdin.read(&mut buf).unwrap_or(0);
if n == 0 || (n == 1 && buf[0] == 3) {
// ctrl+c send char 3
break;
}
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
error!("Failed send inout: {:?}", err);
break;
}
}
});
let stdin = StreamReader::new(UnboundedReceiverStream::new(recv));
let (stdout, mut recv) = tokio::io::duplex(65536);
let rt = tokio::runtime::Handle::current();
thread::spawn(move || {
let task = async move {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx2.lock().close();
});
let mut stdout = io::stdout().lock();
let mut buf = [0u8; 65536];
loop {
let Ok(n) = recv.read(&mut buf).await else {
break;
};
if n == 0 {
break;
}
if let Err(err) = stdout.write_all(&buf[..n]) {
error!("Failed to write to stdout: {:?}", err);
break;
};
let _ = stdout.flush();
}
};
let local = LocalSet::new();
local.spawn_local(task);
rt.block_on(local);
});
Ok(((stdin, stdout), abort_tx))
}

6
src/protocols/tcp/mod.rs Normal file
View file

@ -0,0 +1,6 @@
mod server;
pub use server::configure_socket;
pub use server::connect;
pub use server::connect_with_http_proxy;
pub use server::run_server;

302
src/protocols/tcp/server.rs Normal file
View file

@ -0,0 +1,302 @@
use anyhow::{anyhow, Context};
use std::{io, vec};
use tokio::task::JoinSet;
use base64::Engine;
use bytes::BytesMut;
use log::warn;
use socket2::{SockRef, TcpKeepalive};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::protocols::dns::DnsResolver;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::time::{sleep, timeout};
use tokio_stream::wrappers::TcpListenerStream;
use tracing::log::info;
use tracing::{debug, instrument};
use url::{Host, Url};
pub fn configure_socket(socket: SockRef, so_mark: &Option<u32>) -> Result<(), anyhow::Error> {
socket
.set_nodelay(true)
.with_context(|| format!("cannot set no_delay on socket: {:?}", io::Error::last_os_error()))?;
#[cfg(not(any(target_os = "windows", target_os = "openbsd")))]
let tcp_keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10))
.with_retries(3);
#[cfg(target_os = "windows")]
let tcp_keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
#[cfg(target_os = "openbsd")]
let tcp_keepalive = TcpKeepalive::new().with_time(Duration::from_secs(60));
socket
.set_tcp_keepalive(&tcp_keepalive)
.with_context(|| format!("cannot set tcp_keepalive on socket: {:?}", io::Error::last_os_error()))?;
#[cfg(target_os = "linux")]
if let Some(so_mark) = so_mark {
socket
.set_mark(*so_mark)
.with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?;
}
Ok(())
}
pub async fn connect(
host: &Host<String>,
port: u16,
so_mark: Option<u32>,
connect_timeout: Duration,
dns_resolver: &DnsResolver,
) -> Result<TcpStream, anyhow::Error> {
info!("Opening TCP connection to {}:{}", host, port);
let socket_addrs: Vec<SocketAddr> = match host {
Host::Domain(domain) => dns_resolver
.lookup_host(domain.as_str(), port)
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?,
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
};
let mut cnx = None;
let mut last_err = None;
let mut join_set = JoinSet::new();
for (ix, addr) in socket_addrs.into_iter().enumerate() {
let socket = match &addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
};
let socket = match socket {
Ok(s) => s,
Err(err) => {
last_err = Some(err);
continue;
}
};
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
// Spawn the connection attempt in the join set.
// We include a delay of ix * 250 milliseconds, as per RFC8305.
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
let fut = async move {
if ix > 0 {
sleep(Duration::from_millis(250 * ix as u64)).await;
}
debug!("Connecting to {}", addr);
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(s)) => Ok(Ok(s)),
Ok(Err(e)) => Ok(Err((addr, e))),
Err(e) => Err((addr, e)),
}
};
join_set.spawn(fut);
}
// Wait for the next future that finishes in the join set, until we got one
// that resulted in a successful connection.
// If cnx is no longer None, we exit the loop, since this means that we got
// a successful connection.
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
match res? {
Ok(Ok(stream)) => {
// We've got a successful connection, so we can abort all other
// ongoing attempts.
join_set.abort_all();
debug!(
"Connected to tcp endpoint {}, aborted all other connection attempts",
stream.peer_addr()?
);
cnx = Some(stream);
}
Ok(Err((addr, err))) => {
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err((addr, _)) => {
warn!(
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
);
}
}
}
cnx.ok_or_else(|| anyhow!("Cannot connect to tcp endpoint {}:{} reason {:?}", host, port, last_err))
}
#[instrument(level = "info", name = "http_proxy", skip_all)]
pub async fn connect_with_http_proxy(
proxy: &Url,
host: &Host<String>,
port: u16,
so_mark: Option<u32>,
connect_timeout: Duration,
dns_resolver: &DnsResolver,
) -> Result<TcpStream, anyhow::Error> {
let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned();
let proxy_port = proxy.port_or_known_default().unwrap_or(80);
info!("Connecting to http proxy {}:{}", proxy_host, proxy_port);
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout, dns_resolver).await?;
debug!("Connected to http proxy {}", socket.peer_addr().unwrap());
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
let user = urlencoding::decode(user).with_context(|| format!("Cannot urldecode proxy user: {}", user))?;
let password =
urlencoding::decode(password).with_context(|| format!("Cannot urldecode proxy password: {}", password))?;
let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
format!("Proxy-Authorization: Basic {}\r\n", creds)
} else {
"".to_string()
};
let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
debug!("Sending request:\n{}", connect_request);
socket.write_all(connect_request.as_bytes()).await?;
let mut buf = BytesMut::with_capacity(1024);
loop {
let nb_bytes = tokio::time::timeout(connect_timeout, socket.read_buf(&mut buf)).await;
match nb_bytes {
Ok(Ok(0)) => {
return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response"
));
}
Ok(Ok(_)) => {}
Ok(Err(err)) => {
return Err(anyhow!("Cannot connect to http proxy. {err}"));
}
Err(_) => {
return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
}
};
static END_HTTP_RESPONSE: &[u8; 4] = b"\r\n\r\n"; // It is reversed from \r\n\r\n as we reverse scan the buffer
if buf.len() > 50 * 1024
|| buf
.windows(END_HTTP_RESPONSE.len())
.any(|window| window == END_HTTP_RESPONSE)
{
break;
}
}
static OK_RESPONSE_10: &[u8] = b"HTTP/1.0 200 ";
static OK_RESPONSE_11: &[u8] = b"HTTP/1.1 200 ";
if !buf
.windows(OK_RESPONSE_10.len())
.any(|window| window == OK_RESPONSE_10 || window == OK_RESPONSE_11)
{
return Err(anyhow!(
"Cannot connect to http proxy. Proxy returned an invalid response: {}",
String::from_utf8_lossy(&buf)
));
}
debug!("Got response from proxy:\n{}", String::from_utf8_lossy(&buf));
info!("Http proxy accepted connection to remote host {}:{}", host, port);
Ok(socket)
}
pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpListenerStream, anyhow::Error> {
info!("Starting TCP server listening cnx on {}", bind);
let listener = TcpListener::bind(bind)
.await
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
#[cfg(target_os = "linux")]
if ip_transparent {
info!("TCP server listening in TProxy mode");
socket2::SockRef::from(&listener).set_ip_transparent(ip_transparent)?;
}
Ok(TcpListenerStream::new(listener))
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::pin_mut;
use std::net::SocketAddr;
use testcontainers::core::WaitFor;
use testcontainers::runners::AsyncRunner;
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
#[derive(Debug, Clone, Default)]
pub struct MitmProxy {}
impl ImageArgs for MitmProxy {
fn into_iterator(self) -> Box<dyn Iterator<Item = String>> {
Box::new(vec!["mitmdump".to_string()].into_iter())
}
}
impl Image for MitmProxy {
type Args = Self;
fn name(&self) -> String {
"mitmproxy/mitmproxy".to_string()
}
fn tag(&self) -> String {
"10.1.1".to_string()
}
fn ready_conditions(&self) -> Vec<WaitFor> {
vec![WaitFor::Duration {
length: Duration::from_secs(5),
}]
}
}
#[tokio::test]
async fn test_proxy_connection() {
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
let server = TcpListener::bind(server_addr).await.unwrap();
let _mitm_proxy: ContainerAsync<MitmProxy> = RunnableImage::from(MitmProxy {})
.with_network("host".to_string())
.start()
.await
.unwrap();
let mut client = connect_with_http_proxy(
&"http://localhost:8080".parse().unwrap(),
&Host::Domain("[::1]".to_string()),
1236,
None,
Duration::from_secs(1),
&DnsResolver::System,
)
.await
.unwrap();
client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
let client_srv = server.accept().await.unwrap().0;
pin_mut!(client_srv);
let mut buf = [0u8; 25];
let ret = client_srv.read(&mut buf).await;
assert!(matches!(ret, Ok(18)));
client_srv.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
client_srv.get_mut().shutdown().await.unwrap();
let _ = client.read(&mut buf).await.unwrap();
assert!(buf.starts_with(b"HTTP/1.1 200 OK\r\n\r\n"));
}
}

10
src/protocols/tls/mod.rs Normal file
View file

@ -0,0 +1,10 @@
mod server;
mod utils;
pub use server::connect;
pub use server::load_certificates_from_pem;
pub use server::load_private_key_from_file;
pub use server::tls_acceptor;
pub use server::tls_connector;
pub use utils::cn_from_certificate;
pub use utils::find_leaf_certificate;

205
src/protocols/tls/server.rs Normal file
View file

@ -0,0 +1,205 @@
use crate::{TlsServerConfig, WsClientConfig};
use anyhow::{anyhow, Context};
use std::fs::File;
use log::warn;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use crate::tunnel::TransportAddr;
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::{ClientConfig, DigitallySignedStruct, Error, KeyLogFile, SignatureScheme};
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use tracing::info;
#[derive(Debug)]
struct NullVerifier;
impl ServerCertVerifier for NullVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
info!("Loading tls certificate from {:?}", path);
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader);
Ok(certs
.into_iter()
.filter_map(|cert| match cert {
Ok(cert) => Some(cert),
Err(err) => {
warn!("Error while parsing tls certificate: {:?}", err);
None
}
})
.collect())
}
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
info!("Loading tls private key from {:?}", path);
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let Some(private_key) = rustls_pemfile::private_key(&mut reader)? else {
return Err(anyhow!("No private key found in {path:?}"));
};
Ok(private_key)
}
pub fn tls_connector(
tls_verify_certificate: bool,
alpn_protocols: Vec<Vec<u8>>,
enable_sni: bool,
tls_client_certificate: Option<Vec<CertificateDer<'static>>>,
tls_client_key: Option<PrivateKeyDer<'static>>,
) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
for cert in certs {
if let Err(err) = root_store.add(cert) {
warn!("cannot load a system certificate: {:?}", err);
continue;
}
}
let config_builder = ClientConfig::builder().with_root_certificates(root_store);
let mut config = match (tls_client_certificate, tls_client_key) {
(Some(tls_client_certificate), Some(tls_client_key)) => config_builder
.with_client_auth_cert(tls_client_certificate, tls_client_key)
.with_context(|| "Error setting up mTLS")?,
_ => config_builder.with_no_client_auth(),
};
config.enable_sni = enable_sni;
config.key_log = Arc::new(KeyLogFile::new());
// To bypass certificate verification
if !tls_verify_certificate {
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
}
config.alpn_protocols = alpn_protocols;
let tls_connector = TlsConnector::from(Arc::new(config));
Ok(tls_connector)
}
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
let client_cert_verifier = if let Some(tls_client_ca_certificates) = &tls_cfg.tls_client_ca_certificates {
let mut root_store = rustls::RootCertStore::empty();
for tls_client_ca_certificate in tls_client_ca_certificates.lock().iter() {
root_store
.add(tls_client_ca_certificate.clone())
.with_context(|| "Failed to add mTLS client CA certificate")?;
}
WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|err| anyhow!("Failed to build mTLS client verifier: {:?}", err))?
} else {
WebPkiClientVerifier::no_client_auth()
};
let mut config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_cert_verifier)
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone_key())
.with_context(|| "invalid tls certificate or private key")?;
config.key_log = Arc::new(KeyLogFile::new());
if let Some(alpn_protocols) = alpn_protocols {
config.alpn_protocols = alpn_protocols;
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = client_cfg.tls_server_name();
let (tls_connector, sni_disabled) = match &client_cfg.remote_addr {
TransportAddr::Wss { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
TransportAddr::Https { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
TransportAddr::Http { .. } | TransportAddr::Ws { .. } => {
return Err(anyhow!("Transport does not support TLS: {}", client_cfg.remote_addr.scheme()))
}
};
if sni_disabled {
info!(
"Doing TLS handshake without SNI with the server {}:{}",
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port()
);
} else {
info!(
"Doing TLS handshake using SNI {sni:?} with the server {}:{}",
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port()
);
}
let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| {
format!(
"failed to do TLS handshake with the server {}:{}",
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port()
)
})?;
Ok(tls_stream)
}

View file

@ -0,0 +1,27 @@
use tokio_rustls::rustls::pki_types::CertificateDer;
use x509_parser::parse_x509_certificate;
use x509_parser::prelude::X509Certificate;
/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate
/// is present in the vector. The other certificates should be (intermediate) CA certificates.
pub fn find_leaf_certificate<'a>(tls_certificates: &'a [CertificateDer<'static>]) -> Option<X509Certificate<'a>> {
for tls_certificate in tls_certificates {
if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) {
if !tls_certificate_x509.is_ca() {
return Some(tls_certificate_x509);
}
}
}
None
}
/// Returns the common name (CN) as specified in the supplied certificate.
pub fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option<String> {
tls_certificate_x509
.tbs_certificate
.subject
.iter_common_name()
.flat_map(|cn| cn.as_str().ok())
.next()
.map(|cn| cn.to_string())
}

11
src/protocols/udp/mod.rs Normal file
View file

@ -0,0 +1,11 @@
mod server;
#[cfg(target_os = "linux")]
pub use server::configure_tproxy;
pub use server::connect;
#[cfg(target_os = "linux")]
pub use server::mk_send_socket_tproxy;
pub use server::run_server;
pub use server::MyUdpSocket;
pub use server::UdpStream;
pub use server::UdpStreamWriter;

658
src/protocols/udp/server.rs Normal file
View file

@ -0,0 +1,658 @@
use anyhow::{anyhow, Context};
use futures_util::{stream, Stream};
use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop};
use std::collections::HashMap;
use std::future::Future;
use std::io::{Error, ErrorKind};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::{io, task};
use tokio::task::JoinSet;
use log::warn;
use socket2::SockRef;
use std::pin::{pin, Pin};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::sync::futures::Notified;
use crate::protocols::dns::DnsResolver;
use tokio::sync::Notify;
use tokio::time::{sleep, timeout, Interval};
use tracing::{debug, error, info};
use url::Host;
struct IoInner {
has_data_to_read: Notify,
has_read_data: Notify,
}
struct UdpServer {
listener: Arc<UdpSocket>,
peers: HashMap<SocketAddr, Pin<Arc<IoInner>>, ahash::RandomState>,
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
cnx_timeout: Option<Duration>,
}
impl UdpServer {
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
let socket = socket2::SockRef::from(&listener);
// Increase receive buffer
const BUF_SIZES: [usize; 7] = [64usize, 32usize, 16usize, 8usize, 4usize, 2usize, 1usize];
for size in BUF_SIZES.iter() {
if let Err(err) = socket.set_recv_buffer_size(size * 1024 * 1024) {
warn!("Cannot increase UDP server recv buffer to {} Mib: {}", size, err);
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
continue;
}
if *size != BUF_SIZES[0] {
info!("Increased UDP server recv buffer to {} Mib", size);
}
break;
}
for size in BUF_SIZES.iter() {
if let Err(err) = socket.set_send_buffer_size(size * 1024 * 1024) {
warn!("Cannot increase UDP server send buffer to {} Mib: {}", size, err);
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
continue;
}
if *size != BUF_SIZES[0] {
info!("Increased UDP server send buffer to {} Mib", size);
}
break;
}
Self {
listener: Arc::new(listener),
peers: HashMap::with_hasher(ahash::RandomState::new()),
keys_to_delete: Default::default(),
cnx_timeout: timeout,
}
}
#[inline]
pub fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().len();
if nb_key_to_delete == 0 {
return;
}
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() {
self.peers.remove(key);
}
keys_to_delete.clear();
}
pub fn clone_socket(&self) -> Arc<UdpSocket> {
self.listener.clone()
}
}
#[pin_project(PinnedDrop)]
pub struct UdpStream {
recv_socket: Arc<UdpSocket>,
send_socket: Arc<UdpSocket>,
peer: SocketAddr,
#[pin]
watchdog_deadline: Option<Interval>,
data_read_before_deadline: bool,
has_been_notified: bool,
#[pin]
pending_notification: Option<Notified<'static>>,
io: Pin<Arc<IoInner>>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
}
#[pinned_drop]
impl PinnedDrop for UdpStream {
fn drop(self: Pin<&mut Self>) {
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
keys_to_delete.write().push(self.peer);
}
// safety: we are dropping the notification as we extend its lifetime to 'static unsafely
// So it must be gone before we drop its parent. It should never happen but in case
let mut project = self.project();
project.pending_notification.as_mut().set(None);
project.io.has_read_data.notify_one();
}
}
impl UdpStream {
fn new(
recv_socket: Arc<UdpSocket>,
send_socket: Arc<UdpSocket>,
peer: SocketAddr,
watchdog_deadline: Option<Duration>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
) -> (Self, Pin<Arc<IoInner>>) {
let has_data_to_read = Notify::new();
let has_read_data = Notify::new();
let io = Arc::pin(IoInner {
has_data_to_read,
has_read_data,
});
let mut s = Self {
recv_socket,
send_socket,
peer,
watchdog_deadline: watchdog_deadline
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
data_read_before_deadline: false,
has_been_notified: false,
pending_notification: None,
io: io.clone(),
keys_to_delete,
};
let pending_notification =
unsafe { std::mem::transmute::<Notified<'_>, Notified<'static>>(s.io.has_data_to_read.notified()) };
s.pending_notification = Some(pending_notification);
(s, io)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.send_socket.local_addr()
}
pub fn writer(&self) -> UdpStreamWriter {
UdpStreamWriter {
send_socket: self.send_socket.clone(),
peer: self.peer,
}
}
}
impl AsyncRead for UdpStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, obuf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let mut project = self.project();
// Look that the timeout for client has not elapsed
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
if deadline.poll_tick(cx).is_ready() {
if !*project.data_read_before_deadline {
return Poll::Ready(Err(Error::new(
ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer),
)));
};
*project.data_read_before_deadline = false;
while deadline.poll_tick(cx).is_ready() {}
}
}
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
ready!(notified.poll(cx));
project.pending_notification.as_mut().set(None);
}
let peer = ready!(project.recv_socket.poll_recv_from(cx, obuf))?;
debug_assert_eq!(peer, *project.peer);
*project.data_read_before_deadline = true;
// re-arm notification
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
project.pending_notification.as_mut().set(Some(notified));
project.pending_notification.as_pin_mut().unwrap().enable();
// Let know server that we have read data
project.io.has_read_data.notify_one();
Poll::Ready(Ok(()))
}
}
pub struct UdpStreamWriter {
send_socket: Arc<UdpSocket>,
peer: SocketAddr,
}
impl AsyncWrite for UdpStreamWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self.send_socket.poll_send_to(cx, buf, self.peer)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
self.send_socket.poll_send_ready(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
pub async fn run_server(
bind: SocketAddr,
timeout: Option<Duration>,
configure_listener: impl Fn(&UdpSocket) -> anyhow::Result<()>,
mk_send_socket: impl Fn(&Arc<UdpSocket>) -> anyhow::Result<Arc<UdpSocket>>,
) -> Result<impl Stream<Item = io::Result<UdpStream>>, anyhow::Error> {
info!(
"Starting UDP server listening cnx on {} with cnx timeout of {}s",
bind,
timeout.unwrap_or(Duration::from_secs(0)).as_secs()
);
let listener = UdpSocket::bind(bind)
.await
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
configure_listener(&listener)?;
let udp_server = UdpServer::new(listener, timeout);
let stream = stream::unfold(
(udp_server, None, mk_send_socket),
|(mut server, peer_with_data, mk_send_socket)| async move {
// New returned peer hasn't read its data yet, await for it.
if let Some(await_peer) = peer_with_data {
if let Some(peer) = server.peers.get(&await_peer) {
peer.has_read_data.notified().await;
}
};
loop {
server.clean_dead_keys();
let peer_addr = match server.listener.peek_sender().await {
Ok(ret) => ret,
Err(err) => {
error!("Cannot read from UDP server. Closing server: {}", err);
return None;
}
};
match server.peers.get(&peer_addr) {
Some(io) => {
io.has_data_to_read.notify_one();
io.has_read_data.notified().await;
}
None => {
info!("New UDP connection from {}", peer_addr);
let (udp_client, io) = UdpStream::new(
server.clone_socket(),
mk_send_socket(&server.listener).ok()?,
peer_addr,
server.cnx_timeout,
Arc::downgrade(&server.keys_to_delete),
);
io.has_data_to_read.notify_waiters();
server.peers.insert(peer_addr, io);
return Some((Ok(udp_client), (server, Some(peer_addr), mk_send_socket)));
}
}
}
},
);
Ok(stream)
}
#[derive(Clone)]
pub struct MyUdpSocket {
socket: Arc<UdpSocket>,
}
impl MyUdpSocket {
pub fn new(socket: Arc<UdpSocket>) -> Self {
Self { socket }
}
}
impl AsyncRead for MyUdpSocket {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
.poll_recv_from(cx, buf)
.map(|x| x.map(|_| ()))
}
}
impl AsyncWrite for MyUdpSocket {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
pub async fn connect(
host: &Host<String>,
port: u16,
connect_timeout: Duration,
so_mark: Option<u32>,
dns_resolver: &DnsResolver,
) -> anyhow::Result<MyUdpSocket> {
info!("Opening UDP connection to {}:{}", host, port);
let socket_addrs: Vec<SocketAddr> = match host {
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
Host::Domain(domain) => dns_resolver
.lookup_host(domain.as_str(), port)
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?,
};
let mut cnx = None;
let mut last_err = None;
let mut join_set = JoinSet::new();
for (ix, addr) in socket_addrs.into_iter().enumerate() {
let socket = match &addr {
SocketAddr::V4(_) => UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await,
SocketAddr::V6(_) => UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await,
};
let socket = match socket {
Ok(socket) => socket,
Err(err) => {
warn!("cannot bind udp socket {:?}", err);
continue;
}
};
#[cfg(target_os = "linux")]
if let Some(so_mark) = so_mark {
SockRef::from(&socket)
.set_mark(so_mark)
.with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?;
}
// Spawn the connection attempt in the join set.
// We include a delay of ix * 250 milliseconds, as per RFC8305.
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
let fut = async move {
if ix > 0 {
sleep(Duration::from_millis(250 * ix as u64)).await;
}
debug!("connecting to {}", addr);
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(())) => Ok(Ok(socket)),
Ok(Err(e)) => Ok(Err((addr, e))),
Err(e) => Err((addr, e)),
}
};
join_set.spawn(fut);
}
// Wait for the next future that finishes in the join set, until we got one
// that resulted in a successful connection.
// If cnx is no longer None, we exit the loop, since this means that we got
// a successful connection.
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
match res? {
Ok(Ok(socket)) => {
// We've got a successful connection, so we can abort all other
// ongoing attempts.
join_set.abort_all();
debug!(
"Connected to udp endpoint {}, aborted all other connection attempts",
socket.peer_addr()?
);
cnx = Some(socket);
}
Ok(Err((addr, err))) => {
debug!("Cannot connect to udp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err((addr, _)) => {
warn!(
"Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
);
}
}
}
if let Some(cnx) = cnx {
Ok(MyUdpSocket::new(Arc::new(cnx)))
} else {
Err(anyhow!("Cannot connect to udp peer {}:{} reason {:?}", host, port, last_err))
}
}
#[cfg(target_os = "linux")]
pub fn configure_tproxy(listener: &UdpSocket) -> anyhow::Result<()> {
use std::net::IpAddr;
use std::os::fd::AsFd;
socket2::SockRef::from(&listener).set_ip_transparent(true)?;
match listener.local_addr().unwrap().ip() {
IpAddr::V4(_) => {
nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv4OrigDstAddr, &true)?;
}
IpAddr::V6(_) => {
nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv6OrigDstAddr, &true)?;
}
};
Ok(())
}
#[cfg(target_os = "linux")]
#[inline]
pub fn mk_send_socket_tproxy(listener: &Arc<UdpSocket>) -> anyhow::Result<Arc<UdpSocket>> {
use nix::cmsg_space;
use nix::sys::socket::{ControlMessageOwned, RecvMsg, SockaddrIn};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::io::IoSliceMut;
use std::net::IpAddr;
use std::os::fd::AsRawFd;
let mut cmsg_space = cmsg_space!(nix::libc::sockaddr_in6);
let mut buf = [0; 8];
let mut io = [IoSliceMut::new(&mut buf)];
let msg: RecvMsg<SockaddrIn> = nix::sys::socket::recvmsg(
listener.as_raw_fd(),
&mut io,
Some(&mut cmsg_space),
nix::sys::socket::MsgFlags::MSG_PEEK,
)?;
let mut remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
for cmsg in msg.cmsgs()? {
match cmsg {
ControlMessageOwned::Ipv4OrigDstAddr(ip) => {
remote_addr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::from(u32::from_be(ip.sin_addr.s_addr))),
u16::from_be(ip.sin_port),
);
}
ControlMessageOwned::Ipv6OrigDstAddr(ip) => {
remote_addr = SocketAddr::new(
IpAddr::V6(Ipv6Addr::from(u128::from_be_bytes(ip.sin6_addr.s6_addr))),
u16::from_be(ip.sin6_port),
);
}
_ => {
warn!("Unknown control message {:?}", cmsg);
}
}
}
let socket = Socket::new(Domain::for_address(remote_addr), Type::DGRAM, Some(Protocol::UDP))?;
socket.set_ip_transparent(true)?;
socket.set_reuse_address(true)?;
socket.set_reuse_port(true)?;
socket.bind(&SockAddr::from(remote_addr))?;
socket.set_nonblocking(true)?;
let socket = UdpSocket::from_std(std::net::UdpSocket::from(socket))?;
Ok(Arc::new(socket))
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{pin_mut, StreamExt};
use tokio::io::AsyncReadExt;
use tokio::time::error::Elapsed;
use tokio::time::timeout;
#[tokio::test]
async fn test_udp_server() {
let server_addr: SocketAddr = "[::1]:1234".parse().unwrap();
let server = run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone()))
.await
.unwrap();
pin_mut!(server);
// Should timeout
let fut = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut, Err(Elapsed { .. })));
// Send some data to the server
let client = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client.send_to(b"hello".as_ref(), server_addr).await.is_ok());
// Should have a new connection
let fut = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut, Ok(Some(Ok(_)))));
// Should timeout again, no new client
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut2, Err(Elapsed { .. })));
// Take the stream of data
let stream = fut.unwrap().unwrap().unwrap();
pin_mut!(stream);
let mut buf = [0u8; 25];
let ret = stream.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"hello\0");
assert!(client.send_to(b"world".as_ref(), server_addr).await.is_ok());
assert!(client.send_to(b" test".as_ref(), server_addr).await.is_ok());
// Server need to be polled to feed the stream with needed data
let _ = timeout(Duration::from_millis(100), server.next()).await;
// Udp Server should respect framing from the client and not merge the two packets
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await;
assert!(matches!(ret, Ok(Ok(5))));
let _ = timeout(Duration::from_millis(100), server.next()).await;
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[10..])).await;
assert!(matches!(ret, Ok(Ok(5))));
assert_eq!(&buf[..16], b"helloworld test\0");
}
#[tokio::test]
async fn test_multiple_client() {
let server_addr: SocketAddr = "[::1]:1235".parse().unwrap();
let mut server = Box::pin(
run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone()))
.await
.unwrap(),
);
// Send some data to the server
let client = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
// Should have a new connection
let fut = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut, Ok(Some(Ok(_)))));
// Take the stream of data
let stream = fut.unwrap().unwrap().unwrap();
pin_mut!(stream);
let mut buf = [0u8; 25];
let ret = stream.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"aaaaa\0");
// make the server make progress
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut2, Ok(Some(Ok(_)))));
let stream2 = fut2.unwrap().unwrap().unwrap();
pin_mut!(stream2);
// let the server make progress
tokio::spawn(async move {
loop {
let _ = server.next().await;
}
});
let ret = stream2.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"bbbbb\0");
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
assert!(client2.send_to(b"eeeee".as_ref(), server_addr).await.is_ok());
assert!(client.send_to(b"fffff".as_ref(), server_addr).await.is_ok());
let ret = stream.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"ccccc\0");
let ret = stream2.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"ddddd\0");
let ret = stream2.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"eeeee\0");
let ret = stream.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"fffff\0");
}
#[tokio::test]
async fn test_udp_should_timeout() {
let server_addr: SocketAddr = "[::1]:1237".parse().unwrap();
let socket_timeout = Duration::from_secs(1);
let server = run_server(server_addr, Some(socket_timeout), |_| Ok(()), |l| Ok(l.clone()))
.await
.unwrap();
pin_mut!(server);
// Send some data to the server
let client = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client.send_to(b"hello".as_ref(), server_addr).await.is_ok());
// Should have a new connection
let fut = timeout(Duration::from_millis(100), server.next()).await;
assert!(matches!(fut, Ok(Some(Ok(_)))));
// Take the stream of data
let stream = fut.unwrap().unwrap().unwrap();
pin_mut!(stream);
let mut buf = [0u8; 25];
let ret = stream.read(&mut buf).await;
assert!(matches!(ret, Ok(5)));
assert_eq!(&buf[..6], b"hello\0");
// Server need to be polled to feed the stream with need data
let _ = timeout(Duration::from_millis(100), server.next()).await;
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await;
assert!(ret.is_err());
// Stream should be closed after the timeout
tokio::time::sleep(socket_timeout).await;
let ret = stream.read(&mut buf[5..]).await;
assert!(ret.is_err());
}
}

View file

@ -0,0 +1,4 @@
mod server;
pub use server::run_server;
pub use server::UnixListenerStream;

View file

@ -0,0 +1,58 @@
use anyhow::Context;
use futures_util::Stream;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::Poll;
use tokio::net::{UnixListener, UnixStream};
use tracing::log::info;
pub struct UnixListenerStream {
inner: UnixListener,
path_to_delete: bool,
}
impl UnixListenerStream {
pub const fn new(listener: UnixListener, path_to_delete: bool) -> Self {
Self {
inner: listener,
path_to_delete,
}
}
}
impl Drop for UnixListenerStream {
fn drop(&mut self) {
if self.path_to_delete {
let Ok(addr) = &self.inner.local_addr() else {
return;
};
let Some(path) = addr.as_pathname() else {
return;
};
let _ = std::fs::remove_file(path);
}
}
}
impl Stream for UnixListenerStream {
type Item = io::Result<UnixStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<io::Result<UnixStream>>> {
match self.inner.poll_accept(cx) {
Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
Poll::Pending => Poll::Pending,
}
}
}
pub async fn run_server(socket_path: &Path) -> Result<UnixListenerStream, anyhow::Error> {
info!("Starting Unix socket server listening cnx on {:?}", socket_path);
let path_to_delete = !socket_path.exists();
let listener = UnixListener::bind(socket_path)
.with_context(|| format!("Cannot create Unix socket server {:?}", socket_path))?;
Ok(UnixListenerStream::new(listener, path_to_delete))
}