chore: refacto use dedicated trait
This commit is contained in:
parent
ef016f0467
commit
5e74ed233d
11 changed files with 408 additions and 242 deletions
|
@ -1,11 +1,10 @@
|
|||
use ahash::{HashMap, HashMapExt};
|
||||
use anyhow::anyhow;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||
use futures_util::{pin_mut, FutureExt, StreamExt};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyStream, Either, StreamBody};
|
||||
use std::cmp::min;
|
||||
use std::fmt::Debug;
|
||||
use std::future::Future;
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
@ -15,7 +14,7 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
|
||||
use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
|
||||
use crate::{protocols, socks5, LocalProtocol, TlsServerConfig, WsServerConfig};
|
||||
use crate::{protocols, LocalProtocol, TlsServerConfig, WsServerConfig};
|
||||
use hyper::body::{Frame, Incoming};
|
||||
use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL};
|
||||
use hyper::http::HeaderValue;
|
||||
|
@ -28,18 +27,21 @@ use once_cell::sync::Lazy;
|
|||
use parking_lot::Mutex;
|
||||
use socket2::SockRef;
|
||||
|
||||
use crate::protocols::udp::UdpStream;
|
||||
use crate::protocols::{http_proxy, tls, udp};
|
||||
use crate::protocols::tls;
|
||||
use crate::protocols::udp::{UdpStream, UdpStreamWriter};
|
||||
use crate::restrictions::config_reloader::RestrictionsRulesReloader;
|
||||
use crate::restrictions::types::{
|
||||
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
|
||||
};
|
||||
use crate::socks5::Socks5Stream;
|
||||
use crate::tunnel::connectors::{TcpTunnelConnector, TunnelConnector, UdpTunnelConnector};
|
||||
use crate::tunnel::listeners::{
|
||||
new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, TunnelListener,
|
||||
};
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
||||
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
@ -56,140 +58,123 @@ async fn run_tunnel(
|
|||
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
||||
match remote.protocol {
|
||||
LocalProtocol::Udp { timeout, .. } => {
|
||||
let cnx = udp::connect(
|
||||
let (rx, tx) = UdpTunnelConnector::new(
|
||||
&remote.host,
|
||||
remote.port,
|
||||
timeout.unwrap_or(Duration::from_secs(10)),
|
||||
server_config.socket_so_mark,
|
||||
timeout.unwrap_or(Duration::from_secs(10)),
|
||||
&server_config.dns_resolver,
|
||||
)
|
||||
.connect(&None)
|
||||
.await?;
|
||||
|
||||
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
||||
Ok((remote, Box::pin(rx), Box::pin(tx)))
|
||||
}
|
||||
LocalProtocol::Tcp { proxy_protocol } => {
|
||||
let mut socket = protocols::tcp::connect(
|
||||
let (rx, mut tx) = TcpTunnelConnector::new(
|
||||
&remote.host,
|
||||
remote.port,
|
||||
server_config.socket_so_mark,
|
||||
Duration::from_secs(10),
|
||||
&server_config.dns_resolver,
|
||||
)
|
||||
.connect(&None)
|
||||
.await?;
|
||||
|
||||
if proxy_protocol {
|
||||
let header = ppp::v2::Builder::with_addresses(
|
||||
ppp::v2::Version::Two | ppp::v2::Command::Proxy,
|
||||
ppp::v2::Protocol::Stream,
|
||||
(client_address, socket.local_addr().unwrap()),
|
||||
(client_address, tx.local_addr().unwrap()),
|
||||
)
|
||||
.build()
|
||||
.unwrap();
|
||||
let _ = socket.write_all(&header).await;
|
||||
let _ = tx.write_all(&header).await;
|
||||
}
|
||||
|
||||
let (rx, tx) = socket.into_split();
|
||||
Ok((remote, Box::pin(rx), Box::pin(tx)))
|
||||
}
|
||||
LocalProtocol::ReverseTcp => {
|
||||
type Item = <TcpTunnelListener as TunnelListener>::OkReturn;
|
||||
#[allow(clippy::type_complexity)]
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let remote_port = find_mapped_port(remote.port, restriction);
|
||||
let local_srv = (remote.host, remote_port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server = protocols::tcp::run_server(bind.parse()?, false);
|
||||
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let (local_rx, local_tx) = tcp.into_split();
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
let listening_server = async {
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
TcpTunnelListener::new(bind.parse()?, local_srv.clone(), false).await
|
||||
};
|
||||
let ((local_rx, local_tx), remote) =
|
||||
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
|
||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||
}
|
||||
LocalProtocol::ReverseUdp { timeout } => {
|
||||
type Item = ((UdpStream, UdpStreamWriter), RemoteAddr);
|
||||
#[allow(clippy::type_complexity)]
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let remote_port = find_mapped_port(remote.port, restriction);
|
||||
let local_srv = (remote.host, remote_port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server =
|
||||
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
|
||||
let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let udp_writer = udp.writer();
|
||||
let (local_rx, local_tx) = (udp, udp_writer);
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
let listening_server = async {
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
new_udp_listener(bind.parse()?, local_srv.clone(), timeout).await
|
||||
};
|
||||
let ((local_rx, local_tx), remote) =
|
||||
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||
}
|
||||
LocalProtocol::ReverseSocks5 { timeout, credentials } => {
|
||||
type Item = <Socks5TunnelListener as TunnelListener>::OkReturn;
|
||||
#[allow(clippy::type_complexity)]
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> =
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let remote_port = find_mapped_port(remote.port, restriction);
|
||||
let local_srv = (remote.host, remote_port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server = socks5::run_server(bind.parse()?, timeout, credentials);
|
||||
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let protocol = stream.local_protocol();
|
||||
let (local_rx, local_tx) = tokio::io::split(stream);
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
let listening_server = async {
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
Socks5TunnelListener::new(bind.parse()?, timeout, credentials).await
|
||||
};
|
||||
let ((local_rx, local_tx), remote) =
|
||||
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
|
||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||
}
|
||||
LocalProtocol::ReverseHttpProxy { timeout, credentials } => {
|
||||
type Item = <HttpProxyTunnelListener as TunnelListener>::OkReturn;
|
||||
#[allow(clippy::type_complexity)]
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(TcpStream, (Host, u16))>>>> =
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let remote_port = find_mapped_port(remote.port, restriction);
|
||||
let local_srv = (remote.host, remote_port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server = http_proxy::run_server(bind.parse()?, timeout, credentials);
|
||||
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let (local_rx, local_tx) = tokio::io::split(stream);
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: LocalProtocol::Tcp { proxy_protocol: false },
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
let listening_server = async {
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
HttpProxyTunnelListener::new(bind.parse()?, timeout, credentials, false).await
|
||||
};
|
||||
let ((local_rx, local_tx), remote) =
|
||||
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
|
||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||
}
|
||||
#[cfg(unix)]
|
||||
LocalProtocol::ReverseUnix { ref path } => {
|
||||
use protocols::unix_sock;
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
use crate::tunnel::listeners::UnixTunnelListener;
|
||||
type Item = <UnixTunnelListener as TunnelListener>::OkReturn;
|
||||
#[allow(clippy::type_complexity)]
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UnixStream>>>> =
|
||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<Item>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let remote_port = find_mapped_port(remote.port, restriction);
|
||||
let local_srv = (remote.host, remote_port);
|
||||
let listening_server = unix_sock::run_server(path);
|
||||
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let (local_rx, local_tx) = stream.into_split();
|
||||
let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await };
|
||||
let ((local_rx, local_tx), remote) =
|
||||
run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
};
|
||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
|
@ -232,66 +217,6 @@ fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -> u16 {
|
|||
remote_port
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
async fn run_listening_server<T, Fut, FutOut, E>(
|
||||
local_srv: &(Host, u16),
|
||||
servers: &Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<T>>>,
|
||||
gen_listening_server: Fut,
|
||||
) -> anyhow::Result<T>
|
||||
where
|
||||
Fut: Future<Output = anyhow::Result<FutOut>>,
|
||||
FutOut: Stream<Item = Result<T, E>> + Send + 'static,
|
||||
E: Debug + Send,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let listening_server = servers.lock().remove(local_srv);
|
||||
let mut listening_server = if let Some(listening_server) = listening_server {
|
||||
listening_server
|
||||
} else {
|
||||
let listening_server = gen_listening_server.await?;
|
||||
let send_timeout = Duration::from_secs(60 * 3);
|
||||
let (tx, rx) = mpsc::channel::<T>(1);
|
||||
let fut = async move {
|
||||
pin_mut!(listening_server);
|
||||
loop {
|
||||
select! {
|
||||
biased;
|
||||
cnx = listening_server.next() => {
|
||||
match cnx {
|
||||
None => break,
|
||||
Some(Err(err)) => {
|
||||
warn!("Error while listening for incoming connections {err:?}");
|
||||
continue;
|
||||
}
|
||||
Some(Ok(cnx)) => {
|
||||
if tx.send_timeout(cnx, send_timeout).await.is_err() {
|
||||
info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
_ = tx.closed() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stopping listening reverse server");
|
||||
};
|
||||
|
||||
tokio::spawn(fut.instrument(Span::current()));
|
||||
rx
|
||||
};
|
||||
|
||||
let cnx = listening_server
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("listening reverse server stopped"))?;
|
||||
servers.lock().insert(local_srv.clone(), listening_server);
|
||||
Ok(cnx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, Response<String>> {
|
||||
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
|
||||
|
@ -957,3 +882,65 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
async fn run_listening_server<T>(
|
||||
local_srv: &(Host, u16),
|
||||
servers: &Mutex<
|
||||
HashMap<
|
||||
(Host<String>, u16),
|
||||
mpsc::Receiver<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>,
|
||||
>,
|
||||
>,
|
||||
gen_listening_server: impl Future<Output = anyhow::Result<T>>,
|
||||
) -> anyhow::Result<((<T as TunnelListener>::Reader, <T as TunnelListener>::Writer), RemoteAddr)>
|
||||
where
|
||||
T: TunnelListener + Send + 'static,
|
||||
{
|
||||
let listening_server = servers.lock().remove(local_srv);
|
||||
let mut listening_server = if let Some(listening_server) = listening_server {
|
||||
listening_server
|
||||
} else {
|
||||
let listening_server = gen_listening_server.await?;
|
||||
let send_timeout = Duration::from_secs(60 * 3);
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
let fut = async move {
|
||||
pin_mut!(listening_server);
|
||||
loop {
|
||||
select! {
|
||||
biased;
|
||||
cnx = listening_server.next() => {
|
||||
match cnx {
|
||||
None => break,
|
||||
Some(Err(err)) => {
|
||||
warn!("Error while listening for incoming connections {err:?}");
|
||||
continue;
|
||||
}
|
||||
Some(Ok(cnx)) => {
|
||||
if tx.send_timeout(cnx, send_timeout).await.is_err() {
|
||||
info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
_ = tx.closed() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stopping listening reverse server");
|
||||
};
|
||||
|
||||
tokio::spawn(fut.instrument(Span::current()));
|
||||
rx
|
||||
};
|
||||
|
||||
let cnx = listening_server
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("listening reverse server stopped"))?;
|
||||
servers.lock().insert(local_srv.clone(), listening_server);
|
||||
Ok(cnx)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue