refacto: split into modules
This commit is contained in:
parent
6a07201de1
commit
38cb7ed5f8
35 changed files with 745 additions and 596 deletions
3
src/protocols/dns/mod.rs
Normal file
3
src/protocols/dns/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
mod resolver;
|
||||
|
||||
pub use resolver::DnsResolver;
|
286
src/protocols/dns/resolver.rs
Normal file
286
src/protocols/dns/resolver.rs
Normal file
|
@ -0,0 +1,286 @@
|
|||
use crate::protocols;
|
||||
use anyhow::{anyhow, Context};
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
|
||||
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
|
||||
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
|
||||
use hickory_resolver::proto::TokioTime;
|
||||
use hickory_resolver::{AsyncResolver, TokioHandle};
|
||||
use log::warn;
|
||||
use std::future::Future;
|
||||
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
use url::{Host, Url};
|
||||
|
||||
// Interleave v4 and v6 addresses as per RFC8305.
|
||||
// The first address is v6 if we have any v6 addresses.
|
||||
#[inline]
|
||||
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator<Item = &'_ SocketAddr> {
|
||||
let mut pick_v6 = !prefer_ipv6;
|
||||
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
|
||||
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
|
||||
std::iter::from_fn(move || {
|
||||
pick_v6 = !pick_v6;
|
||||
if pick_v6 {
|
||||
v6.next().or_else(|| v4.next())
|
||||
} else {
|
||||
v4.next().or_else(|| v6.next())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum DnsResolver {
|
||||
System,
|
||||
TrustDns {
|
||||
resolver: AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
|
||||
prefer_ipv6: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl DnsResolver {
|
||||
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
|
||||
let addrs: Vec<SocketAddr> = match self {
|
||||
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
|
||||
Self::TrustDns { resolver, prefer_ipv6 } => {
|
||||
let addrs: Vec<_> = resolver
|
||||
.lookup_ip(domain)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|ip| match ip {
|
||||
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
|
||||
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
|
||||
})
|
||||
.collect();
|
||||
sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(addrs)
|
||||
}
|
||||
|
||||
pub fn new_from_urls(
|
||||
resolvers: &[Url],
|
||||
proxy: Option<Url>,
|
||||
so_mark: Option<u32>,
|
||||
prefer_ipv6: bool,
|
||||
) -> anyhow::Result<Self> {
|
||||
fn mk_resolver(
|
||||
cfg: ResolverConfig,
|
||||
mut opts: ResolverOpts,
|
||||
proxy: Option<Url>,
|
||||
so_mark: Option<u32>,
|
||||
) -> AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>> {
|
||||
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
|
||||
opts.timeout = Duration::from_secs(1);
|
||||
|
||||
// Windows end-up with too many dns resolvers, which causes a performance issue
|
||||
// https://github.com/hickory-dns/hickory-dns/issues/1968
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
opts.cache_size = 1024;
|
||||
opts.num_concurrent_reqs = cfg.name_servers().len();
|
||||
}
|
||||
|
||||
AsyncResolver::new(
|
||||
cfg,
|
||||
opts,
|
||||
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_sni(resolver: &Url) -> anyhow::Result<String> {
|
||||
Ok(resolver
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "sni")
|
||||
.with_context(|| "Missing `sni` query parameter for dns over https")?
|
||||
.1
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn url_to_ns_config(resolver: &Url) -> anyhow::Result<NameServerConfig> {
|
||||
let (protocol, port, tls_sni) = match resolver.scheme() {
|
||||
"dns" => (Protocol::Udp, resolver.port().unwrap_or(53), None),
|
||||
"dns+https" => (Protocol::Https, resolver.port().unwrap_or(443), Some(get_sni(resolver)?)),
|
||||
"dns+tls" => (Protocol::Tls, resolver.port().unwrap_or(853), Some(get_sni(resolver)?)),
|
||||
_ => return Err(anyhow!("invalid protocol for dns resolver")),
|
||||
};
|
||||
let host = resolver
|
||||
.host()
|
||||
.ok_or_else(|| anyhow!("Invalid dns resolver host: {}", resolver))?;
|
||||
let sock = match host {
|
||||
Host::Domain(host) => match Host::parse(host) {
|
||||
Ok(Host::Ipv4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
|
||||
Ok(Host::Ipv6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
|
||||
Ok(Host::Domain(_)) | Err(_) => {
|
||||
return Err(anyhow!("Dns resolver must be an ip address, got {}", host));
|
||||
}
|
||||
},
|
||||
Host::Ipv4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
|
||||
Host::Ipv6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
|
||||
};
|
||||
|
||||
let mut ns = NameServerConfig::new(sock, protocol);
|
||||
ns.tls_dns_name = tls_sni;
|
||||
|
||||
Ok(ns)
|
||||
}
|
||||
|
||||
// no dns resolver specified, fall-back to default one
|
||||
if resolvers.is_empty() {
|
||||
let Ok((cfg, opts)) = hickory_resolver::system_conf::read_system_conf() else {
|
||||
warn!("Fall-backing to system dns resolver. You should consider specifying a dns resolver. To avoid performance issue");
|
||||
return Ok(Self::System);
|
||||
};
|
||||
|
||||
return Ok(Self::TrustDns {
|
||||
resolver: mk_resolver(cfg, opts, proxy, so_mark),
|
||||
prefer_ipv6,
|
||||
});
|
||||
};
|
||||
|
||||
// if one is specified as system, use the default one from libc
|
||||
if resolvers.iter().any(|r| r.scheme() == "system") {
|
||||
return Ok(Self::System);
|
||||
}
|
||||
|
||||
// otherwise, use the specified resolvers
|
||||
let mut cfg = ResolverConfig::new();
|
||||
for resolver in resolvers.iter() {
|
||||
cfg.add_name_server(url_to_ns_config(resolver)?);
|
||||
}
|
||||
|
||||
Ok(Self::TrustDns {
|
||||
resolver: mk_resolver(cfg, ResolverOpts::default(), proxy, so_mark),
|
||||
prefer_ipv6,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TokioRuntimeProviderWithSoMark {
|
||||
runtime: TokioRuntimeProvider,
|
||||
proxy: Option<Arc<Url>>,
|
||||
#[cfg(target_os = "linux")]
|
||||
so_mark: Option<u32>,
|
||||
}
|
||||
|
||||
impl TokioRuntimeProviderWithSoMark {
|
||||
fn new(proxy: Option<Url>, so_mark: Option<u32>) -> Self {
|
||||
Self {
|
||||
runtime: TokioRuntimeProvider::default(),
|
||||
proxy: proxy.map(Arc::new),
|
||||
#[cfg(target_os = "linux")]
|
||||
so_mark,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
|
||||
type Handle = TokioHandle;
|
||||
type Timer = TokioTime;
|
||||
type Udp = UdpSocket;
|
||||
type Tcp = AsyncIoTokioAsStd<TcpStream>;
|
||||
|
||||
#[inline]
|
||||
fn create_handle(&self) -> Self::Handle {
|
||||
self.runtime.create_handle()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn connect_tcp(&self, server_addr: SocketAddr) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Tcp>>>> {
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
let so_mark = None;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
let so_mark = self.so_mark;
|
||||
let proxy = self.proxy.clone();
|
||||
let socket = async move {
|
||||
let host = match server_addr.ip() {
|
||||
IpAddr::V4(addr) => Host::Ipv4(addr),
|
||||
IpAddr::V6(addr) => Host::Ipv6(addr),
|
||||
};
|
||||
|
||||
if let Some(proxy) = &proxy {
|
||||
protocols::tcp::connect_with_http_proxy(
|
||||
proxy,
|
||||
&host,
|
||||
server_addr.port(),
|
||||
so_mark,
|
||||
Duration::from_secs(10),
|
||||
&DnsResolver::System, // not going to be used as host is directly an ip address
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.map(|s| s.map(AsyncIoTokioAsStd))
|
||||
.await
|
||||
} else {
|
||||
protocols::tcp::connect(
|
||||
&host,
|
||||
server_addr.port(),
|
||||
so_mark,
|
||||
Duration::from_secs(10),
|
||||
&DnsResolver::System, // not going to be used as host is directly an ip address
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.map(|s| s.map(AsyncIoTokioAsStd))
|
||||
.await
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(socket)
|
||||
}
|
||||
|
||||
fn bind_udp(
|
||||
&self,
|
||||
local_addr: SocketAddr,
|
||||
_server_addr: SocketAddr,
|
||||
) -> Pin<Box<dyn Send + Future<Output = std::io::Result<Self::Udp>>>> {
|
||||
let socket = UdpSocket::bind(local_addr);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
let socket = {
|
||||
use socket2::SockRef;
|
||||
|
||||
socket.map({
|
||||
let so_mark = self.so_mark;
|
||||
move |sock| {
|
||||
if let (Ok(sock), Some(so_mark)) = (&sock, so_mark) {
|
||||
SockRef::from(sock).set_mark(so_mark)?;
|
||||
}
|
||||
sock
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Box::pin(socket)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
|
||||
#[test]
|
||||
fn test_sort_socket_addrs() {
|
||||
let addrs = [
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
|
||||
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
|
||||
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
|
||||
];
|
||||
let expected = [
|
||||
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
|
||||
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
|
||||
];
|
||||
let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect();
|
||||
assert_eq!(expected, *actual);
|
||||
}
|
||||
}
|
4
src/protocols/http_proxy/mod.rs
Normal file
4
src/protocols/http_proxy/mod.rs
Normal file
|
@ -0,0 +1,4 @@
|
|||
mod server;
|
||||
|
||||
pub use server::run_server;
|
||||
pub use server::HttpProxyListener;
|
145
src/protocols/http_proxy/server.rs
Normal file
145
src/protocols/http_proxy/server.rs
Normal file
|
@ -0,0 +1,145 @@
|
|||
use anyhow::Context;
|
||||
use std::future::Future;
|
||||
|
||||
use bytes::Bytes;
|
||||
use log::{debug, error};
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
|
||||
use base64::Engine;
|
||||
use futures_util::{future, stream, Stream};
|
||||
use http_body_util::Empty;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Request, Response};
|
||||
use hyper_util::rt::TokioTimer;
|
||||
use parking_lot::Mutex;
|
||||
use std::time::Duration;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::log::info;
|
||||
use url::Host;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct HttpProxyListener {
|
||||
listener: Pin<Box<dyn Stream<Item = anyhow::Result<(TcpStream, (Host, u16))>> + Send>>,
|
||||
}
|
||||
|
||||
impl Stream for HttpProxyListener {
|
||||
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.listener) }.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
credentials: &Option<String>,
|
||||
dest: &Mutex<(Host, u16)>,
|
||||
req: Request<Incoming>,
|
||||
) -> impl Future<Output = Result<Response<Empty<Bytes>>, &'static str>> {
|
||||
const PROXY_AUTHORIZATION_PREFIX: &str = "Basic ";
|
||||
let ok_response = |forward_to: (Host, u16)| -> Result<Response<Empty<Bytes>>, _> {
|
||||
*dest.lock() = forward_to;
|
||||
Ok(Response::builder().status(200).body(Empty::new()).unwrap())
|
||||
};
|
||||
fn err_response() -> Result<Response<Empty<Bytes>>, &'static str> {
|
||||
info!("Un-authorized connection to http proxy");
|
||||
Err("Un-authorized")
|
||||
}
|
||||
|
||||
if req.method() != hyper::Method::CONNECT {
|
||||
return future::ready(err_response());
|
||||
}
|
||||
|
||||
debug!("HTTP Proxy CONNECT request to {}", req.uri());
|
||||
let forward_to = (
|
||||
Host::parse(req.uri().host().unwrap_or_default()).unwrap_or(Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0))),
|
||||
req.uri().port_u16().unwrap_or(443),
|
||||
);
|
||||
|
||||
let Some(token) = credentials else {
|
||||
return future::ready(ok_response(forward_to));
|
||||
};
|
||||
|
||||
let Some(auth) = req.headers().get(hyper::header::PROXY_AUTHORIZATION) else {
|
||||
return future::ready(err_response());
|
||||
};
|
||||
|
||||
let auth = auth.to_str().unwrap_or_default().trim();
|
||||
if auth.starts_with(PROXY_AUTHORIZATION_PREFIX) && &auth[PROXY_AUTHORIZATION_PREFIX.len()..] == token {
|
||||
return future::ready(ok_response(forward_to));
|
||||
}
|
||||
|
||||
future::ready(err_response())
|
||||
}
|
||||
|
||||
pub async fn run_server(
|
||||
bind: SocketAddr,
|
||||
timeout: Option<Duration>,
|
||||
credentials: Option<(String, String)>,
|
||||
) -> Result<HttpProxyListener, anyhow::Error> {
|
||||
info!(
|
||||
"Starting http proxy server listening cnx on {} with credentials {:?}",
|
||||
bind, credentials
|
||||
);
|
||||
|
||||
let listener = TcpListener::bind(bind)
|
||||
.await
|
||||
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
|
||||
|
||||
let http1 = {
|
||||
let mut builder = http1::Builder::new();
|
||||
builder
|
||||
.timer(TokioTimer::new())
|
||||
.header_read_timeout(timeout)
|
||||
.keep_alive(false);
|
||||
builder
|
||||
};
|
||||
let auth_header =
|
||||
credentials.map(|(user, pass)| base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, pass)));
|
||||
|
||||
let listener = stream::unfold((listener, http1, auth_header), |(listener, http1, auth_header)| async {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
error!("Error while accepting connection {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let forward_to = Mutex::new((Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0)), 0));
|
||||
let conn_fut = http1.serve_connection(
|
||||
hyper_util::rt::TokioIo::new(&mut stream),
|
||||
service_fn(|req| handle_request(&auth_header, &forward_to, req)),
|
||||
);
|
||||
match conn_fut.await {
|
||||
Ok(_) => return Some((Ok((stream, forward_to.into_inner())), (listener, http1, auth_header))),
|
||||
Err(err) => {
|
||||
info!("Error while serving connection: {}", err);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(HttpProxyListener {
|
||||
listener: Box::pin(listener),
|
||||
})
|
||||
}
|
||||
|
||||
//#[cfg(test)]
|
||||
//mod tests {
|
||||
// use super::*;
|
||||
// use tracing::level_filters::LevelFilter;
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn test_run_server() {
|
||||
// tracing_subscriber::fmt()
|
||||
// .with_ansi(true)
|
||||
// .with_max_level(LevelFilter::TRACE)
|
||||
// .init();
|
||||
// let x = run_server("127.0.0.1:1212".parse().unwrap(), None, None).await;
|
||||
// }
|
||||
//}
|
9
src/protocols/mod.rs
Normal file
9
src/protocols/mod.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
pub mod dns;
|
||||
pub mod http_proxy;
|
||||
pub mod socks5;
|
||||
pub mod stdio;
|
||||
pub mod tcp;
|
||||
pub mod tls;
|
||||
pub mod udp;
|
||||
#[cfg(unix)]
|
||||
pub mod unix_sock;
|
6
src/protocols/socks5/mod.rs
Normal file
6
src/protocols/socks5/mod.rs
Normal file
|
@ -0,0 +1,6 @@
|
|||
mod tcp_server;
|
||||
mod udp_server;
|
||||
|
||||
pub use tcp_server::run_server;
|
||||
pub use tcp_server::Socks5Listener;
|
||||
pub use tcp_server::Socks5Stream;
|
274
src/protocols/socks5/tcp_server.rs
Normal file
274
src/protocols/socks5/tcp_server.rs
Normal file
|
@ -0,0 +1,274 @@
|
|||
use super::udp_server::Socks5UdpStream;
|
||||
use crate::LocalProtocol;
|
||||
use anyhow::Context;
|
||||
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
|
||||
use fast_socks5::util::target_addr::TargetAddr;
|
||||
use fast_socks5::{consts, ReplyError};
|
||||
use futures_util::{stream, Stream, StreamExt};
|
||||
use std::io::{Error, IoSlice};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tracing::{info, warn};
|
||||
use url::Host;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct Socks5Listener {
|
||||
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
|
||||
}
|
||||
|
||||
pub enum Socks5Stream {
|
||||
Tcp(TcpStream),
|
||||
Udp(Socks5UdpStream),
|
||||
}
|
||||
|
||||
impl Socks5Stream {
|
||||
pub fn local_protocol(&self) -> LocalProtocol {
|
||||
match self {
|
||||
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
|
||||
Self::Udp(s) => LocalProtocol::Udp {
|
||||
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for Socks5Listener {
|
||||
type Item = anyhow::Result<(Socks5Stream, (Host, u16))>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.socks_server) }.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server(
|
||||
bind: SocketAddr,
|
||||
timeout: Option<Duration>,
|
||||
credentials: Option<(String, String)>,
|
||||
) -> Result<Socks5Listener, anyhow::Error> {
|
||||
info!(
|
||||
"Starting SOCKS5 server listening cnx on {} with credentials {:?}",
|
||||
bind, credentials
|
||||
);
|
||||
|
||||
let server = Socks5Server::<DenyAuthentication>::bind(bind)
|
||||
.await
|
||||
.with_context(|| format!("Cannot create socks5 server {:?}", bind))?;
|
||||
|
||||
let mut cfg = Config::default();
|
||||
cfg = if let Some((username, password)) = credentials {
|
||||
cfg.set_allow_no_auth(false);
|
||||
cfg.with_authentication(SimpleUserPassword { username, password })
|
||||
} else {
|
||||
cfg.set_allow_no_auth(true);
|
||||
cfg
|
||||
};
|
||||
|
||||
cfg.set_dns_resolve(false);
|
||||
cfg.set_execute_command(false);
|
||||
cfg.set_udp_support(true);
|
||||
|
||||
let udp_server = super::udp_server::run_server(bind, timeout).await?;
|
||||
let server = server.with_config(cfg);
|
||||
let stream = stream::unfold((server, Box::pin(udp_server)), move |(server, mut udp_server)| async move {
|
||||
let mut acceptor = server.incoming();
|
||||
loop {
|
||||
let cnx = select! {
|
||||
biased;
|
||||
|
||||
cnx = acceptor.next() => match cnx {
|
||||
None => return None,
|
||||
Some(Err(err)) => {
|
||||
drop(acceptor);
|
||||
return Some((Err(anyhow::Error::new(err)), (server, udp_server)));
|
||||
}
|
||||
Some(Ok(cnx)) => cnx,
|
||||
},
|
||||
|
||||
// new incoming udp stream
|
||||
udp_conn = udp_server.next() => {
|
||||
drop(acceptor);
|
||||
return match udp_conn {
|
||||
Some(Ok(stream)) => {
|
||||
let dest = stream.destination();
|
||||
Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server)))
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
Some((Err(anyhow::Error::new(err)), (server, udp_server)))
|
||||
}
|
||||
None => {
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let cnx = match cnx.upgrade_to_socks5().await {
|
||||
Ok(cnx) => cnx,
|
||||
Err(err) => {
|
||||
warn!("Rejecting socks5 cnx: {}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(target) = cnx.target_addr() else {
|
||||
warn!("Rejecting socks5 cnx: no target addr");
|
||||
continue;
|
||||
};
|
||||
|
||||
let (host, port) = match target {
|
||||
TargetAddr::Ip(SocketAddr::V4(ip)) => (Host::Ipv4(*ip.ip()), ip.port()),
|
||||
TargetAddr::Ip(SocketAddr::V6(ip)) => (Host::Ipv6(*ip.ip()), ip.port()),
|
||||
TargetAddr::Domain(host, port) => (Host::Domain(host.clone()), *port),
|
||||
};
|
||||
|
||||
// Special case for UDP Associate where we return the bind addr of the udp server
|
||||
if matches!(cnx.cmd(), Some(fast_socks5::Socks5Command::UDPAssociate)) {
|
||||
let mut cnx = cnx.into_inner();
|
||||
let ret = cnx.write_all(&new_reply(&ReplyError::Succeeded, bind)).await;
|
||||
|
||||
if let Err(err) = ret {
|
||||
warn!("Cannot reply to socks5 udp client: {}", err);
|
||||
continue;
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0u8; 8];
|
||||
loop {
|
||||
match cnx.read(&mut buf).await {
|
||||
Ok(0) => return,
|
||||
Err(_) => return,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut cnx = cnx.into_inner();
|
||||
let ret = cnx
|
||||
.write_all(&new_reply(
|
||||
&ReplyError::Succeeded,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
|
||||
))
|
||||
.await;
|
||||
|
||||
if let Err(err) = ret {
|
||||
warn!("Cannot reply to socks5 client: {}", err);
|
||||
continue;
|
||||
}
|
||||
|
||||
drop(acceptor);
|
||||
return Some((Ok((Socks5Stream::Tcp(cnx), (host, port))), (server, udp_server)));
|
||||
}
|
||||
});
|
||||
|
||||
let listener = Socks5Listener {
|
||||
socks_server: Box::pin(stream),
|
||||
};
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
|
||||
fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
|
||||
let (addr_type, mut ip_oct, mut port) = match sock_addr {
|
||||
SocketAddr::V4(sock) => (
|
||||
consts::SOCKS5_ADDR_TYPE_IPV4,
|
||||
sock.ip().octets().to_vec(),
|
||||
sock.port().to_be_bytes().to_vec(),
|
||||
),
|
||||
SocketAddr::V6(sock) => (
|
||||
consts::SOCKS5_ADDR_TYPE_IPV6,
|
||||
sock.ip().octets().to_vec(),
|
||||
sock.port().to_be_bytes().to_vec(),
|
||||
),
|
||||
};
|
||||
|
||||
let mut reply = vec![
|
||||
consts::SOCKS5_VERSION,
|
||||
error.as_u8(), // transform the error into byte code
|
||||
0x00, // reserved
|
||||
addr_type, // address type (ipv4, v6, domain)
|
||||
];
|
||||
reply.append(&mut ip_oct);
|
||||
reply.append(&mut port);
|
||||
|
||||
reply
|
||||
}
|
||||
|
||||
impl Unpin for Socks5Stream {}
|
||||
impl AsyncRead for Socks5Stream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf),
|
||||
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socks5Stream {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
|
||||
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx),
|
||||
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx),
|
||||
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs),
|
||||
Self::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
Self::Tcp(s) => s.is_write_vectored(),
|
||||
Self::Udp(s) => s.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//#[cfg(test)]
|
||||
//mod test {
|
||||
// use super::*;
|
||||
// use futures_util::StreamExt;
|
||||
// use std::str::FromStr;
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn socks5_server() {
|
||||
// let mut x = run_server(SocketAddr::from_str("[::]:4343").unwrap())
|
||||
// .await
|
||||
// .unwrap();
|
||||
//
|
||||
// loop {
|
||||
// let cnx = x.next().await.unwrap().unwrap();
|
||||
// eprintln!("{:?}", cnx);
|
||||
// }
|
||||
// }
|
||||
//}
|
285
src/protocols/socks5/udp_server.rs
Normal file
285
src/protocols/socks5/udp_server.rs
Normal file
|
@ -0,0 +1,285 @@
|
|||
use anyhow::Context;
|
||||
use futures_util::{stream, Stream};
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::tunnel::to_host_port;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use fast_socks5::new_udp_header;
|
||||
use fast_socks5::util::target_addr::TargetAddr;
|
||||
use log::warn;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Interval;
|
||||
use tracing::{debug, error, info};
|
||||
use url::Host;
|
||||
|
||||
type PeerMapKey = (SocketAddr, TargetAddr);
|
||||
|
||||
struct IoInner {
|
||||
sender: mpsc::Sender<Bytes>,
|
||||
}
|
||||
struct Socks5UdpServer {
|
||||
listener: Arc<UdpSocket>,
|
||||
peers: HashMap<PeerMapKey, Pin<Arc<IoInner>>, ahash::RandomState>,
|
||||
keys_to_delete: Arc<RwLock<Vec<PeerMapKey>>>,
|
||||
cnx_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Socks5UdpServer {
|
||||
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
|
||||
let socket = socket2::SockRef::from(&listener);
|
||||
|
||||
// Increase receive buffer
|
||||
const BUF_SIZES: [usize; 7] = [64usize, 32usize, 16usize, 8usize, 4usize, 2usize, 1usize];
|
||||
for size in BUF_SIZES.iter() {
|
||||
if let Err(err) = socket.set_recv_buffer_size(size * 1024 * 1024) {
|
||||
warn!("Cannot increase UDP server recv buffer to {} Mib: {}", size, err);
|
||||
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
|
||||
continue;
|
||||
}
|
||||
|
||||
if *size != BUF_SIZES[0] {
|
||||
info!("Increased UDP server recv buffer to {} Mib", size);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
for size in BUF_SIZES.iter() {
|
||||
if let Err(err) = socket.set_send_buffer_size(size * 1024 * 1024) {
|
||||
warn!("Cannot increase UDP server send buffer to {} Mib: {}", size, err);
|
||||
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
|
||||
continue;
|
||||
}
|
||||
|
||||
if *size != BUF_SIZES[0] {
|
||||
info!("Increased UDP server send buffer to {} Mib", size);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
Self {
|
||||
listener: Arc::new(listener),
|
||||
peers: HashMap::with_hasher(ahash::RandomState::new()),
|
||||
keys_to_delete: Default::default(),
|
||||
cnx_timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn clean_dead_keys(&mut self) {
|
||||
let nb_key_to_delete = self.keys_to_delete.read().len();
|
||||
if nb_key_to_delete == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
||||
let mut keys_to_delete = self.keys_to_delete.write();
|
||||
for key in keys_to_delete.iter() {
|
||||
self.peers.remove(key);
|
||||
}
|
||||
keys_to_delete.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(PinnedDrop)]
|
||||
pub struct Socks5UdpStream {
|
||||
#[pin]
|
||||
recv_data: mpsc::Receiver<Bytes>,
|
||||
send_socket: Arc<UdpSocket>,
|
||||
destination: TargetAddr,
|
||||
peer: SocketAddr,
|
||||
udp_header: Vec<u8>,
|
||||
#[pin]
|
||||
pub watchdog_deadline: Option<Interval>,
|
||||
data_read_before_deadline: bool,
|
||||
io: Pin<Arc<IoInner>>,
|
||||
keys_to_delete: Weak<RwLock<Vec<PeerMapKey>>>,
|
||||
}
|
||||
|
||||
#[pinned_drop]
|
||||
impl PinnedDrop for Socks5UdpStream {
|
||||
fn drop(self: Pin<&mut Self>) {
|
||||
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
|
||||
keys_to_delete.write().push((self.peer, self.destination.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Socks5UdpStream {
|
||||
fn new(
|
||||
send_socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
destination: TargetAddr,
|
||||
watchdog_deadline: Option<Duration>,
|
||||
keys_to_delete: Weak<RwLock<Vec<PeerMapKey>>>,
|
||||
) -> (Self, Pin<Arc<IoInner>>) {
|
||||
let (tx, rx) = mpsc::channel(1024);
|
||||
let io = Arc::pin(IoInner { sender: tx });
|
||||
let udp_header = match &destination {
|
||||
TargetAddr::Ip(ip) => new_udp_header(*ip).unwrap(),
|
||||
TargetAddr::Domain(h, p) => new_udp_header((h.as_str(), *p)).unwrap(),
|
||||
};
|
||||
let s = Self {
|
||||
recv_data: rx,
|
||||
send_socket,
|
||||
peer,
|
||||
destination,
|
||||
watchdog_deadline: watchdog_deadline
|
||||
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
|
||||
data_read_before_deadline: false,
|
||||
io: io.clone(),
|
||||
keys_to_delete,
|
||||
udp_header,
|
||||
};
|
||||
|
||||
(s, io)
|
||||
}
|
||||
|
||||
pub fn destination(&self) -> (Host, u16) {
|
||||
match &self.destination {
|
||||
TargetAddr::Ip(sock_addr) => to_host_port(*sock_addr),
|
||||
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Socks5UdpStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
obuf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let mut project = self.project();
|
||||
// Look that the timeout for client has not elapsed
|
||||
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
|
||||
if deadline.poll_tick(cx).is_ready() {
|
||||
if !*project.data_read_before_deadline {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::TimedOut,
|
||||
format!("UDP stream timeout with {}", project.peer),
|
||||
)));
|
||||
};
|
||||
|
||||
*project.data_read_before_deadline = false;
|
||||
while deadline.poll_tick(cx).is_ready() {}
|
||||
}
|
||||
}
|
||||
|
||||
let Some(data) = ready!(project.recv_data.poll_recv(cx)) else {
|
||||
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)));
|
||||
};
|
||||
if obuf.remaining() < data.len() {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"udp dst buffer does not have enough space left. Can't fragment",
|
||||
)));
|
||||
}
|
||||
|
||||
obuf.put_slice(data.chunk());
|
||||
*project.data_read_before_deadline = true;
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socks5UdpStream {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
let this = self.project();
|
||||
let header_len = this.udp_header.len();
|
||||
this.udp_header.extend_from_slice(buf);
|
||||
let ret = this
|
||||
.send_socket
|
||||
.poll_send_to(cx, this.udp_header.as_slice(), *this.peer);
|
||||
this.udp_header.truncate(header_len);
|
||||
ret.map(|r| r.map(|write_len| write_len - header_len))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
self.send_socket.poll_send_ready(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server(
|
||||
bind: SocketAddr,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<impl Stream<Item = io::Result<Socks5UdpStream>>, anyhow::Error> {
|
||||
let listener = UdpSocket::bind(bind)
|
||||
.await
|
||||
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
|
||||
|
||||
let udp_server = Socks5UdpServer::new(listener, timeout);
|
||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||
let buffer = BytesMut::with_capacity(MAX_PACKET_LENGTH * 10);
|
||||
let stream = stream::unfold((udp_server, buffer), |(mut server, mut buf)| async move {
|
||||
loop {
|
||||
server.clean_dead_keys();
|
||||
buf.reserve(MAX_PACKET_LENGTH);
|
||||
|
||||
let peer_addr = match server.listener.recv_buf_from(&mut buf).await {
|
||||
Ok((_read_len, peer_addr)) => peer_addr,
|
||||
Err(err) => {
|
||||
error!("Cannot read from UDP server. Closing server: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let (destination_addr, data) = {
|
||||
let payload = buf.split().freeze();
|
||||
let (frag, destination_addr, data) = match fast_socks5::parse_udp_request(payload.chunk()).await {
|
||||
Ok((frag, addr, data)) => (frag, addr, data),
|
||||
Err(err) => {
|
||||
warn!("Skipping invalid UDP socks5 request: {} ", err);
|
||||
debug!("Invalid UDP socks5 request: {:?}", payload.chunk());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// We don't support udp fragmentation
|
||||
if frag != 0 {
|
||||
warn!("dropping UDP socks5 fragmented");
|
||||
continue;
|
||||
}
|
||||
(destination_addr, payload.slice_ref(data))
|
||||
};
|
||||
|
||||
let addr = (peer_addr, destination_addr);
|
||||
match server.peers.get(&addr) {
|
||||
Some(io) => {
|
||||
if io.sender.send(data).await.is_err() {
|
||||
server.peers.remove(&addr);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!("New UDP connection for {}", addr.1);
|
||||
let (udp_client, io) = Socks5UdpStream::new(
|
||||
server.listener.clone(),
|
||||
addr.0,
|
||||
addr.1.clone(),
|
||||
server.cnx_timeout,
|
||||
Arc::downgrade(&server.keys_to_delete),
|
||||
);
|
||||
let _ = io.sender.send(data).await;
|
||||
server.peers.insert(addr, io);
|
||||
return Some((Ok(udp_client), (server, buf)));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(stream)
|
||||
}
|
9
src/protocols/stdio/mod.rs
Normal file
9
src/protocols/stdio/mod.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
#[cfg(unix)]
|
||||
mod server_unix;
|
||||
#[cfg(not(unix))]
|
||||
mod server_windows;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use server_unix::run_server;
|
||||
#[cfg(not(unix))]
|
||||
pub use server_windows::run_server;
|
27
src/protocols/stdio/server_unix.rs
Normal file
27
src/protocols/stdio/server_unix.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_fd::AsyncFd;
|
||||
use tracing::info;
|
||||
|
||||
pub struct WsStdin {
|
||||
stdin: AsyncFd,
|
||||
_receiver: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
impl AsyncRead for WsStdin {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
|
||||
info!("Starting STDIO server");
|
||||
|
||||
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
|
||||
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
|
||||
let (tx, rx) = oneshot::channel::<()>();
|
||||
|
||||
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
|
||||
}
|
82
src/protocols/stdio/server_windows.rs
Normal file
82
src/protocols/stdio/server_windows.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use bytes::BytesMut;
|
||||
use log::error;
|
||||
use parking_lot::Mutex;
|
||||
use scopeguard::guard;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::Arc;
|
||||
use std::{io, thread};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::LocalSet;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
|
||||
info!("Starting STDIO server. Press ctrl+c twice to exit");
|
||||
|
||||
crossterm::terminal::enable_raw_mode()?;
|
||||
|
||||
let stdin = io::stdin();
|
||||
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (abort_tx, abort_rx) = oneshot::channel::<()>();
|
||||
let abort_rx = Arc::new(Mutex::new(abort_rx));
|
||||
let abort_rx2 = abort_rx.clone();
|
||||
thread::spawn(move || {
|
||||
let _restore_terminal = guard((), move |_| {
|
||||
let _ = crossterm::terminal::disable_raw_mode();
|
||||
abort_rx.lock().close();
|
||||
});
|
||||
let stdin = stdin;
|
||||
let mut stdin = stdin.lock();
|
||||
let mut buf = [0u8; 65536];
|
||||
|
||||
loop {
|
||||
let n = stdin.read(&mut buf).unwrap_or(0);
|
||||
if n == 0 || (n == 1 && buf[0] == 3) {
|
||||
// ctrl+c send char 3
|
||||
break;
|
||||
}
|
||||
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
|
||||
error!("Failed send inout: {:?}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
let stdin = StreamReader::new(UnboundedReceiverStream::new(recv));
|
||||
|
||||
let (stdout, mut recv) = tokio::io::duplex(65536);
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
thread::spawn(move || {
|
||||
let task = async move {
|
||||
let _restore_terminal = guard((), move |_| {
|
||||
let _ = crossterm::terminal::disable_raw_mode();
|
||||
abort_rx2.lock().close();
|
||||
});
|
||||
let mut stdout = io::stdout().lock();
|
||||
let mut buf = [0u8; 65536];
|
||||
loop {
|
||||
let Ok(n) = recv.read(&mut buf).await else {
|
||||
break;
|
||||
};
|
||||
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(err) = stdout.write_all(&buf[..n]) {
|
||||
error!("Failed to write to stdout: {:?}", err);
|
||||
break;
|
||||
};
|
||||
let _ = stdout.flush();
|
||||
}
|
||||
};
|
||||
|
||||
let local = LocalSet::new();
|
||||
local.spawn_local(task);
|
||||
|
||||
rt.block_on(local);
|
||||
});
|
||||
|
||||
Ok(((stdin, stdout), abort_tx))
|
||||
}
|
6
src/protocols/tcp/mod.rs
Normal file
6
src/protocols/tcp/mod.rs
Normal file
|
@ -0,0 +1,6 @@
|
|||
mod server;
|
||||
|
||||
pub use server::configure_socket;
|
||||
pub use server::connect;
|
||||
pub use server::connect_with_http_proxy;
|
||||
pub use server::run_server;
|
302
src/protocols/tcp/server.rs
Normal file
302
src/protocols/tcp/server.rs
Normal file
|
@ -0,0 +1,302 @@
|
|||
use anyhow::{anyhow, Context};
|
||||
use std::{io, vec};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use base64::Engine;
|
||||
use bytes::BytesMut;
|
||||
use log::warn;
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
|
||||
use crate::protocols::dns::DnsResolver;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::time::{sleep, timeout};
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
use tracing::log::info;
|
||||
use tracing::{debug, instrument};
|
||||
use url::{Host, Url};
|
||||
|
||||
pub fn configure_socket(socket: SockRef, so_mark: &Option<u32>) -> Result<(), anyhow::Error> {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.with_context(|| format!("cannot set no_delay on socket: {:?}", io::Error::last_os_error()))?;
|
||||
|
||||
#[cfg(not(any(target_os = "windows", target_os = "openbsd")))]
|
||||
let tcp_keepalive = TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(60))
|
||||
.with_interval(Duration::from_secs(10))
|
||||
.with_retries(3);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
let tcp_keepalive = TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(60))
|
||||
.with_interval(Duration::from_secs(10));
|
||||
|
||||
#[cfg(target_os = "openbsd")]
|
||||
let tcp_keepalive = TcpKeepalive::new().with_time(Duration::from_secs(60));
|
||||
|
||||
socket
|
||||
.set_tcp_keepalive(&tcp_keepalive)
|
||||
.with_context(|| format!("cannot set tcp_keepalive on socket: {:?}", io::Error::last_os_error()))?;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
if let Some(so_mark) = so_mark {
|
||||
socket
|
||||
.set_mark(*so_mark)
|
||||
.with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
host: &Host<String>,
|
||||
port: u16,
|
||||
so_mark: Option<u32>,
|
||||
connect_timeout: Duration,
|
||||
dns_resolver: &DnsResolver,
|
||||
) -> Result<TcpStream, anyhow::Error> {
|
||||
info!("Opening TCP connection to {}:{}", host, port);
|
||||
|
||||
let socket_addrs: Vec<SocketAddr> = match host {
|
||||
Host::Domain(domain) => dns_resolver
|
||||
.lookup_host(domain.as_str(), port)
|
||||
.await
|
||||
.with_context(|| format!("cannot resolve domain: {}", domain))?,
|
||||
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
|
||||
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
|
||||
};
|
||||
|
||||
let mut cnx = None;
|
||||
let mut last_err = None;
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (ix, addr) in socket_addrs.into_iter().enumerate() {
|
||||
let socket = match &addr {
|
||||
SocketAddr::V4(_) => TcpSocket::new_v4(),
|
||||
SocketAddr::V6(_) => TcpSocket::new_v6(),
|
||||
};
|
||||
let socket = match socket {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
last_err = Some(err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
|
||||
|
||||
// Spawn the connection attempt in the join set.
|
||||
// We include a delay of ix * 250 milliseconds, as per RFC8305.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
|
||||
let fut = async move {
|
||||
if ix > 0 {
|
||||
sleep(Duration::from_millis(250 * ix as u64)).await;
|
||||
}
|
||||
debug!("Connecting to {}", addr);
|
||||
match timeout(connect_timeout, socket.connect(addr)).await {
|
||||
Ok(Ok(s)) => Ok(Ok(s)),
|
||||
Ok(Err(e)) => Ok(Err((addr, e))),
|
||||
Err(e) => Err((addr, e)),
|
||||
}
|
||||
};
|
||||
join_set.spawn(fut);
|
||||
}
|
||||
|
||||
// Wait for the next future that finishes in the join set, until we got one
|
||||
// that resulted in a successful connection.
|
||||
// If cnx is no longer None, we exit the loop, since this means that we got
|
||||
// a successful connection.
|
||||
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
|
||||
match res? {
|
||||
Ok(Ok(stream)) => {
|
||||
// We've got a successful connection, so we can abort all other
|
||||
// ongoing attempts.
|
||||
join_set.abort_all();
|
||||
|
||||
debug!(
|
||||
"Connected to tcp endpoint {}, aborted all other connection attempts",
|
||||
stream.peer_addr()?
|
||||
);
|
||||
cnx = Some(stream);
|
||||
}
|
||||
Ok(Err((addr, err))) => {
|
||||
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
|
||||
last_err = Some(err);
|
||||
}
|
||||
Err((addr, _)) => {
|
||||
warn!(
|
||||
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
|
||||
connect_timeout.as_secs()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cnx.ok_or_else(|| anyhow!("Cannot connect to tcp endpoint {}:{} reason {:?}", host, port, last_err))
|
||||
}
|
||||
|
||||
#[instrument(level = "info", name = "http_proxy", skip_all)]
|
||||
pub async fn connect_with_http_proxy(
|
||||
proxy: &Url,
|
||||
host: &Host<String>,
|
||||
port: u16,
|
||||
so_mark: Option<u32>,
|
||||
connect_timeout: Duration,
|
||||
dns_resolver: &DnsResolver,
|
||||
) -> Result<TcpStream, anyhow::Error> {
|
||||
let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned();
|
||||
let proxy_port = proxy.port_or_known_default().unwrap_or(80);
|
||||
|
||||
info!("Connecting to http proxy {}:{}", proxy_host, proxy_port);
|
||||
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout, dns_resolver).await?;
|
||||
debug!("Connected to http proxy {}", socket.peer_addr().unwrap());
|
||||
|
||||
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
|
||||
let user = urlencoding::decode(user).with_context(|| format!("Cannot urldecode proxy user: {}", user))?;
|
||||
let password =
|
||||
urlencoding::decode(password).with_context(|| format!("Cannot urldecode proxy password: {}", password))?;
|
||||
let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
|
||||
format!("Proxy-Authorization: Basic {}\r\n", creds)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
|
||||
debug!("Sending request:\n{}", connect_request);
|
||||
socket.write_all(connect_request.as_bytes()).await?;
|
||||
|
||||
let mut buf = BytesMut::with_capacity(1024);
|
||||
loop {
|
||||
let nb_bytes = tokio::time::timeout(connect_timeout, socket.read_buf(&mut buf)).await;
|
||||
match nb_bytes {
|
||||
Ok(Ok(0)) => {
|
||||
return Err(anyhow!(
|
||||
"Cannot connect to http proxy. Proxy closed the connection without returning any response"
|
||||
));
|
||||
}
|
||||
Ok(Ok(_)) => {}
|
||||
Ok(Err(err)) => {
|
||||
return Err(anyhow!("Cannot connect to http proxy. {err}"));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
|
||||
}
|
||||
};
|
||||
|
||||
static END_HTTP_RESPONSE: &[u8; 4] = b"\r\n\r\n"; // It is reversed from \r\n\r\n as we reverse scan the buffer
|
||||
if buf.len() > 50 * 1024
|
||||
|| buf
|
||||
.windows(END_HTTP_RESPONSE.len())
|
||||
.any(|window| window == END_HTTP_RESPONSE)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static OK_RESPONSE_10: &[u8] = b"HTTP/1.0 200 ";
|
||||
static OK_RESPONSE_11: &[u8] = b"HTTP/1.1 200 ";
|
||||
if !buf
|
||||
.windows(OK_RESPONSE_10.len())
|
||||
.any(|window| window == OK_RESPONSE_10 || window == OK_RESPONSE_11)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Cannot connect to http proxy. Proxy returned an invalid response: {}",
|
||||
String::from_utf8_lossy(&buf)
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Got response from proxy:\n{}", String::from_utf8_lossy(&buf));
|
||||
info!("Http proxy accepted connection to remote host {}:{}", host, port);
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpListenerStream, anyhow::Error> {
|
||||
info!("Starting TCP server listening cnx on {}", bind);
|
||||
|
||||
let listener = TcpListener::bind(bind)
|
||||
.await
|
||||
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
if ip_transparent {
|
||||
info!("TCP server listening in TProxy mode");
|
||||
socket2::SockRef::from(&listener).set_ip_transparent(ip_transparent)?;
|
||||
}
|
||||
|
||||
Ok(TcpListenerStream::new(listener))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures_util::pin_mut;
|
||||
use std::net::SocketAddr;
|
||||
use testcontainers::core::WaitFor;
|
||||
use testcontainers::runners::AsyncRunner;
|
||||
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MitmProxy {}
|
||||
|
||||
impl ImageArgs for MitmProxy {
|
||||
fn into_iterator(self) -> Box<dyn Iterator<Item = String>> {
|
||||
Box::new(vec!["mitmdump".to_string()].into_iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl Image for MitmProxy {
|
||||
type Args = Self;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"mitmproxy/mitmproxy".to_string()
|
||||
}
|
||||
|
||||
fn tag(&self) -> String {
|
||||
"10.1.1".to_string()
|
||||
}
|
||||
|
||||
fn ready_conditions(&self) -> Vec<WaitFor> {
|
||||
vec![WaitFor::Duration {
|
||||
length: Duration::from_secs(5),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_connection() {
|
||||
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();
|
||||
let server = TcpListener::bind(server_addr).await.unwrap();
|
||||
|
||||
let _mitm_proxy: ContainerAsync<MitmProxy> = RunnableImage::from(MitmProxy {})
|
||||
.with_network("host".to_string())
|
||||
.start()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut client = connect_with_http_proxy(
|
||||
&"http://localhost:8080".parse().unwrap(),
|
||||
&Host::Domain("[::1]".to_string()),
|
||||
1236,
|
||||
None,
|
||||
Duration::from_secs(1),
|
||||
&DnsResolver::System,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
|
||||
let client_srv = server.accept().await.unwrap().0;
|
||||
pin_mut!(client_srv);
|
||||
|
||||
let mut buf = [0u8; 25];
|
||||
let ret = client_srv.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(18)));
|
||||
client_srv.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
|
||||
|
||||
client_srv.get_mut().shutdown().await.unwrap();
|
||||
let _ = client.read(&mut buf).await.unwrap();
|
||||
assert!(buf.starts_with(b"HTTP/1.1 200 OK\r\n\r\n"));
|
||||
}
|
||||
}
|
10
src/protocols/tls/mod.rs
Normal file
10
src/protocols/tls/mod.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
mod server;
|
||||
mod utils;
|
||||
|
||||
pub use server::connect;
|
||||
pub use server::load_certificates_from_pem;
|
||||
pub use server::load_private_key_from_file;
|
||||
pub use server::tls_acceptor;
|
||||
pub use server::tls_connector;
|
||||
pub use utils::cn_from_certificate;
|
||||
pub use utils::find_leaf_certificate;
|
205
src/protocols/tls/server.rs
Normal file
205
src/protocols/tls/server.rs
Normal file
|
@ -0,0 +1,205 @@
|
|||
use crate::{TlsServerConfig, WsClientConfig};
|
||||
use anyhow::{anyhow, Context};
|
||||
use std::fs::File;
|
||||
|
||||
use log::warn;
|
||||
use std::io::BufReader;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
|
||||
use crate::tunnel::TransportAddr;
|
||||
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
|
||||
use tokio_rustls::rustls::server::WebPkiClientVerifier;
|
||||
use tokio_rustls::rustls::{ClientConfig, DigitallySignedStruct, Error, KeyLogFile, SignatureScheme};
|
||||
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NullVerifier;
|
||||
|
||||
impl ServerCertVerifier for NullVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<ServerCertVerified, Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
|
||||
vec![
|
||||
SignatureScheme::RSA_PKCS1_SHA1,
|
||||
SignatureScheme::ECDSA_SHA1_Legacy,
|
||||
SignatureScheme::RSA_PKCS1_SHA256,
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::RSA_PKCS1_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::RSA_PKCS1_SHA512,
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::ED25519,
|
||||
SignatureScheme::ED448,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
|
||||
info!("Loading tls certificate from {:?}", path);
|
||||
|
||||
let file = File::open(path)?;
|
||||
let mut reader = BufReader::new(file);
|
||||
let certs = rustls_pemfile::certs(&mut reader);
|
||||
|
||||
Ok(certs
|
||||
.into_iter()
|
||||
.filter_map(|cert| match cert {
|
||||
Ok(cert) => Some(cert),
|
||||
Err(err) => {
|
||||
warn!("Error while parsing tls certificate: {:?}", err);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
|
||||
info!("Loading tls private key from {:?}", path);
|
||||
|
||||
let file = File::open(path)?;
|
||||
let mut reader = BufReader::new(file);
|
||||
|
||||
let Some(private_key) = rustls_pemfile::private_key(&mut reader)? else {
|
||||
return Err(anyhow!("No private key found in {path:?}"));
|
||||
};
|
||||
|
||||
Ok(private_key)
|
||||
}
|
||||
|
||||
pub fn tls_connector(
|
||||
tls_verify_certificate: bool,
|
||||
alpn_protocols: Vec<Vec<u8>>,
|
||||
enable_sni: bool,
|
||||
tls_client_certificate: Option<Vec<CertificateDer<'static>>>,
|
||||
tls_client_key: Option<PrivateKeyDer<'static>>,
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load system certificates and add them to the root store
|
||||
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
|
||||
for cert in certs {
|
||||
if let Err(err) = root_store.add(cert) {
|
||||
warn!("cannot load a system certificate: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let config_builder = ClientConfig::builder().with_root_certificates(root_store);
|
||||
|
||||
let mut config = match (tls_client_certificate, tls_client_key) {
|
||||
(Some(tls_client_certificate), Some(tls_client_key)) => config_builder
|
||||
.with_client_auth_cert(tls_client_certificate, tls_client_key)
|
||||
.with_context(|| "Error setting up mTLS")?,
|
||||
_ => config_builder.with_no_client_auth(),
|
||||
};
|
||||
|
||||
config.enable_sni = enable_sni;
|
||||
config.key_log = Arc::new(KeyLogFile::new());
|
||||
|
||||
// To bypass certificate verification
|
||||
if !tls_verify_certificate {
|
||||
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
|
||||
}
|
||||
|
||||
config.alpn_protocols = alpn_protocols;
|
||||
let tls_connector = TlsConnector::from(Arc::new(config));
|
||||
Ok(tls_connector)
|
||||
}
|
||||
|
||||
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
|
||||
let client_cert_verifier = if let Some(tls_client_ca_certificates) = &tls_cfg.tls_client_ca_certificates {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
for tls_client_ca_certificate in tls_client_ca_certificates.lock().iter() {
|
||||
root_store
|
||||
.add(tls_client_ca_certificate.clone())
|
||||
.with_context(|| "Failed to add mTLS client CA certificate")?;
|
||||
}
|
||||
|
||||
WebPkiClientVerifier::builder(Arc::new(root_store))
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to build mTLS client verifier: {:?}", err))?
|
||||
} else {
|
||||
WebPkiClientVerifier::no_client_auth()
|
||||
};
|
||||
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_client_cert_verifier(client_cert_verifier)
|
||||
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone_key())
|
||||
.with_context(|| "invalid tls certificate or private key")?;
|
||||
|
||||
config.key_log = Arc::new(KeyLogFile::new());
|
||||
if let Some(alpn_protocols) = alpn_protocols {
|
||||
config.alpn_protocols = alpn_protocols;
|
||||
}
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
let sni = client_cfg.tls_server_name();
|
||||
let (tls_connector, sni_disabled) = match &client_cfg.remote_addr {
|
||||
TransportAddr::Wss { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
|
||||
TransportAddr::Https { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
|
||||
TransportAddr::Http { .. } | TransportAddr::Ws { .. } => {
|
||||
return Err(anyhow!("Transport does not support TLS: {}", client_cfg.remote_addr.scheme()))
|
||||
}
|
||||
};
|
||||
|
||||
if sni_disabled {
|
||||
info!(
|
||||
"Doing TLS handshake without SNI with the server {}:{}",
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port()
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Doing TLS handshake using SNI {sni:?} with the server {}:{}",
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port()
|
||||
);
|
||||
}
|
||||
|
||||
let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| {
|
||||
format!(
|
||||
"failed to do TLS handshake with the server {}:{}",
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port()
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(tls_stream)
|
||||
}
|
27
src/protocols/tls/utils.rs
Normal file
27
src/protocols/tls/utils.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
use tokio_rustls::rustls::pki_types::CertificateDer;
|
||||
use x509_parser::parse_x509_certificate;
|
||||
use x509_parser::prelude::X509Certificate;
|
||||
|
||||
/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate
|
||||
/// is present in the vector. The other certificates should be (intermediate) CA certificates.
|
||||
pub fn find_leaf_certificate<'a>(tls_certificates: &'a [CertificateDer<'static>]) -> Option<X509Certificate<'a>> {
|
||||
for tls_certificate in tls_certificates {
|
||||
if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) {
|
||||
if !tls_certificate_x509.is_ca() {
|
||||
return Some(tls_certificate_x509);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns the common name (CN) as specified in the supplied certificate.
|
||||
pub fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option<String> {
|
||||
tls_certificate_x509
|
||||
.tbs_certificate
|
||||
.subject
|
||||
.iter_common_name()
|
||||
.flat_map(|cn| cn.as_str().ok())
|
||||
.next()
|
||||
.map(|cn| cn.to_string())
|
||||
}
|
11
src/protocols/udp/mod.rs
Normal file
11
src/protocols/udp/mod.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
mod server;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub use server::configure_tproxy;
|
||||
pub use server::connect;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub use server::mk_send_socket_tproxy;
|
||||
pub use server::run_server;
|
||||
pub use server::MyUdpSocket;
|
||||
pub use server::UdpStream;
|
||||
pub use server::UdpStreamWriter;
|
658
src/protocols/udp/server.rs
Normal file
658
src/protocols/udp/server.rs
Normal file
|
@ -0,0 +1,658 @@
|
|||
use anyhow::{anyhow, Context};
|
||||
use futures_util::{stream, Stream};
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::{io, task};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use log::warn;
|
||||
use socket2::SockRef;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::futures::Notified;
|
||||
|
||||
use crate::protocols::dns::DnsResolver;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::{sleep, timeout, Interval};
|
||||
use tracing::{debug, error, info};
|
||||
use url::Host;
|
||||
|
||||
struct IoInner {
|
||||
has_data_to_read: Notify,
|
||||
has_read_data: Notify,
|
||||
}
|
||||
struct UdpServer {
|
||||
listener: Arc<UdpSocket>,
|
||||
peers: HashMap<SocketAddr, Pin<Arc<IoInner>>, ahash::RandomState>,
|
||||
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
|
||||
cnx_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl UdpServer {
|
||||
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
|
||||
let socket = socket2::SockRef::from(&listener);
|
||||
|
||||
// Increase receive buffer
|
||||
const BUF_SIZES: [usize; 7] = [64usize, 32usize, 16usize, 8usize, 4usize, 2usize, 1usize];
|
||||
for size in BUF_SIZES.iter() {
|
||||
if let Err(err) = socket.set_recv_buffer_size(size * 1024 * 1024) {
|
||||
warn!("Cannot increase UDP server recv buffer to {} Mib: {}", size, err);
|
||||
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
|
||||
continue;
|
||||
}
|
||||
|
||||
if *size != BUF_SIZES[0] {
|
||||
info!("Increased UDP server recv buffer to {} Mib", size);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
for size in BUF_SIZES.iter() {
|
||||
if let Err(err) = socket.set_send_buffer_size(size * 1024 * 1024) {
|
||||
warn!("Cannot increase UDP server send buffer to {} Mib: {}", size, err);
|
||||
warn!("This is not fatal, but can lead to packet loss if you have too much throughput. You must monitor packet loss in this case");
|
||||
continue;
|
||||
}
|
||||
|
||||
if *size != BUF_SIZES[0] {
|
||||
info!("Increased UDP server send buffer to {} Mib", size);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
Self {
|
||||
listener: Arc::new(listener),
|
||||
peers: HashMap::with_hasher(ahash::RandomState::new()),
|
||||
keys_to_delete: Default::default(),
|
||||
cnx_timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn clean_dead_keys(&mut self) {
|
||||
let nb_key_to_delete = self.keys_to_delete.read().len();
|
||||
if nb_key_to_delete == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
||||
let mut keys_to_delete = self.keys_to_delete.write();
|
||||
for key in keys_to_delete.iter() {
|
||||
self.peers.remove(key);
|
||||
}
|
||||
keys_to_delete.clear();
|
||||
}
|
||||
pub fn clone_socket(&self) -> Arc<UdpSocket> {
|
||||
self.listener.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(PinnedDrop)]
|
||||
pub struct UdpStream {
|
||||
recv_socket: Arc<UdpSocket>,
|
||||
send_socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
#[pin]
|
||||
watchdog_deadline: Option<Interval>,
|
||||
data_read_before_deadline: bool,
|
||||
has_been_notified: bool,
|
||||
#[pin]
|
||||
pending_notification: Option<Notified<'static>>,
|
||||
io: Pin<Arc<IoInner>>,
|
||||
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||
}
|
||||
|
||||
#[pinned_drop]
|
||||
impl PinnedDrop for UdpStream {
|
||||
fn drop(self: Pin<&mut Self>) {
|
||||
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
|
||||
keys_to_delete.write().push(self.peer);
|
||||
}
|
||||
|
||||
// safety: we are dropping the notification as we extend its lifetime to 'static unsafely
|
||||
// So it must be gone before we drop its parent. It should never happen but in case
|
||||
let mut project = self.project();
|
||||
project.pending_notification.as_mut().set(None);
|
||||
project.io.has_read_data.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpStream {
|
||||
fn new(
|
||||
recv_socket: Arc<UdpSocket>,
|
||||
send_socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
watchdog_deadline: Option<Duration>,
|
||||
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||
) -> (Self, Pin<Arc<IoInner>>) {
|
||||
let has_data_to_read = Notify::new();
|
||||
let has_read_data = Notify::new();
|
||||
let io = Arc::pin(IoInner {
|
||||
has_data_to_read,
|
||||
has_read_data,
|
||||
});
|
||||
let mut s = Self {
|
||||
recv_socket,
|
||||
send_socket,
|
||||
peer,
|
||||
watchdog_deadline: watchdog_deadline
|
||||
.map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)),
|
||||
data_read_before_deadline: false,
|
||||
has_been_notified: false,
|
||||
pending_notification: None,
|
||||
io: io.clone(),
|
||||
keys_to_delete,
|
||||
};
|
||||
|
||||
let pending_notification =
|
||||
unsafe { std::mem::transmute::<Notified<'_>, Notified<'static>>(s.io.has_data_to_read.notified()) };
|
||||
s.pending_notification = Some(pending_notification);
|
||||
|
||||
(s, io)
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.send_socket.local_addr()
|
||||
}
|
||||
pub fn writer(&self) -> UdpStreamWriter {
|
||||
UdpStreamWriter {
|
||||
send_socket: self.send_socket.clone(),
|
||||
peer: self.peer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UdpStream {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, obuf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||
let mut project = self.project();
|
||||
// Look that the timeout for client has not elapsed
|
||||
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {
|
||||
if deadline.poll_tick(cx).is_ready() {
|
||||
if !*project.data_read_before_deadline {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::TimedOut,
|
||||
format!("UDP stream timeout with {}", project.peer),
|
||||
)));
|
||||
};
|
||||
|
||||
*project.data_read_before_deadline = false;
|
||||
while deadline.poll_tick(cx).is_ready() {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
|
||||
ready!(notified.poll(cx));
|
||||
project.pending_notification.as_mut().set(None);
|
||||
}
|
||||
|
||||
let peer = ready!(project.recv_socket.poll_recv_from(cx, obuf))?;
|
||||
debug_assert_eq!(peer, *project.peer);
|
||||
*project.data_read_before_deadline = true;
|
||||
|
||||
// re-arm notification
|
||||
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
|
||||
project.pending_notification.as_mut().set(Some(notified));
|
||||
project.pending_notification.as_pin_mut().unwrap().enable();
|
||||
|
||||
// Let know server that we have read data
|
||||
project.io.has_read_data.notify_one();
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpStreamWriter {
|
||||
send_socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
}
|
||||
|
||||
impl AsyncWrite for UdpStreamWriter {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
self.send_socket.poll_send_to(cx, buf, self.peer)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
self.send_socket.poll_send_ready(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server(
|
||||
bind: SocketAddr,
|
||||
timeout: Option<Duration>,
|
||||
configure_listener: impl Fn(&UdpSocket) -> anyhow::Result<()>,
|
||||
mk_send_socket: impl Fn(&Arc<UdpSocket>) -> anyhow::Result<Arc<UdpSocket>>,
|
||||
) -> Result<impl Stream<Item = io::Result<UdpStream>>, anyhow::Error> {
|
||||
info!(
|
||||
"Starting UDP server listening cnx on {} with cnx timeout of {}s",
|
||||
bind,
|
||||
timeout.unwrap_or(Duration::from_secs(0)).as_secs()
|
||||
);
|
||||
|
||||
let listener = UdpSocket::bind(bind)
|
||||
.await
|
||||
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
|
||||
configure_listener(&listener)?;
|
||||
|
||||
let udp_server = UdpServer::new(listener, timeout);
|
||||
let stream = stream::unfold(
|
||||
(udp_server, None, mk_send_socket),
|
||||
|(mut server, peer_with_data, mk_send_socket)| async move {
|
||||
// New returned peer hasn't read its data yet, await for it.
|
||||
if let Some(await_peer) = peer_with_data {
|
||||
if let Some(peer) = server.peers.get(&await_peer) {
|
||||
peer.has_read_data.notified().await;
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
server.clean_dead_keys();
|
||||
let peer_addr = match server.listener.peek_sender().await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
error!("Cannot read from UDP server. Closing server: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match server.peers.get(&peer_addr) {
|
||||
Some(io) => {
|
||||
io.has_data_to_read.notify_one();
|
||||
io.has_read_data.notified().await;
|
||||
}
|
||||
None => {
|
||||
info!("New UDP connection from {}", peer_addr);
|
||||
let (udp_client, io) = UdpStream::new(
|
||||
server.clone_socket(),
|
||||
mk_send_socket(&server.listener).ok()?,
|
||||
peer_addr,
|
||||
server.cnx_timeout,
|
||||
Arc::downgrade(&server.keys_to_delete),
|
||||
);
|
||||
io.has_data_to_read.notify_waiters();
|
||||
server.peers.insert(peer_addr, io);
|
||||
return Some((Ok(udp_client), (server, Some(peer_addr), mk_send_socket)));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MyUdpSocket {
|
||||
socket: Arc<UdpSocket>,
|
||||
}
|
||||
|
||||
impl MyUdpSocket {
|
||||
pub fn new(socket: Arc<UdpSocket>) -> Self {
|
||||
Self { socket }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for MyUdpSocket {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
|
||||
.poll_recv_from(cx, buf)
|
||||
.map(|x| x.map(|_| ()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MyUdpSocket {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
host: &Host<String>,
|
||||
port: u16,
|
||||
connect_timeout: Duration,
|
||||
so_mark: Option<u32>,
|
||||
dns_resolver: &DnsResolver,
|
||||
) -> anyhow::Result<MyUdpSocket> {
|
||||
info!("Opening UDP connection to {}:{}", host, port);
|
||||
|
||||
let socket_addrs: Vec<SocketAddr> = match host {
|
||||
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
|
||||
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
|
||||
Host::Domain(domain) => dns_resolver
|
||||
.lookup_host(domain.as_str(), port)
|
||||
.await
|
||||
.with_context(|| format!("cannot resolve domain: {}", domain))?,
|
||||
};
|
||||
|
||||
let mut cnx = None;
|
||||
let mut last_err = None;
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (ix, addr) in socket_addrs.into_iter().enumerate() {
|
||||
let socket = match &addr {
|
||||
SocketAddr::V4(_) => UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await,
|
||||
SocketAddr::V6(_) => UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await,
|
||||
};
|
||||
|
||||
let socket = match socket {
|
||||
Ok(socket) => socket,
|
||||
Err(err) => {
|
||||
warn!("cannot bind udp socket {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
if let Some(so_mark) = so_mark {
|
||||
SockRef::from(&socket)
|
||||
.set_mark(so_mark)
|
||||
.with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?;
|
||||
}
|
||||
|
||||
// Spawn the connection attempt in the join set.
|
||||
// We include a delay of ix * 250 milliseconds, as per RFC8305.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
|
||||
let fut = async move {
|
||||
if ix > 0 {
|
||||
sleep(Duration::from_millis(250 * ix as u64)).await;
|
||||
}
|
||||
|
||||
debug!("connecting to {}", addr);
|
||||
match timeout(connect_timeout, socket.connect(addr)).await {
|
||||
Ok(Ok(())) => Ok(Ok(socket)),
|
||||
Ok(Err(e)) => Ok(Err((addr, e))),
|
||||
Err(e) => Err((addr, e)),
|
||||
}
|
||||
};
|
||||
join_set.spawn(fut);
|
||||
}
|
||||
|
||||
// Wait for the next future that finishes in the join set, until we got one
|
||||
// that resulted in a successful connection.
|
||||
// If cnx is no longer None, we exit the loop, since this means that we got
|
||||
// a successful connection.
|
||||
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
|
||||
match res? {
|
||||
Ok(Ok(socket)) => {
|
||||
// We've got a successful connection, so we can abort all other
|
||||
// ongoing attempts.
|
||||
join_set.abort_all();
|
||||
|
||||
debug!(
|
||||
"Connected to udp endpoint {}, aborted all other connection attempts",
|
||||
socket.peer_addr()?
|
||||
);
|
||||
cnx = Some(socket);
|
||||
}
|
||||
Ok(Err((addr, err))) => {
|
||||
debug!("Cannot connect to udp endpoint {addr} reason {err}");
|
||||
last_err = Some(err);
|
||||
}
|
||||
Err((addr, _)) => {
|
||||
warn!(
|
||||
"Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed",
|
||||
connect_timeout.as_secs()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cnx) = cnx {
|
||||
Ok(MyUdpSocket::new(Arc::new(cnx)))
|
||||
} else {
|
||||
Err(anyhow!("Cannot connect to udp peer {}:{} reason {:?}", host, port, last_err))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn configure_tproxy(listener: &UdpSocket) -> anyhow::Result<()> {
|
||||
use std::net::IpAddr;
|
||||
use std::os::fd::AsFd;
|
||||
|
||||
socket2::SockRef::from(&listener).set_ip_transparent(true)?;
|
||||
match listener.local_addr().unwrap().ip() {
|
||||
IpAddr::V4(_) => {
|
||||
nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv4OrigDstAddr, &true)?;
|
||||
}
|
||||
IpAddr::V6(_) => {
|
||||
nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv6OrigDstAddr, &true)?;
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[inline]
|
||||
pub fn mk_send_socket_tproxy(listener: &Arc<UdpSocket>) -> anyhow::Result<Arc<UdpSocket>> {
|
||||
use nix::cmsg_space;
|
||||
use nix::sys::socket::{ControlMessageOwned, RecvMsg, SockaddrIn};
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
use std::io::IoSliceMut;
|
||||
use std::net::IpAddr;
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
let mut cmsg_space = cmsg_space!(nix::libc::sockaddr_in6);
|
||||
let mut buf = [0; 8];
|
||||
let mut io = [IoSliceMut::new(&mut buf)];
|
||||
let msg: RecvMsg<SockaddrIn> = nix::sys::socket::recvmsg(
|
||||
listener.as_raw_fd(),
|
||||
&mut io,
|
||||
Some(&mut cmsg_space),
|
||||
nix::sys::socket::MsgFlags::MSG_PEEK,
|
||||
)?;
|
||||
|
||||
let mut remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
|
||||
for cmsg in msg.cmsgs()? {
|
||||
match cmsg {
|
||||
ControlMessageOwned::Ipv4OrigDstAddr(ip) => {
|
||||
remote_addr = SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::from(u32::from_be(ip.sin_addr.s_addr))),
|
||||
u16::from_be(ip.sin_port),
|
||||
);
|
||||
}
|
||||
ControlMessageOwned::Ipv6OrigDstAddr(ip) => {
|
||||
remote_addr = SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::from(u128::from_be_bytes(ip.sin6_addr.s6_addr))),
|
||||
u16::from_be(ip.sin6_port),
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown control message {:?}", cmsg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let socket = Socket::new(Domain::for_address(remote_addr), Type::DGRAM, Some(Protocol::UDP))?;
|
||||
socket.set_ip_transparent(true)?;
|
||||
socket.set_reuse_address(true)?;
|
||||
socket.set_reuse_port(true)?;
|
||||
socket.bind(&SockAddr::from(remote_addr))?;
|
||||
socket.set_nonblocking(true)?;
|
||||
let socket = UdpSocket::from_std(std::net::UdpSocket::from(socket))?;
|
||||
|
||||
Ok(Arc::new(socket))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures_util::{pin_mut, StreamExt};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::time::error::Elapsed;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udp_server() {
|
||||
let server_addr: SocketAddr = "[::1]:1234".parse().unwrap();
|
||||
let server = run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
pin_mut!(server);
|
||||
|
||||
// Should timeout
|
||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Err(Elapsed { .. })));
|
||||
|
||||
// Send some data to the server
|
||||
let client = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||
assert!(client.send_to(b"hello".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Should have a new connection
|
||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Ok(Some(Ok(_)))));
|
||||
|
||||
// Should timeout again, no new client
|
||||
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut2, Err(Elapsed { .. })));
|
||||
|
||||
// Take the stream of data
|
||||
let stream = fut.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream);
|
||||
|
||||
let mut buf = [0u8; 25];
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"hello\0");
|
||||
|
||||
assert!(client.send_to(b"world".as_ref(), server_addr).await.is_ok());
|
||||
assert!(client.send_to(b" test".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Server need to be polled to feed the stream with needed data
|
||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||
// Udp Server should respect framing from the client and not merge the two packets
|
||||
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await;
|
||||
assert!(matches!(ret, Ok(Ok(5))));
|
||||
|
||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[10..])).await;
|
||||
assert!(matches!(ret, Ok(Ok(5))));
|
||||
assert_eq!(&buf[..16], b"helloworld test\0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_client() {
|
||||
let server_addr: SocketAddr = "[::1]:1235".parse().unwrap();
|
||||
let mut server = Box::pin(
|
||||
run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone()))
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Send some data to the server
|
||||
let client = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||
assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Should have a new connection
|
||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Ok(Some(Ok(_)))));
|
||||
|
||||
// Take the stream of data
|
||||
let stream = fut.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream);
|
||||
|
||||
let mut buf = [0u8; 25];
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"aaaaa\0");
|
||||
|
||||
// make the server make progress
|
||||
let fut2 = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut2, Ok(Some(Ok(_)))));
|
||||
|
||||
let stream2 = fut2.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream2);
|
||||
|
||||
// let the server make progress
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let _ = server.next().await;
|
||||
}
|
||||
});
|
||||
|
||||
let ret = stream2.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"bbbbb\0");
|
||||
|
||||
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
|
||||
assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
|
||||
assert!(client2.send_to(b"eeeee".as_ref(), server_addr).await.is_ok());
|
||||
assert!(client.send_to(b"fffff".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"ccccc\0");
|
||||
|
||||
let ret = stream2.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"ddddd\0");
|
||||
|
||||
let ret = stream2.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"eeeee\0");
|
||||
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"fffff\0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udp_should_timeout() {
|
||||
let server_addr: SocketAddr = "[::1]:1237".parse().unwrap();
|
||||
let socket_timeout = Duration::from_secs(1);
|
||||
let server = run_server(server_addr, Some(socket_timeout), |_| Ok(()), |l| Ok(l.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
pin_mut!(server);
|
||||
|
||||
// Send some data to the server
|
||||
let client = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||
assert!(client.send_to(b"hello".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Should have a new connection
|
||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
assert!(matches!(fut, Ok(Some(Ok(_)))));
|
||||
|
||||
// Take the stream of data
|
||||
let stream = fut.unwrap().unwrap().unwrap();
|
||||
pin_mut!(stream);
|
||||
|
||||
let mut buf = [0u8; 25];
|
||||
let ret = stream.read(&mut buf).await;
|
||||
assert!(matches!(ret, Ok(5)));
|
||||
assert_eq!(&buf[..6], b"hello\0");
|
||||
|
||||
// Server need to be polled to feed the stream with need data
|
||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||
let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await;
|
||||
assert!(ret.is_err());
|
||||
|
||||
// Stream should be closed after the timeout
|
||||
tokio::time::sleep(socket_timeout).await;
|
||||
let ret = stream.read(&mut buf[5..]).await;
|
||||
assert!(ret.is_err());
|
||||
}
|
||||
}
|
4
src/protocols/unix_sock/mod.rs
Normal file
4
src/protocols/unix_sock/mod.rs
Normal file
|
@ -0,0 +1,4 @@
|
|||
mod server;
|
||||
|
||||
pub use server::run_server;
|
||||
pub use server::UnixListenerStream;
|
58
src/protocols/unix_sock/server.rs
Normal file
58
src/protocols/unix_sock/server.rs
Normal file
|
@ -0,0 +1,58 @@
|
|||
use anyhow::Context;
|
||||
use futures_util::Stream;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tracing::log::info;
|
||||
|
||||
pub struct UnixListenerStream {
|
||||
inner: UnixListener,
|
||||
path_to_delete: bool,
|
||||
}
|
||||
|
||||
impl UnixListenerStream {
|
||||
pub const fn new(listener: UnixListener, path_to_delete: bool) -> Self {
|
||||
Self {
|
||||
inner: listener,
|
||||
path_to_delete,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UnixListenerStream {
|
||||
fn drop(&mut self) {
|
||||
if self.path_to_delete {
|
||||
let Ok(addr) = &self.inner.local_addr() else {
|
||||
return;
|
||||
};
|
||||
let Some(path) = addr.as_pathname() else {
|
||||
return;
|
||||
};
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for UnixListenerStream {
|
||||
type Item = io::Result<UnixStream>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<io::Result<UnixStream>>> {
|
||||
match self.inner.poll_accept(cx) {
|
||||
Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
|
||||
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server(socket_path: &Path) -> Result<UnixListenerStream, anyhow::Error> {
|
||||
info!("Starting Unix socket server listening cnx on {:?}", socket_path);
|
||||
|
||||
let path_to_delete = !socket_path.exists();
|
||||
let listener = UnixListener::bind(socket_path)
|
||||
.with_context(|| format!("Cannot create Unix socket server {:?}", socket_path))?;
|
||||
|
||||
Ok(UnixListenerStream::new(listener, path_to_delete))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue