From dc4eadb8f948e90d1f94e90110a7550eca7ae02a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Thu, 11 Jan 2024 09:19:32 +0100 Subject: [PATCH] Support proxy protocol for tcp connection --- src/main.rs | 22 +++++++----- src/socks5.rs | 2 +- src/tunnel/mod.rs | 27 ++++++++++----- src/tunnel/server.rs | 79 ++++++++++++++++++++++++++------------------ 4 files changed, 80 insertions(+), 50 deletions(-) diff --git a/src/main.rs b/src/main.rs index 0e5b404..9364409 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,6 +90,8 @@ struct Client { /// Listen on local and forwards traffic from remote. Can be specified multiple times /// examples: /// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443 + /// 'tcp://2:n.lan:4?proxy_protocol' => listen locally on tcp on port 2 and forward to n.lan on port 4 + /// Send a proxy protocol header v2 when establishing connection to n.lan /// /// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53 /// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30] @@ -258,7 +260,7 @@ struct Server { #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] enum LocalProtocol { - Tcp, + Tcp { proxy_protocol: bool }, Udp { timeout: Option }, Stdio, Socks5 { timeout: Option }, @@ -367,9 +369,10 @@ fn parse_tunnel_arg(arg: &str) -> Result { match &arg[..6] { "tcp://" => { let (local_bind, remaining) = parse_local_bind(&arg[6..])?; - let (dest_host, dest_port, _options) = parse_tunnel_dest(remaining)?; + let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?; + let proxy_protocol = options.contains_key("proxy_protocol"); Ok(LocalToRemote { - local_protocol: LocalProtocol::Tcp, + local_protocol: LocalProtocol::Tcp { proxy_protocol }, local: local_bind, remote: (dest_host, dest_port), }) @@ -701,7 +704,7 @@ async fn main() { for tunnel in args.remote_to_local.into_iter() { let client_config = client_config.clone(); match &tunnel.local_protocol { - LocalProtocol::Tcp => { + LocalProtocol::Tcp { proxy_protocol: _ } => { tokio::spawn(async move { let remote = tunnel.remote.clone(); let cfg = client_config.clone(); @@ -775,7 +778,7 @@ async fn main() { }; match remote.protocol { - LocalProtocol::Tcp => { + LocalProtocol::Tcp { proxy_protocol: _ } => { tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) .await .map(|s| Box::new(s) as Box) @@ -805,7 +808,8 @@ async fn main() { let client_config = client_config.clone(); match &tunnel.local_protocol { - LocalProtocol::Tcp => { + LocalProtocol::Tcp { proxy_protocol } => { + let proxy_protocol = *proxy_protocol; let remote = tunnel.remote.clone(); let server = tcp::run_server(tunnel.local, false) .await @@ -813,7 +817,7 @@ async fn main() { .map_err(anyhow::Error::new) .map_ok(move |stream| { let remote = RemoteAddr { - protocol: LocalProtocol::Tcp, + protocol: LocalProtocol::Tcp { proxy_protocol }, host: remote.0.clone(), port: remote.1, }; @@ -836,7 +840,7 @@ async fn main() { // In TProxy mode local destination is the final ip:port destination let (host, port) = to_host_port(stream.local_addr().unwrap()); let remote = RemoteAddr { - protocol: LocalProtocol::Tcp, + protocol: LocalProtocol::Tcp { proxy_protocol: false }, host, port, }; @@ -931,7 +935,7 @@ async fn main() { client_config, stream::once(async move { let remote = RemoteAddr { - protocol: LocalProtocol::Tcp, + protocol: LocalProtocol::Tcp { proxy_protocol: false }, host: tunnel.remote.0, port: tunnel.remote.1, }; diff --git a/src/socks5.rs b/src/socks5.rs index b413221..341d840 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -29,7 +29,7 @@ pub enum Socks5Stream { impl Socks5Stream { pub fn local_protocol(&self) -> LocalProtocol { match self { - Socks5Stream::Tcp(_) => LocalProtocol::Tcp, + Socks5Stream::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, Socks5Stream::Udp(s) => LocalProtocol::Udp { timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), }, diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 37ca84d..b9da6ef 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -24,10 +24,10 @@ use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] struct JwtTunnelConfig { - pub id: String, - pub p: LocalProtocol, - pub r: String, - pub rp: u16, + pub id: String, // tunnel id + pub p: LocalProtocol, // protocol to use + pub r: String, // remote host + pub rp: u16, // remote port } impl JwtTunnelConfig { @@ -35,14 +35,14 @@ impl JwtTunnelConfig { Self { id: request_id.to_string(), p: match dest.protocol { - LocalProtocol::Tcp => LocalProtocol::Tcp, + LocalProtocol::Tcp { .. } => dest.protocol, LocalProtocol::Udp { .. } => dest.protocol, - LocalProtocol::Stdio => LocalProtocol::Tcp, - LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp, + LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false }, LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, LocalProtocol::ReverseUdp { .. } => dest.protocol, LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5, - LocalProtocol::TProxyTcp => LocalProtocol::Tcp, + LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false }, LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, }, r: dest.host.to_string(), @@ -75,6 +75,17 @@ pub struct RemoteAddr { pub port: u16, } +impl TryFrom for RemoteAddr { + type Error = anyhow::Error; + fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result { + Ok(Self { + protocol: jwt.p, + host: Host::parse(&jwt.r)?, + port: jwt.rp, + }) + } +} + pub enum TransportStream { Plain(TcpStream), Tls(TlsStream), diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 404c80f..4281644 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -4,6 +4,7 @@ use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; use std::cmp::min; use std::fmt::Debug; use std::future::Future; +use std::net::{IpAddr, SocketAddr}; use std::ops::{Deref, Not}; use std::pin::Pin; use std::sync::Arc; @@ -24,7 +25,7 @@ use parking_lot::Mutex; use crate::socks5::Socks5Stream; use crate::tunnel::tls_reloader::TlsReloader; use crate::udp::UdpStream; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::select; use tokio::sync::{mpsc, oneshot}; @@ -36,43 +37,44 @@ use uuid::Uuid; async fn run_tunnel( server_config: &WsServerConfig, jwt: TokenData, + client_address: SocketAddr, ) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { match jwt.claims.p { LocalProtocol::Udp { timeout, .. } => { - let host = Host::parse(&jwt.claims.r)?; + let remote = RemoteAddr::try_from(jwt.claims)?; let cnx = udp::connect( - &host, - jwt.claims.rp, + &remote.host, + remote.port, timeout.unwrap_or(Duration::from_secs(10)), &server_config.dns_resolver, ) .await?; - let remote = RemoteAddr { - protocol: jwt.claims.p, - host, - port: jwt.claims.rp, - }; Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) } - LocalProtocol::Tcp => { - let host = Host::parse(&jwt.claims.r)?; - let port = jwt.claims.rp; - let (rx, tx) = tcp::connect( - &host, - port, + LocalProtocol::Tcp { proxy_protocol } => { + let remote = RemoteAddr::try_from(jwt.claims)?; + let mut socket = tcp::connect( + &remote.host, + remote.port, server_config.socket_so_mark, Duration::from_secs(10), &server_config.dns_resolver, ) - .await? - .into_split(); + .await?; - let remote = RemoteAddr { - protocol: jwt.claims.p, - host, - port, - }; + if proxy_protocol { + let header = ppp::v2::Builder::with_addresses( + ppp::v2::Version::Two | ppp::v2::Command::Proxy, + ppp::v2::Protocol::Stream, + (client_address, socket.local_addr().unwrap()), + ) + .build() + .unwrap(); + let _ = socket.write_all(&header).await; + } + + let (rx, tx) = socket.into_split(); Ok((remote, Box::pin(rx), Box::pin(tx))) } LocalProtocol::ReverseTcp => { @@ -194,12 +196,16 @@ where } #[inline] -fn extract_x_forwarded_for(req: &Request) -> Result, Response> { +fn extract_x_forwarded_for(req: &Request) -> Result, Response> { let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { return Ok(None); }; - Ok(Some(x_forward_for.to_str().unwrap_or_default())) + // X-Forwarded-For: , , + let x_forward_for = x_forward_for.to_str().unwrap_or_default(); + let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for); + let ip: Option = x_forward_for.parse().ok(); + Ok(ip.map(|ip| (ip, x_forward_for))) } #[inline] @@ -288,7 +294,11 @@ fn validate_destination( Ok(()) } -async fn server_upgrade(server_config: Arc, mut req: Request) -> Response { +async fn server_upgrade( + server_config: Arc, + mut client_addr: SocketAddr, + mut req: Request, +) -> Response { if !fastwebsockets::upgrade::is_upgrade_request(&req) { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return http::Response::builder() @@ -298,13 +308,14 @@ async fn server_upgrade(server_config: Arc, mut req: Request { + Ok(Some((x_forward_for, x_forward_for_str))) => { info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); } Ok(_) => {} Err(err) => return err, - } + }; if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) { return err; @@ -323,7 +334,7 @@ async fn server_upgrade(server_config: Arc, mut req: Request ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); @@ -406,8 +417,12 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Starting wstunnel server listening on {}", server_config.bind); // setup upgrade request handler - let config = server_config.clone(); - let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req).map::, _>(Ok); + // FIXME: Avoid double clone of the arc for each request + let mk_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { + move |req: Request| { + server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) + } + }; // Init TLS if needed let mut tls_context = if let Some(tls_config) = &server_config.tls { @@ -443,7 +458,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() ); info!("Accepting connection"); - let upgrade_fn = upgrade_fn.clone(); + let upgrade_fn = mk_upgrade_fn(server_config.clone(), peer_addr); // TLS if let Some(tls) = tls_context.as_mut() { // Reload TLS certificate if needed