Support proxy protocol for tcp connection
This commit is contained in:
parent
79a50b654e
commit
dc4eadb8f9
4 changed files with 80 additions and 50 deletions
22
src/main.rs
22
src/main.rs
|
@ -90,6 +90,8 @@ struct Client {
|
||||||
/// Listen on local and forwards traffic from remote. Can be specified multiple times
|
/// Listen on local and forwards traffic from remote. Can be specified multiple times
|
||||||
/// examples:
|
/// examples:
|
||||||
/// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443
|
/// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443
|
||||||
|
/// 'tcp://2:n.lan:4?proxy_protocol' => listen locally on tcp on port 2 and forward to n.lan on port 4
|
||||||
|
/// Send a proxy protocol header v2 when establishing connection to n.lan
|
||||||
///
|
///
|
||||||
/// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53
|
/// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53
|
||||||
/// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30]
|
/// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30]
|
||||||
|
@ -258,7 +260,7 @@ struct Server {
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
enum LocalProtocol {
|
enum LocalProtocol {
|
||||||
Tcp,
|
Tcp { proxy_protocol: bool },
|
||||||
Udp { timeout: Option<Duration> },
|
Udp { timeout: Option<Duration> },
|
||||||
Stdio,
|
Stdio,
|
||||||
Socks5 { timeout: Option<Duration> },
|
Socks5 { timeout: Option<Duration> },
|
||||||
|
@ -367,9 +369,10 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
|
||||||
match &arg[..6] {
|
match &arg[..6] {
|
||||||
"tcp://" => {
|
"tcp://" => {
|
||||||
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
|
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
|
||||||
let (dest_host, dest_port, _options) = parse_tunnel_dest(remaining)?;
|
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
|
||||||
|
let proxy_protocol = options.contains_key("proxy_protocol");
|
||||||
Ok(LocalToRemote {
|
Ok(LocalToRemote {
|
||||||
local_protocol: LocalProtocol::Tcp,
|
local_protocol: LocalProtocol::Tcp { proxy_protocol },
|
||||||
local: local_bind,
|
local: local_bind,
|
||||||
remote: (dest_host, dest_port),
|
remote: (dest_host, dest_port),
|
||||||
})
|
})
|
||||||
|
@ -701,7 +704,7 @@ async fn main() {
|
||||||
for tunnel in args.remote_to_local.into_iter() {
|
for tunnel in args.remote_to_local.into_iter() {
|
||||||
let client_config = client_config.clone();
|
let client_config = client_config.clone();
|
||||||
match &tunnel.local_protocol {
|
match &tunnel.local_protocol {
|
||||||
LocalProtocol::Tcp => {
|
LocalProtocol::Tcp { proxy_protocol: _ } => {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let remote = tunnel.remote.clone();
|
let remote = tunnel.remote.clone();
|
||||||
let cfg = client_config.clone();
|
let cfg = client_config.clone();
|
||||||
|
@ -775,7 +778,7 @@ async fn main() {
|
||||||
};
|
};
|
||||||
|
|
||||||
match remote.protocol {
|
match remote.protocol {
|
||||||
LocalProtocol::Tcp => {
|
LocalProtocol::Tcp { proxy_protocol: _ } => {
|
||||||
tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver)
|
tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver)
|
||||||
.await
|
.await
|
||||||
.map(|s| Box::new(s) as Box<dyn T>)
|
.map(|s| Box::new(s) as Box<dyn T>)
|
||||||
|
@ -805,7 +808,8 @@ async fn main() {
|
||||||
let client_config = client_config.clone();
|
let client_config = client_config.clone();
|
||||||
|
|
||||||
match &tunnel.local_protocol {
|
match &tunnel.local_protocol {
|
||||||
LocalProtocol::Tcp => {
|
LocalProtocol::Tcp { proxy_protocol } => {
|
||||||
|
let proxy_protocol = *proxy_protocol;
|
||||||
let remote = tunnel.remote.clone();
|
let remote = tunnel.remote.clone();
|
||||||
let server = tcp::run_server(tunnel.local, false)
|
let server = tcp::run_server(tunnel.local, false)
|
||||||
.await
|
.await
|
||||||
|
@ -813,7 +817,7 @@ async fn main() {
|
||||||
.map_err(anyhow::Error::new)
|
.map_err(anyhow::Error::new)
|
||||||
.map_ok(move |stream| {
|
.map_ok(move |stream| {
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::Tcp,
|
protocol: LocalProtocol::Tcp { proxy_protocol },
|
||||||
host: remote.0.clone(),
|
host: remote.0.clone(),
|
||||||
port: remote.1,
|
port: remote.1,
|
||||||
};
|
};
|
||||||
|
@ -836,7 +840,7 @@ async fn main() {
|
||||||
// In TProxy mode local destination is the final ip:port destination
|
// In TProxy mode local destination is the final ip:port destination
|
||||||
let (host, port) = to_host_port(stream.local_addr().unwrap());
|
let (host, port) = to_host_port(stream.local_addr().unwrap());
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::Tcp,
|
protocol: LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
};
|
};
|
||||||
|
@ -931,7 +935,7 @@ async fn main() {
|
||||||
client_config,
|
client_config,
|
||||||
stream::once(async move {
|
stream::once(async move {
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::Tcp,
|
protocol: LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
host: tunnel.remote.0,
|
host: tunnel.remote.0,
|
||||||
port: tunnel.remote.1,
|
port: tunnel.remote.1,
|
||||||
};
|
};
|
||||||
|
|
|
@ -29,7 +29,7 @@ pub enum Socks5Stream {
|
||||||
impl Socks5Stream {
|
impl Socks5Stream {
|
||||||
pub fn local_protocol(&self) -> LocalProtocol {
|
pub fn local_protocol(&self) -> LocalProtocol {
|
||||||
match self {
|
match self {
|
||||||
Socks5Stream::Tcp(_) => LocalProtocol::Tcp,
|
Socks5Stream::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
Socks5Stream::Udp(s) => LocalProtocol::Udp {
|
Socks5Stream::Udp(s) => LocalProtocol::Udp {
|
||||||
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
|
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
|
||||||
},
|
},
|
||||||
|
|
|
@ -24,10 +24,10 @@ use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct JwtTunnelConfig {
|
struct JwtTunnelConfig {
|
||||||
pub id: String,
|
pub id: String, // tunnel id
|
||||||
pub p: LocalProtocol,
|
pub p: LocalProtocol, // protocol to use
|
||||||
pub r: String,
|
pub r: String, // remote host
|
||||||
pub rp: u16,
|
pub rp: u16, // remote port
|
||||||
}
|
}
|
||||||
|
|
||||||
impl JwtTunnelConfig {
|
impl JwtTunnelConfig {
|
||||||
|
@ -35,14 +35,14 @@ impl JwtTunnelConfig {
|
||||||
Self {
|
Self {
|
||||||
id: request_id.to_string(),
|
id: request_id.to_string(),
|
||||||
p: match dest.protocol {
|
p: match dest.protocol {
|
||||||
LocalProtocol::Tcp => LocalProtocol::Tcp,
|
LocalProtocol::Tcp { .. } => dest.protocol,
|
||||||
LocalProtocol::Udp { .. } => dest.protocol,
|
LocalProtocol::Udp { .. } => dest.protocol,
|
||||||
LocalProtocol::Stdio => LocalProtocol::Tcp,
|
LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp,
|
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
|
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
|
||||||
LocalProtocol::ReverseUdp { .. } => dest.protocol,
|
LocalProtocol::ReverseUdp { .. } => dest.protocol,
|
||||||
LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5,
|
LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5,
|
||||||
LocalProtocol::TProxyTcp => LocalProtocol::Tcp,
|
LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false },
|
||||||
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
|
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
|
||||||
},
|
},
|
||||||
r: dest.host.to_string(),
|
r: dest.host.to_string(),
|
||||||
|
@ -75,6 +75,17 @@ pub struct RemoteAddr {
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
protocol: jwt.p,
|
||||||
|
host: Host::parse(&jwt.r)?,
|
||||||
|
port: jwt.rp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub enum TransportStream {
|
pub enum TransportStream {
|
||||||
Plain(TcpStream),
|
Plain(TcpStream),
|
||||||
Tls(TlsStream<TcpStream>),
|
Tls(TlsStream<TcpStream>),
|
||||||
|
|
|
@ -4,6 +4,7 @@ use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::ops::{Deref, Not};
|
use std::ops::{Deref, Not};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -24,7 +25,7 @@ use parking_lot::Mutex;
|
||||||
use crate::socks5::Socks5Stream;
|
use crate::socks5::Socks5Stream;
|
||||||
use crate::tunnel::tls_reloader::TlsReloader;
|
use crate::tunnel::tls_reloader::TlsReloader;
|
||||||
use crate::udp::UdpStream;
|
use crate::udp::UdpStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
@ -36,43 +37,44 @@ use uuid::Uuid;
|
||||||
async fn run_tunnel(
|
async fn run_tunnel(
|
||||||
server_config: &WsServerConfig,
|
server_config: &WsServerConfig,
|
||||||
jwt: TokenData<JwtTunnelConfig>,
|
jwt: TokenData<JwtTunnelConfig>,
|
||||||
|
client_address: SocketAddr,
|
||||||
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
||||||
match jwt.claims.p {
|
match jwt.claims.p {
|
||||||
LocalProtocol::Udp { timeout, .. } => {
|
LocalProtocol::Udp { timeout, .. } => {
|
||||||
let host = Host::parse(&jwt.claims.r)?;
|
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||||
let cnx = udp::connect(
|
let cnx = udp::connect(
|
||||||
&host,
|
&remote.host,
|
||||||
jwt.claims.rp,
|
remote.port,
|
||||||
timeout.unwrap_or(Duration::from_secs(10)),
|
timeout.unwrap_or(Duration::from_secs(10)),
|
||||||
&server_config.dns_resolver,
|
&server_config.dns_resolver,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let remote = RemoteAddr {
|
|
||||||
protocol: jwt.claims.p,
|
|
||||||
host,
|
|
||||||
port: jwt.claims.rp,
|
|
||||||
};
|
|
||||||
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
||||||
}
|
}
|
||||||
LocalProtocol::Tcp => {
|
LocalProtocol::Tcp { proxy_protocol } => {
|
||||||
let host = Host::parse(&jwt.claims.r)?;
|
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||||
let port = jwt.claims.rp;
|
let mut socket = tcp::connect(
|
||||||
let (rx, tx) = tcp::connect(
|
&remote.host,
|
||||||
&host,
|
remote.port,
|
||||||
port,
|
|
||||||
server_config.socket_so_mark,
|
server_config.socket_so_mark,
|
||||||
Duration::from_secs(10),
|
Duration::from_secs(10),
|
||||||
&server_config.dns_resolver,
|
&server_config.dns_resolver,
|
||||||
)
|
)
|
||||||
.await?
|
.await?;
|
||||||
.into_split();
|
|
||||||
|
|
||||||
let remote = RemoteAddr {
|
if proxy_protocol {
|
||||||
protocol: jwt.claims.p,
|
let header = ppp::v2::Builder::with_addresses(
|
||||||
host,
|
ppp::v2::Version::Two | ppp::v2::Command::Proxy,
|
||||||
port,
|
ppp::v2::Protocol::Stream,
|
||||||
};
|
(client_address, socket.local_addr().unwrap()),
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let _ = socket.write_all(&header).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (rx, tx) = socket.into_split();
|
||||||
Ok((remote, Box::pin(rx), Box::pin(tx)))
|
Ok((remote, Box::pin(rx), Box::pin(tx)))
|
||||||
}
|
}
|
||||||
LocalProtocol::ReverseTcp => {
|
LocalProtocol::ReverseTcp => {
|
||||||
|
@ -194,12 +196,16 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<&str>, Response<String>> {
|
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, Response<String>> {
|
||||||
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
|
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(x_forward_for.to_str().unwrap_or_default()))
|
// X-Forwarded-For: <client>, <proxy1>, <proxy2>
|
||||||
|
let x_forward_for = x_forward_for.to_str().unwrap_or_default();
|
||||||
|
let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for);
|
||||||
|
let ip: Option<IpAddr> = x_forward_for.parse().ok();
|
||||||
|
Ok(ip.map(|ip| (ip, x_forward_for)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -288,7 +294,11 @@ fn validate_destination(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Incoming>) -> Response<String> {
|
async fn server_upgrade(
|
||||||
|
server_config: Arc<WsServerConfig>,
|
||||||
|
mut client_addr: SocketAddr,
|
||||||
|
mut req: Request<Incoming>,
|
||||||
|
) -> Response<String> {
|
||||||
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
return http::Response::builder()
|
return http::Response::builder()
|
||||||
|
@ -298,13 +308,14 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
||||||
}
|
}
|
||||||
|
|
||||||
match extract_x_forwarded_for(&req) {
|
match extract_x_forwarded_for(&req) {
|
||||||
Ok(Some(x_forward_for)) => {
|
Ok(Some((x_forward_for, x_forward_for_str))) => {
|
||||||
info!("Request X-Forwarded-For: {:?}", x_forward_for);
|
info!("Request X-Forwarded-For: {:?}", x_forward_for);
|
||||||
Span::current().record("forwarded_for", x_forward_for);
|
Span::current().record("forwarded_for", x_forward_for_str);
|
||||||
|
client_addr.set_ip(x_forward_for);
|
||||||
}
|
}
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(err) => return err,
|
Err(err) => return err,
|
||||||
}
|
};
|
||||||
|
|
||||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||||
return err;
|
return err;
|
||||||
|
@ -323,7 +334,7 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
||||||
}
|
}
|
||||||
|
|
||||||
let req_protocol = jwt.claims.p;
|
let req_protocol = jwt.claims.p;
|
||||||
let tunnel = match run_tunnel(&server_config, jwt).await {
|
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
|
@ -406,8 +417,12 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
||||||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||||
|
|
||||||
// setup upgrade request handler
|
// setup upgrade request handler
|
||||||
let config = server_config.clone();
|
// FIXME: Avoid double clone of the arc for each request
|
||||||
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
let mk_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||||
|
move |req: Request<Incoming>| {
|
||||||
|
server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Init TLS if needed
|
// Init TLS if needed
|
||||||
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
||||||
|
@ -443,7 +458,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
||||||
);
|
);
|
||||||
|
|
||||||
info!("Accepting connection");
|
info!("Accepting connection");
|
||||||
let upgrade_fn = upgrade_fn.clone();
|
let upgrade_fn = mk_upgrade_fn(server_config.clone(), peer_addr);
|
||||||
// TLS
|
// TLS
|
||||||
if let Some(tls) = tls_context.as_mut() {
|
if let Some(tls) = tls_context.as_mut() {
|
||||||
// Reload TLS certificate if needed
|
// Reload TLS certificate if needed
|
||||||
|
|
Loading…
Reference in a new issue