diff --git a/src/main.rs b/src/main.rs index 47aae04..54f32d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ mod protocols; mod restrictions; mod tunnel; -use anyhow::anyhow; use base64::Engine; use clap::Parser; use hyper::header::HOST; @@ -21,8 +20,6 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; use tokio::select; use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName}; @@ -31,9 +28,9 @@ use tokio_rustls::TlsConnector; use tracing::{error, info}; use crate::protocols::dns::DnsResolver; -use crate::protocols::udp::MyUdpSocket; -use crate::protocols::{socks5, tls, udp}; +use crate::protocols::tls; use crate::restrictions::types::RestrictionsRules; +use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector}; use crate::tunnel::listeners::{ new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, }; @@ -989,19 +986,14 @@ async fn main() -> anyhow::Result<()> { match &tunnel.local_protocol { LocalProtocol::Tcp { proxy_protocol: _ } => { tokio::spawn(async move { - let remote = tunnel.remote.clone(); let cfg = client_config.clone(); - let connect_to_dest = |_| async { - protocols::tcp::connect( - &remote.0, - remote.1, - cfg.socket_so_mark, - cfg.timeout_connect, - &cfg.dns_resolver, - ) - .await - }; - + let tcp_connector = TcpTunnelConnector::new( + &tunnel.remote.0, + tunnel.remote.1, + cfg.socket_so_mark, + cfg.timeout_connect, + &cfg.dns_resolver, + ); let (host, port) = to_host_port(tunnel.local); let remote = RemoteAddr { protocol: LocalProtocol::ReverseTcp, @@ -1009,7 +1001,7 @@ async fn main() -> anyhow::Result<()> { port, }; if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await { error!("{:?}", err); } @@ -1026,29 +1018,22 @@ async fn main() -> anyhow::Result<()> { host, port, }; - let connect_to_dest = |_| async { - udp::connect( - &tunnel.remote.0, - tunnel.remote.1, - cfg.timeout_connect, - cfg.socket_so_mark, - &cfg.dns_resolver, - ) - .await - }; + let udp_connector = UdpTunnelConnector::new( + &remote.host, + remote.port, + cfg.socket_so_mark, + cfg.timeout_connect, + &cfg.dns_resolver, + ); if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await { error!("{:?}", err); } }); } LocalProtocol::Socks5 { timeout, credentials } => { - trait T: AsyncWrite + AsyncRead + Unpin + Send {} - impl T for TcpStream {} - impl T for MyUdpSocket {} - let credentials = credentials.clone(); let timeout = *timeout; tokio::spawn(async move { @@ -1059,37 +1044,11 @@ async fn main() -> anyhow::Result<()> { host, port, }; - let connect_to_dest = |remote: Option| { - let so_mark = cfg.socket_so_mark; - let timeout = cfg.timeout_connect; - let dns_resolver = &cfg.dns_resolver; - async move { - let Some(remote) = remote else { - return Err(anyhow!("Missing remote destination for reverse socks5")); - }; - - match remote.protocol { - LocalProtocol::Tcp { proxy_protocol: _ } => protocols::tcp::connect( - &remote.host, - remote.port, - so_mark, - timeout, - dns_resolver, - ) - .await - .map(|s| Box::new(s) as Box), - LocalProtocol::Udp { .. } => { - udp::connect(&remote.host, remote.port, timeout, so_mark, dns_resolver) - .await - .map(|s| Box::new(s) as Box) - } - _ => Err(anyhow!("Invalid protocol for reverse socks5 {:?}", remote.protocol)), - } - } - }; + let socks_connector = + Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver); if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await { error!("{:?}", err); } @@ -1108,22 +1067,16 @@ async fn main() -> anyhow::Result<()> { host, port, }; - let connect_to_dest = |remote: Option| { - let so_mark = cfg.socket_so_mark; - let timeout = cfg.timeout_connect; - let dns_resolver = &cfg.dns_resolver; - async move { - let Some(remote) = remote else { - return Err(anyhow!("Missing remote destination for reverse socks5")); - }; - - protocols::tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) - .await - } - }; + let tcp_connector = TcpTunnelConnector::new( + &remote.host, + remote.port, + cfg.socket_so_mark, + cfg.timeout_connect, + &cfg.dns_resolver, + ); if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await { error!("{:?}", err); } @@ -1133,18 +1086,14 @@ async fn main() -> anyhow::Result<()> { LocalProtocol::Unix { path } => { let path = path.clone(); tokio::spawn(async move { - let remote = tunnel.remote.clone(); let cfg = client_config.clone(); - let connect_to_dest = |_| async { - protocols::tcp::connect( - &remote.0, - remote.1, - cfg.socket_so_mark, - cfg.timeout_connect, - &cfg.dns_resolver, - ) - .await - }; + let tcp_connector = TcpTunnelConnector::new( + &tunnel.remote.0, + tunnel.remote.1, + cfg.socket_so_mark, + cfg.timeout_connect, + &cfg.dns_resolver, + ); let (host, port) = to_host_port(tunnel.local); let remote = RemoteAddr { @@ -1153,7 +1102,7 @@ async fn main() -> anyhow::Result<()> { port, }; if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await { error!("{:?}", err); } diff --git a/src/protocols/udp/mod.rs b/src/protocols/udp/mod.rs index e758958..b424dff 100644 --- a/src/protocols/udp/mod.rs +++ b/src/protocols/udp/mod.rs @@ -6,6 +6,6 @@ pub use server::connect; #[cfg(target_os = "linux")] pub use server::mk_send_socket_tproxy; pub use server::run_server; -pub use server::MyUdpSocket; pub use server::UdpStream; pub use server::UdpStreamWriter; +pub use server::WsUdpSocket; diff --git a/src/protocols/udp/server.rs b/src/protocols/udp/server.rs index a995f49..8d0deca 100644 --- a/src/protocols/udp/server.rs +++ b/src/protocols/udp/server.rs @@ -295,17 +295,17 @@ pub async fn run_server( } #[derive(Clone)] -pub struct MyUdpSocket { +pub struct WsUdpSocket { socket: Arc, } -impl MyUdpSocket { +impl WsUdpSocket { pub fn new(socket: Arc) -> Self { Self { socket } } } -impl AsyncRead for MyUdpSocket { +impl AsyncRead for WsUdpSocket { fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.socket) } .poll_recv_from(cx, buf) @@ -313,7 +313,7 @@ impl AsyncRead for MyUdpSocket { } } -impl AsyncWrite for MyUdpSocket { +impl AsyncWrite for WsUdpSocket { fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf) } @@ -333,7 +333,7 @@ pub async fn connect( connect_timeout: Duration, so_mark: Option, dns_resolver: &DnsResolver, -) -> anyhow::Result { +) -> anyhow::Result { info!("Opening UDP connection to {}:{}", host, port); let socket_addrs: Vec = match host { @@ -419,7 +419,7 @@ pub async fn connect( } if let Some(cnx) = cnx { - Ok(MyUdpSocket::new(Arc::new(cnx))) + Ok(WsUdpSocket::new(Arc::new(cnx))) } else { Err(anyhow!("Cannot connect to udp peer {}:{} reason {:?}", host, port, last_err)) } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 2759d1d..b2de8c4 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,4 +1,5 @@ use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; +use crate::tunnel::connectors::TunnelConnector; use crate::tunnel::listeners::TunnelListener; use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::{tunnel, WsClientConfig}; @@ -6,7 +7,6 @@ use futures_util::pin_mut; use hyper::header::COOKIE; use jsonwebtoken::TokenData; use log::debug; -use std::future::Future; use std::ops::Deref; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; @@ -90,16 +90,11 @@ pub async fn run_tunnel(client_config: Arc, incoming_cnx: impl T Ok(()) } -pub async fn run_reverse_tunnel( +pub async fn run_reverse_tunnel( client_cfg: Arc, remote_addr: RemoteAddr, - connect_to_dest: F, -) -> anyhow::Result<()> -where - F: Fn(Option) -> Fut, - Fut: Future>, - T: AsyncRead + AsyncWrite + Send + 'static, -{ + connector: impl TunnelConnector, +) -> anyhow::Result<()> { loop { let client_config = client_cfg.clone(); let request_id = Uuid::now_v7(); @@ -156,17 +151,15 @@ where port: jwt.claims.rp, }); - let stream = match connect_to_dest(remote).instrument(span.clone()).await { + let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await { Ok(s) => s, Err(err) => { - event!(parent: &span, Level::ERROR, "Cannot connect to xxxx: {err:?}"); + event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}"); continue; } }; - let (local_rx, local_tx) = tokio::io::split(stream); let (close_tx, close_rx) = oneshot::channel::<()>(); - let tunnel = async move { let ping_frequency = client_config.websocket_ping_frequency; tokio::spawn( diff --git a/src/tunnel/connectors/mod.rs b/src/tunnel/connectors/mod.rs new file mode 100644 index 0000000..f129f41 --- /dev/null +++ b/src/tunnel/connectors/mod.rs @@ -0,0 +1,17 @@ +mod sock5; +mod tcp; +mod udp; + +pub use sock5::Socks5TunnelConnector; +pub use tcp::TcpTunnelConnector; +pub use udp::UdpTunnelConnector; + +use crate::tunnel::RemoteAddr; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub trait TunnelConnector { + type Reader: AsyncRead + Send + 'static; + type Writer: AsyncWrite + Send + 'static; + + async fn connect(&self, remote: &Option) -> anyhow::Result<(Self::Reader, Self::Writer)>; +} diff --git a/src/tunnel/connectors/sock5.rs b/src/tunnel/connectors/sock5.rs new file mode 100644 index 0000000..694f0a4 --- /dev/null +++ b/src/tunnel/connectors/sock5.rs @@ -0,0 +1,124 @@ +use std::io::{Error, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use anyhow::anyhow; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; + +use crate::protocols::dns::DnsResolver; +use crate::protocols::udp; +use crate::protocols::udp::WsUdpSocket; +use crate::tunnel::connectors::TunnelConnector; +use crate::tunnel::RemoteAddr; +use crate::{protocols, LocalProtocol}; + +pub struct Socks5TunnelConnector<'a> { + so_mark: Option, + connect_timeout: Duration, + dns_resolver: &'a DnsResolver, +} + +impl Socks5TunnelConnector<'_> { + pub fn new(so_mark: Option, connect_timeout: Duration, dns_resolver: &DnsResolver) -> Socks5TunnelConnector { + Socks5TunnelConnector { + so_mark, + connect_timeout, + dns_resolver, + } + } +} + +impl TunnelConnector for Socks5TunnelConnector<'_> { + type Reader = Socks5Reader; + type Writer = Socks5Writer; + + async fn connect(&self, remote: &Option) -> anyhow::Result<(Self::Reader, Self::Writer)> { + let Some(remote) = remote else { + return Err(anyhow!("Missing remote destination for reverse socks5")); + }; + + match remote.protocol { + LocalProtocol::Tcp { proxy_protocol: _ } => { + let stream = protocols::tcp::connect( + &remote.host, + remote.port, + self.so_mark, + self.connect_timeout, + self.dns_resolver, + ) + .await?; + let (reader, writer) = stream.into_split(); + Ok((Socks5Reader::Tcp(reader), Socks5Writer::Tcp(writer))) + } + LocalProtocol::Udp { .. } => { + let stream = + udp::connect(&remote.host, remote.port, self.connect_timeout, self.so_mark, self.dns_resolver) + .await?; + Ok((Socks5Reader::Udp(stream.clone()), Socks5Writer::Udp(stream))) + } + _ => Err(anyhow!("Invalid protocol for reverse socks5 {:?}", remote.protocol)), + } + } +} + +pub enum Socks5Reader { + Tcp(OwnedReadHalf), + Udp(WsUdpSocket), +} + +impl AsyncRead for Socks5Reader { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + match self.get_mut() { + Socks5Reader::Tcp(reader) => Pin::new(reader).poll_read(cx, buf), + Socks5Reader::Udp(reader) => Pin::new(reader).poll_read(cx, buf), + } + } +} + +pub enum Socks5Writer { + Tcp(OwnedWriteHalf), + Udp(WsUdpSocket), +} + +impl AsyncWrite for Socks5Writer { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + Socks5Writer::Tcp(writer) => Pin::new(writer).poll_write(cx, buf), + Socks5Writer::Udp(wrtier) => Pin::new(wrtier).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Socks5Writer::Tcp(writer) => Pin::new(writer).poll_flush(cx), + Socks5Writer::Udp(wrtier) => Pin::new(wrtier).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Socks5Writer::Tcp(writer) => Pin::new(writer).poll_shutdown(cx), + Socks5Writer::Udp(wrtier) => Pin::new(wrtier).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Socks5Writer::Tcp(writer) => Pin::new(writer).poll_write_vectored(cx, bufs), + Socks5Writer::Udp(wrtier) => Pin::new(wrtier).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Socks5Writer::Tcp(v) => v.is_write_vectored(), + Socks5Writer::Udp(v) => v.is_write_vectored(), + } + } +} diff --git a/src/tunnel/connectors/tcp.rs b/src/tunnel/connectors/tcp.rs new file mode 100644 index 0000000..96a07aa --- /dev/null +++ b/src/tunnel/connectors/tcp.rs @@ -0,0 +1,48 @@ +use crate::protocols; +use crate::protocols::dns::DnsResolver; +use crate::tunnel::connectors::TunnelConnector; +use crate::tunnel::RemoteAddr; +use std::time::Duration; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use url::Host; + +pub struct TcpTunnelConnector<'a> { + host: &'a Host, + port: u16, + so_mark: Option, + connect_timeout: Duration, + dns_resolver: &'a DnsResolver, +} + +impl<'a> TcpTunnelConnector<'a> { + pub fn new( + host: &'a Host, + port: u16, + so_mark: Option, + connect_timeout: Duration, + dns_resolver: &'a DnsResolver, + ) -> TcpTunnelConnector<'a> { + TcpTunnelConnector { + host, + port, + so_mark, + connect_timeout, + dns_resolver, + } + } +} + +impl TunnelConnector for TcpTunnelConnector<'_> { + type Reader = OwnedReadHalf; + type Writer = OwnedWriteHalf; + + async fn connect(&self, remote: &Option) -> anyhow::Result<(Self::Reader, Self::Writer)> { + let (host, port) = match remote { + Some(remote) => (&remote.host, remote.port), + None => (self.host, self.port), + }; + + let stream = protocols::tcp::connect(host, port, self.so_mark, self.connect_timeout, self.dns_resolver).await?; + Ok(stream.into_split()) + } +} diff --git a/src/tunnel/connectors/udp.rs b/src/tunnel/connectors/udp.rs new file mode 100644 index 0000000..341dc91 --- /dev/null +++ b/src/tunnel/connectors/udp.rs @@ -0,0 +1,46 @@ +use crate::protocols; +use crate::protocols::dns::DnsResolver; +use crate::protocols::udp::WsUdpSocket; +use crate::tunnel::connectors::TunnelConnector; +use crate::tunnel::RemoteAddr; +use std::time::Duration; +use url::Host; + +pub struct UdpTunnelConnector<'a> { + host: &'a Host, + port: u16, + so_mark: Option, + connect_timeout: Duration, + dns_resolver: &'a DnsResolver, +} + +impl<'a> UdpTunnelConnector<'a> { + pub fn new( + host: &'a Host, + port: u16, + so_mark: Option, + connect_timeout: Duration, + dns_resolver: &'a DnsResolver, + ) -> UdpTunnelConnector<'a> { + UdpTunnelConnector { + host, + port, + so_mark, + connect_timeout, + dns_resolver, + } + } +} + +impl TunnelConnector for UdpTunnelConnector<'_> { + type Reader = WsUdpSocket; + type Writer = WsUdpSocket; + + async fn connect(&self, _: &Option) -> anyhow::Result<(Self::Reader, Self::Writer)> { + let stream = + protocols::udp::connect(self.host, self.port, self.connect_timeout, self.so_mark, self.dns_resolver) + .await?; + + Ok((stream.clone(), stream)) + } +} diff --git a/src/tunnel/listeners/mod.rs b/src/tunnel/listeners/mod.rs index daa7a9b..d42d03c 100644 --- a/src/tunnel/listeners/mod.rs +++ b/src/tunnel/listeners/mod.rs @@ -23,22 +23,23 @@ pub use udp::new_udp_listener; #[cfg(unix)] pub use unix_sock::UnixTunnelListener; +use crate::tunnel::RemoteAddr; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::Stream; -pub trait TunnelListener: - Stream> -{ +pub trait TunnelListener: Stream> { type Reader: AsyncRead + Send + 'static; type Writer: AsyncWrite + Send + 'static; + type OkReturn; // = ((Self::Reader, Self::Writer), RemoteAddr); } impl TunnelListener for T where - T: Stream>, + T: Stream>, R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { type Reader = R; type Writer = W; + type OkReturn = ((R, W), RemoteAddr); } diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 413784b..dfc3e4c 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -1,4 +1,5 @@ pub mod client; +pub mod connectors; pub mod listeners; pub mod server; pub mod tls_reloader; @@ -76,7 +77,7 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { (validation, DecodingKey::from_secret(JWT_SECRET)) }); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RemoteAddr { pub protocol: LocalProtocol, pub host: Host, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 60a552f..d9eef43 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,11 +1,10 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; use bytes::Bytes; -use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; +use futures_util::{pin_mut, FutureExt, StreamExt}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyStream, Either, StreamBody}; use std::cmp::min; -use std::fmt::Debug; use std::future::Future; use std::net::{IpAddr, SocketAddr}; @@ -15,7 +14,7 @@ use std::sync::Arc; use std::time::Duration; use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::{protocols, socks5, LocalProtocol, TlsServerConfig, WsServerConfig}; +use crate::{protocols, LocalProtocol, TlsServerConfig, WsServerConfig}; use hyper::body::{Frame, Incoming}; use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; @@ -28,18 +27,21 @@ use once_cell::sync::Lazy; use parking_lot::Mutex; use socket2::SockRef; -use crate::protocols::udp::UdpStream; -use crate::protocols::{http_proxy, tls, udp}; +use crate::protocols::tls; +use crate::protocols::udp::{UdpStream, UdpStreamWriter}; use crate::restrictions::config_reloader::RestrictionsRulesReloader; use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; -use crate::socks5::Socks5Stream; +use crate::tunnel::connectors::{TcpTunnelConnector, TunnelConnector, UdpTunnelConnector}; +use crate::tunnel::listeners::{ + new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, TunnelListener, +}; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpListener; use tokio::select; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::TlsAcceptor; @@ -56,140 +58,123 @@ async fn run_tunnel( ) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { match remote.protocol { LocalProtocol::Udp { timeout, .. } => { - let cnx = udp::connect( + let (rx, tx) = UdpTunnelConnector::new( &remote.host, remote.port, - timeout.unwrap_or(Duration::from_secs(10)), server_config.socket_so_mark, + timeout.unwrap_or(Duration::from_secs(10)), &server_config.dns_resolver, ) + .connect(&None) .await?; - Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) + Ok((remote, Box::pin(rx), Box::pin(tx))) } LocalProtocol::Tcp { proxy_protocol } => { - let mut socket = protocols::tcp::connect( + let (rx, mut tx) = TcpTunnelConnector::new( &remote.host, remote.port, server_config.socket_so_mark, Duration::from_secs(10), &server_config.dns_resolver, ) + .connect(&None) .await?; if proxy_protocol { let header = ppp::v2::Builder::with_addresses( ppp::v2::Version::Two | ppp::v2::Command::Proxy, ppp::v2::Protocol::Stream, - (client_address, socket.local_addr().unwrap()), + (client_address, tx.local_addr().unwrap()), ) .build() .unwrap(); - let _ = socket.write_all(&header).await; + let _ = tx.write_all(&header).await; } - let (rx, tx) = socket.into_split(); Ok((remote, Box::pin(rx), Box::pin(tx))) } LocalProtocol::ReverseTcp => { + type Item = ::OkReturn; #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = protocols::tcp::run_server(bind.parse()?, false); - let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let (local_rx, local_tx) = tcp.into_split(); - - let remote = RemoteAddr { - protocol: remote.protocol, - host: local_srv.0, - port: local_srv.1, + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + TcpTunnelListener::new(bind.parse()?, local_srv.clone(), false).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseUdp { timeout } => { + type Item = ((UdpStream, UdpStreamWriter), RemoteAddr); #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = - udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone())); - let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let udp_writer = udp.writer(); - let (local_rx, local_tx) = (udp, udp_writer); - - let remote = RemoteAddr { - protocol: remote.protocol, - host: local_srv.0, - port: local_srv.1, + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + new_udp_listener(bind.parse()?, local_srv.clone(), timeout).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseSocks5 { timeout, credentials } => { + type Item = ::OkReturn; #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = socks5::run_server(bind.parse()?, timeout, credentials); - let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let protocol = stream.local_protocol(); - let (local_rx, local_tx) = tokio::io::split(stream); - - let remote = RemoteAddr { - protocol, - host: local_srv.0, - port: local_srv.1, + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + Socks5TunnelListener::new(bind.parse()?, timeout, credentials).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseHttpProxy { timeout, credentials } => { + type Item = ::OkReturn; #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver<(TcpStream, (Host, u16))>>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = http_proxy::run_server(bind.parse()?, timeout, credentials); - let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let (local_rx, local_tx) = tokio::io::split(stream); - - let remote = RemoteAddr { - protocol: LocalProtocol::Tcp { proxy_protocol: false }, - host: local_srv.0, - port: local_srv.1, + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + HttpProxyTunnelListener::new(bind.parse()?, timeout, credentials, false).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } #[cfg(unix)] LocalProtocol::ReverseUnix { ref path } => { - use protocols::unix_sock; - use tokio::net::UnixStream; - + use crate::tunnel::listeners::UnixTunnelListener; + type Item = ::OkReturn; #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let listening_server = unix_sock::run_server(path); - let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let (local_rx, local_tx) = stream.into_split(); + let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let remote = RemoteAddr { - protocol: remote.protocol, - host: local_srv.0, - port: local_srv.1, - }; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } #[cfg(not(unix))] @@ -232,66 +217,6 @@ fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -> u16 { remote_port } -#[allow(clippy::type_complexity)] -async fn run_listening_server( - local_srv: &(Host, u16), - servers: &Mutex, u16), mpsc::Receiver>>, - gen_listening_server: Fut, -) -> anyhow::Result -where - Fut: Future>, - FutOut: Stream> + Send + 'static, - E: Debug + Send, - T: Send + 'static, -{ - let listening_server = servers.lock().remove(local_srv); - let mut listening_server = if let Some(listening_server) = listening_server { - listening_server - } else { - let listening_server = gen_listening_server.await?; - let send_timeout = Duration::from_secs(60 * 3); - let (tx, rx) = mpsc::channel::(1); - let fut = async move { - pin_mut!(listening_server); - loop { - select! { - biased; - cnx = listening_server.next() => { - match cnx { - None => break, - Some(Err(err)) => { - warn!("Error while listening for incoming connections {err:?}"); - continue; - } - Some(Ok(cnx)) => { - if tx.send_timeout(cnx, send_timeout).await.is_err() { - info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs()); - break; - } - } - } - }, - - _ = tx.closed() => { - break; - } - } - } - info!("Stopping listening reverse server"); - }; - - tokio::spawn(fut.instrument(Span::current())); - rx - }; - - let cnx = listening_server - .recv() - .await - .ok_or_else(|| anyhow!("listening reverse server stopped"))?; - servers.lock().insert(local_srv.clone(), listening_server); - Ok(cnx) -} - #[inline] fn extract_x_forwarded_for(req: &Request) -> Result, Response> { let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { @@ -957,3 +882,65 @@ pub async fn run_server(server_config: Arc, restrictions: Restri } } } + +#[allow(clippy::type_complexity)] +async fn run_listening_server( + local_srv: &(Host, u16), + servers: &Mutex< + HashMap< + (Host, u16), + mpsc::Receiver<((::Reader, ::Writer), RemoteAddr)>, + >, + >, + gen_listening_server: impl Future>, +) -> anyhow::Result<((::Reader, ::Writer), RemoteAddr)> +where + T: TunnelListener + Send + 'static, +{ + let listening_server = servers.lock().remove(local_srv); + let mut listening_server = if let Some(listening_server) = listening_server { + listening_server + } else { + let listening_server = gen_listening_server.await?; + let send_timeout = Duration::from_secs(60 * 3); + let (tx, rx) = mpsc::channel(1); + let fut = async move { + pin_mut!(listening_server); + loop { + select! { + biased; + cnx = listening_server.next() => { + match cnx { + None => break, + Some(Err(err)) => { + warn!("Error while listening for incoming connections {err:?}"); + continue; + } + Some(Ok(cnx)) => { + if tx.send_timeout(cnx, send_timeout).await.is_err() { + info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs()); + break; + } + } + } + }, + + _ = tx.closed() => { + break; + } + } + } + info!("Stopping listening reverse server"); + }; + + tokio::spawn(fut.instrument(Span::current())); + rx + }; + + let cnx = listening_server + .recv() + .await + .ok_or_else(|| anyhow!("listening reverse server stopped"))?; + servers.lock().insert(local_srv.clone(), listening_server); + Ok(cnx) +}