feat(reverse-tunnel): allow multiple waiters & auto-shutdown

This commit is contained in:
Σrebe - Romain GERARD 2024-08-04 11:19:26 +02:00
parent 811a1e6adf
commit a468428791
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
9 changed files with 244 additions and 165 deletions

View file

@ -9,7 +9,7 @@ use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::client::{TlsClientConfig, WsClient, WsClientConfig};
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
use crate::tunnel::listeners::{
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
new_stdio_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, UdpTunnelListener,
};
use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig};
use crate::tunnel::{to_host_port, LocalProtocol, RemoteAddr, TransportAddr, TransportScheme};
@ -1004,7 +1004,7 @@ async fn main() -> anyhow::Result<()> {
panic!("Transparent proxy is not available for non Linux platform")
}
LocalProtocol::Udp { timeout } => {
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
let server = UdpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *timeout).await?;
tokio::spawn(async move {
if let Err(err) = client.run_tunnel(server).await {

View file

@ -18,7 +18,7 @@ pub use http_proxy::HttpProxyTunnelListener;
pub use socks5::Socks5TunnelListener;
pub use stdio::new_stdio_listener;
pub use tcp::TcpTunnelListener;
pub use udp::new_udp_listener;
pub use udp::UdpTunnelListener;
#[cfg(unix)]
pub use unix_sock::UnixTunnelListener;
@ -30,7 +30,6 @@ use tokio_stream::Stream;
pub trait TunnelListener: Stream<Item = anyhow::Result<((Self::Reader, Self::Writer), RemoteAddr)>> {
type Reader: AsyncRead + Send + 'static;
type Writer: AsyncWrite + Send + 'static;
type OkReturn; // = ((Self::Reader, Self::Writer), RemoteAddr);
}
impl<T, R, W> TunnelListener for T
@ -41,5 +40,4 @@ where
{
type Reader = R;
type Writer = W;
type OkReturn = ((R, W), RemoteAddr);
}

View file

@ -10,35 +10,31 @@ use std::time::Duration;
use tokio_stream::Stream;
use url::Host;
pub struct UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
pub struct UdpTunnelListener {
listener: Pin<Box<dyn Stream<Item = io::Result<UdpStream>> + Send>>,
dest: (Host, u16),
timeout: Option<Duration>,
}
pub async fn new_udp_listener(
bind_addr: SocketAddr,
dest: (Host, u16),
timeout: Option<Duration>,
) -> anyhow::Result<UdpTunnelListener<impl Stream<Item = io::Result<UdpStream>>>> {
let listener = udp::run_server(bind_addr, timeout, |_| Ok(()), |s| Ok(s.clone()))
.await
.with_context(|| anyhow!("Cannot start UDP server on {}", bind_addr))?;
impl UdpTunnelListener {
pub async fn new(
bind_addr: SocketAddr,
dest: (Host, u16),
timeout: Option<Duration>,
) -> anyhow::Result<UdpTunnelListener> {
let listener = udp::run_server(bind_addr, timeout, |_| Ok(()), |s| Ok(s.clone()))
.await
.with_context(|| anyhow!("Cannot start UDP server on {}", bind_addr))?;
Ok(UdpTunnelListener {
listener,
dest,
timeout,
})
Ok(UdpTunnelListener {
listener: Box::pin(listener),
dest,
timeout,
})
}
}
impl<S> Stream for UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
impl Stream for UdpTunnelListener {
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {

View file

@ -1,6 +1,7 @@
#![allow(clippy::module_inception)]
mod handler_http2;
mod handler_websocket;
mod reverse_tunnel;
mod server;
mod utils;

View file

@ -0,0 +1,118 @@
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::RemoteAddr;
use ahash::{HashMap, HashMapExt};
use anyhow::anyhow;
use futures_util::{pin_mut, StreamExt};
use log::warn;
use parking_lot::Mutex;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::{select, time};
use tracing::{info, Instrument, Span};
struct ReverseTunnelItem<T: TunnelListener> {
#[allow(clippy::type_complexity)]
receiver: async_channel::Receiver<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>,
nb_seen_clients: Arc<AtomicUsize>,
}
impl<T: TunnelListener> Clone for ReverseTunnelItem<T> {
fn clone(&self) -> Self {
Self {
receiver: self.receiver.clone(),
nb_seen_clients: self.nb_seen_clients.clone(),
}
}
}
pub struct ReverseTunnelServer<T: TunnelListener> {
servers: Arc<Mutex<HashMap<SocketAddr, ReverseTunnelItem<T>>>>,
}
impl<T: TunnelListener> ReverseTunnelServer<T> {
pub fn new() -> Self {
Self {
servers: Arc::new(Mutex::new(HashMap::with_capacity(1))),
}
}
pub async fn run_listening_server(
&self,
bind_addr: SocketAddr,
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
) -> anyhow::Result<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>
where
T: TunnelListener + Send + 'static,
{
let listening_server = self.servers.lock().get(&bind_addr).cloned();
let item = if let Some(listening_server) = listening_server {
listening_server
} else {
let listening_server = gen_listening_server.await?;
let send_timeout = Duration::from_secs(60 * 3);
let (tx, rx) = async_channel::bounded(10);
let nb_seen_clients = Arc::new(AtomicUsize::new(0));
let seen_clients = nb_seen_clients.clone();
let server = self.servers.clone();
let local_srv2 = bind_addr;
let fut = async move {
scopeguard::defer!({
server.lock().remove(&local_srv2);
});
let mut timer = time::interval(send_timeout);
pin_mut!(listening_server);
loop {
select! {
biased;
cnx = listening_server.next() => {
match cnx {
None => break,
Some(Err(err)) => {
warn!("Error while listening for incoming connections {err:?}");
continue;
}
Some(Ok(cnx)) => {
if time::timeout(send_timeout, tx.send(cnx)).await.is_err() {
info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs());
break;
}
}
}
},
_ = timer.tick() => {
// if no client connected to the reverse tunnel server, close it
// <= 1 because the server itself has a receiver
if seen_clients.swap(0, Ordering::Relaxed) == 0 && tx.receiver_count() <= 1 {
info!("No client connected to reverse tunnel server for {}s. Closing reverse tunnel server", send_timeout.as_secs());
break;
}
},
}
}
info!("Stopping listening reverse server");
};
tokio::spawn(fut.instrument(Span::current()));
let item = ReverseTunnelItem {
receiver: rx,
nb_seen_clients,
};
self.servers.lock().insert(bind_addr, item.clone());
item
};
item.nb_seen_clients.fetch_add(1, Ordering::Relaxed);
let cnx = item
.receiver
.recv()
.await
.map_err(|_| anyhow!("listening reverse server stopped"))?;
Ok(cnx)
}
}

View file

@ -1,15 +1,12 @@
use ahash::{HashMap, HashMapExt};
use anyhow::anyhow;
use futures_util::{pin_mut, FutureExt, StreamExt};
use futures_util::FutureExt;
use http_body_util::Either;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use std::net::SocketAddr;
use std::ops::Deref;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
@ -28,27 +25,25 @@ use socket2::SockRef;
use crate::protocols::dns::DnsResolver;
use crate::protocols::tls;
use crate::protocols::udp::{UdpStream, UdpStreamWriter};
use crate::restrictions::config_reloader::RestrictionsRulesReloader;
use crate::restrictions::types::{RestrictionConfig, RestrictionsRules};
use crate::tunnel::connectors::{TcpTunnelConnector, TunnelConnector, UdpTunnelConnector};
use crate::tunnel::listeners::{
new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, TunnelListener,
};
use crate::tunnel::listeners::{HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, UdpTunnelListener};
use crate::tunnel::server::handler_http2::http_server_upgrade;
use crate::tunnel::server::handler_websocket::ws_server_upgrade;
use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer;
use crate::tunnel::server::utils::{
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel,
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, try_to_sock_aadr,
validate_tunnel,
};
use crate::tunnel::tls_reloader::TlsReloader;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::mpsc;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::TlsAcceptor;
use tracing::{error, info, span, warn, Instrument, Level, Span};
use url::{Host, Url};
use url::Url;
#[derive(Debug)]
pub struct TlsServerConfig {
@ -214,85 +209,59 @@ impl WsServer {
Ok((remote, Box::pin(rx), Box::pin(tx)))
}
LocalProtocol::ReverseTcp => {
type Item = <TcpTunnelListener as TunnelListener>::OkReturn;
#[allow(clippy::type_complexity)]
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
static SERVERS: Lazy<ReverseTunnelServer<TcpTunnelListener>> = Lazy::new(ReverseTunnelServer::new);
let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let listening_server = async {
let bind = format!("{}:{}", local_srv.0, local_srv.1);
TcpTunnelListener::new(bind.parse()?, local_srv.clone(), false).await
};
let ((local_rx, local_tx), remote) =
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let bind = try_to_sock_aadr(local_srv.clone())?;
let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseUdp { timeout } => {
type Item = ((UdpStream, UdpStreamWriter), RemoteAddr);
#[allow(clippy::type_complexity)]
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
static SERVERS: Lazy<ReverseTunnelServer<UdpTunnelListener>> = Lazy::new(ReverseTunnelServer::new);
let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let listening_server = async {
let bind = format!("{}:{}", local_srv.0, local_srv.1);
new_udp_listener(bind.parse()?, local_srv.clone(), timeout).await
};
let ((local_rx, local_tx), remote) =
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let bind = try_to_sock_aadr(local_srv.clone())?;
let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseSocks5 { timeout, credentials } => {
type Item = <Socks5TunnelListener as TunnelListener>::OkReturn;
#[allow(clippy::type_complexity)]
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
static SERVERS: Lazy<ReverseTunnelServer<Socks5TunnelListener>> = Lazy::new(ReverseTunnelServer::new);
let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let listening_server = async {
let bind = format!("{}:{}", local_srv.0, local_srv.1);
Socks5TunnelListener::new(bind.parse()?, timeout, credentials).await
};
let ((local_rx, local_tx), remote) =
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let bind = try_to_sock_aadr(local_srv.clone())?;
let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseHttpProxy { timeout, credentials } => {
type Item = <HttpProxyTunnelListener as TunnelListener>::OkReturn;
#[allow(clippy::type_complexity)]
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
static SERVERS: Lazy<ReverseTunnelServer<HttpProxyTunnelListener>> =
Lazy::new(ReverseTunnelServer::new);
let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let listening_server = async {
let bind = format!("{}:{}", local_srv.0, local_srv.1);
HttpProxyTunnelListener::new(bind.parse()?, timeout, credentials, false).await
};
let ((local_rx, local_tx), remote) =
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let bind = try_to_sock_aadr(local_srv.clone())?;
let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
#[cfg(unix)]
LocalProtocol::ReverseUnix { ref path } => {
use crate::tunnel::listeners::UnixTunnelListener;
type Item = <UnixTunnelListener as TunnelListener>::OkReturn;
#[allow(clippy::type_complexity)]
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
static SERVERS: Lazy<ReverseTunnelServer<UnixTunnelListener>> = Lazy::new(ReverseTunnelServer::new);
let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?;
let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await };
let ((local_rx, local_tx), remote) =
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
@ -551,65 +520,3 @@ impl TlsContext<'_> {
&self.tls_acceptor
}
}
#[allow(clippy::type_complexity)]
async fn run_listening_server<T>(
local_srv: &(Host, u16),
servers: &Mutex<
HashMap<
(Host<String>, u16),
mpsc::Receiver<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>,
>,
>,
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
) -> anyhow::Result<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>
where
T: TunnelListener + Send + 'static,
{
let listening_server = servers.lock().remove(local_srv);
let mut listening_server = if let Some(listening_server) = listening_server {
listening_server
} else {
let listening_server = gen_listening_server.await?;
let send_timeout = Duration::from_secs(60 * 3);
let (tx, rx) = mpsc::channel(1);
let fut = async move {
pin_mut!(listening_server);
loop {
select! {
biased;
cnx = listening_server.next() => {
match cnx {
None => break,
Some(Err(err)) => {
warn!("Error while listening for incoming connections {err:?}");
continue;
}
Some(Ok(cnx)) => {
if tx.send_timeout(cnx, send_timeout).await.is_err() {
info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs());
break;
}
}
}
},
_ = tx.closed() => {
break;
}
}
}
info!("Stopping listening reverse server");
};
tokio::spawn(fut.instrument(Span::current()));
rx
};
let cnx = listening_server
.recv()
.await
.ok_or_else(|| anyhow!("listening reverse server stopped"))?;
servers.lock().insert(local_srv.clone(), listening_server);
Ok(cnx)
}

View file

@ -10,7 +10,7 @@ use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL};
use hyper::{http, Request, Response, StatusCode};
use jsonwebtoken::TokenData;
use std::cmp::min;
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::ops::Deref;
use tracing::{error, info, warn};
use url::Host;
@ -218,3 +218,11 @@ pub(super) fn inject_cookie(response: &mut http::Response<impl Body>, remote_add
Ok(())
}
pub fn try_to_sock_aadr((host, port): (Host, u16)) -> anyhow::Result<SocketAddr> {
match host {
Host::Domain(_) => Err(anyhow::anyhow!("Cannot convert domain to socket address")),
Host::Ipv4(ip) => Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))),
Host::Ipv6(ip) => Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))),
}
}