Support proxy protocol for tcp connection
This commit is contained in:
parent
79a50b654e
commit
dc4eadb8f9
4 changed files with 80 additions and 50 deletions
|
@ -4,6 +4,7 @@ use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
|||
use std::cmp::min;
|
||||
use std::fmt::Debug;
|
||||
use std::future::Future;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::ops::{Deref, Not};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
@ -24,7 +25,7 @@ use parking_lot::Mutex;
|
|||
use crate::socks5::Socks5Stream;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::udp::UdpStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
@ -36,43 +37,44 @@ use uuid::Uuid;
|
|||
async fn run_tunnel(
|
||||
server_config: &WsServerConfig,
|
||||
jwt: TokenData<JwtTunnelConfig>,
|
||||
client_address: SocketAddr,
|
||||
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
||||
match jwt.claims.p {
|
||||
LocalProtocol::Udp { timeout, .. } => {
|
||||
let host = Host::parse(&jwt.claims.r)?;
|
||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||
let cnx = udp::connect(
|
||||
&host,
|
||||
jwt.claims.rp,
|
||||
&remote.host,
|
||||
remote.port,
|
||||
timeout.unwrap_or(Duration::from_secs(10)),
|
||||
&server_config.dns_resolver,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
host,
|
||||
port: jwt.claims.rp,
|
||||
};
|
||||
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
||||
}
|
||||
LocalProtocol::Tcp => {
|
||||
let host = Host::parse(&jwt.claims.r)?;
|
||||
let port = jwt.claims.rp;
|
||||
let (rx, tx) = tcp::connect(
|
||||
&host,
|
||||
port,
|
||||
LocalProtocol::Tcp { proxy_protocol } => {
|
||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||
let mut socket = tcp::connect(
|
||||
&remote.host,
|
||||
remote.port,
|
||||
server_config.socket_so_mark,
|
||||
Duration::from_secs(10),
|
||||
&server_config.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
.into_split();
|
||||
.await?;
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
host,
|
||||
port,
|
||||
};
|
||||
if proxy_protocol {
|
||||
let header = ppp::v2::Builder::with_addresses(
|
||||
ppp::v2::Version::Two | ppp::v2::Command::Proxy,
|
||||
ppp::v2::Protocol::Stream,
|
||||
(client_address, socket.local_addr().unwrap()),
|
||||
)
|
||||
.build()
|
||||
.unwrap();
|
||||
let _ = socket.write_all(&header).await;
|
||||
}
|
||||
|
||||
let (rx, tx) = socket.into_split();
|
||||
Ok((remote, Box::pin(rx), Box::pin(tx)))
|
||||
}
|
||||
LocalProtocol::ReverseTcp => {
|
||||
|
@ -194,12 +196,16 @@ where
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<&str>, Response<String>> {
|
||||
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, Response<String>> {
|
||||
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(x_forward_for.to_str().unwrap_or_default()))
|
||||
// X-Forwarded-For: <client>, <proxy1>, <proxy2>
|
||||
let x_forward_for = x_forward_for.to_str().unwrap_or_default();
|
||||
let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for);
|
||||
let ip: Option<IpAddr> = x_forward_for.parse().ok();
|
||||
Ok(ip.map(|ip| (ip, x_forward_for)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -288,7 +294,11 @@ fn validate_destination(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Incoming>) -> Response<String> {
|
||||
async fn server_upgrade(
|
||||
server_config: Arc<WsServerConfig>,
|
||||
mut client_addr: SocketAddr,
|
||||
mut req: Request<Incoming>,
|
||||
) -> Response<String> {
|
||||
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||
return http::Response::builder()
|
||||
|
@ -298,13 +308,14 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
|||
}
|
||||
|
||||
match extract_x_forwarded_for(&req) {
|
||||
Ok(Some(x_forward_for)) => {
|
||||
Ok(Some((x_forward_for, x_forward_for_str))) => {
|
||||
info!("Request X-Forwarded-For: {:?}", x_forward_for);
|
||||
Span::current().record("forwarded_for", x_forward_for);
|
||||
Span::current().record("forwarded_for", x_forward_for_str);
|
||||
client_addr.set_ip(x_forward_for);
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(err) => return err,
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||
return err;
|
||||
|
@ -323,7 +334,7 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
|||
}
|
||||
|
||||
let req_protocol = jwt.claims.p;
|
||||
let tunnel = match run_tunnel(&server_config, jwt).await {
|
||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
|
@ -406,8 +417,12 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||
|
||||
// setup upgrade request handler
|
||||
let config = server_config.clone();
|
||||
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
||||
// FIXME: Avoid double clone of the arc for each request
|
||||
let mk_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||
move |req: Request<Incoming>| {
|
||||
server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
|
||||
}
|
||||
};
|
||||
|
||||
// Init TLS if needed
|
||||
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
||||
|
@ -443,7 +458,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
);
|
||||
|
||||
info!("Accepting connection");
|
||||
let upgrade_fn = upgrade_fn.clone();
|
||||
let upgrade_fn = mk_upgrade_fn(server_config.clone(), peer_addr);
|
||||
// TLS
|
||||
if let Some(tls) = tls_context.as_mut() {
|
||||
// Reload TLS certificate if needed
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue