From f149b8190b02876cbe8d15c0a67242edfec499d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Wed, 31 Jul 2024 21:57:25 +0200 Subject: [PATCH] cleanup --- src/main.rs | 150 ++++++----------------- src/protocols/socks5/tcp_server.rs | 2 +- src/restrictions/types.rs | 2 +- src/tunnel/client/cnx_pool.rs | 2 +- src/tunnel/client/l4_transport_stream.rs | 61 +++++++++ src/tunnel/client/mod.rs | 1 + src/tunnel/connectors/sock5.rs | 4 +- src/tunnel/listeners/http_proxy.rs | 3 +- src/tunnel/listeners/stdio.rs | 3 +- src/tunnel/listeners/tcp.rs | 4 +- src/tunnel/listeners/tproxy.rs | 4 +- src/tunnel/listeners/udp.rs | 3 +- src/tunnel/listeners/unix_sock.rs | 3 +- src/tunnel/mod.rs | 120 +++++++++--------- src/tunnel/server/server.rs | 4 +- 15 files changed, 171 insertions(+), 195 deletions(-) create mode 100644 src/tunnel/client/l4_transport_stream.rs diff --git a/src/main.rs b/src/main.rs index 2c4d26f..c8b9d77 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,14 +12,14 @@ use crate::tunnel::listeners::{ new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, }; use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig}; -use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; +use crate::tunnel::{to_host_port, LocalProtocol, RemoteAddr, TransportAddr, TransportScheme}; +use anyhow::{anyhow, Context}; use base64::Engine; use clap::Parser; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; use log::debug; use parking_lot::{Mutex, RwLock}; -use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::fmt::Debug; use std::io; @@ -376,66 +376,11 @@ struct Server { http_proxy_password: Option, } -#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -enum LocalProtocol { - Tcp { - proxy_protocol: bool, - }, - Udp { - timeout: Option, - }, - Stdio, - Socks5 { - timeout: Option, - credentials: Option<(String, String)>, - }, - TProxyTcp, - TProxyUdp { - timeout: Option, - }, - HttpProxy { - timeout: Option, - credentials: Option<(String, String)>, - proxy_protocol: bool, - }, - ReverseTcp, - ReverseUdp { - timeout: Option, - }, - ReverseSocks5 { - timeout: Option, - credentials: Option<(String, String)>, - }, - ReverseHttpProxy { - timeout: Option, - credentials: Option<(String, String)>, - }, - ReverseUnix { - path: PathBuf, - }, - Unix { - path: PathBuf, - }, -} - -impl LocalProtocol { - pub const fn is_reverse_tunnel(&self) -> bool { - matches!( - self, - Self::ReverseTcp - | Self::ReverseUdp { .. } - | Self::ReverseSocks5 { .. } - | Self::ReverseUnix { .. } - | Self::ReverseHttpProxy { .. } - ) - } -} - #[derive(Clone, Debug)] pub struct LocalToRemote { local_protocol: LocalProtocol, local: SocketAddr, - remote: (Host, u16), + remote: (Host, u16), } fn parse_duration_sec(arg: &str) -> Result { @@ -773,24 +718,7 @@ async fn main() -> anyhow::Result<()> { TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url"); let tls = match transport_scheme { TransportScheme::Ws | TransportScheme::Http => None, - TransportScheme::Wss => Some(TlsClientConfig { - tls_connector: Arc::new(RwLock::new( - tls::tls_connector( - args.tls_verify_certificate, - transport_scheme.alpn_protocols(), - !args.tls_sni_disable, - tls_certificate, - tls_key, - ) - .expect("Cannot create tls connector"), - )), - tls_sni_override: args.tls_sni_override, - tls_verify_certificate: args.tls_verify_certificate, - tls_sni_disabled: args.tls_sni_disable, - tls_certificate_path: args.tls_certificate.clone(), - tls_key_path: args.tls_private_key.clone(), - }), - TransportScheme::Https => Some(TlsClientConfig { + TransportScheme::Wss | TransportScheme::Https => Some(TlsClientConfig { tls_connector: Arc::new(RwLock::new( tls::tls_connector( args.tls_verify_certificate, @@ -824,25 +752,8 @@ async fn main() -> anyhow::Result<()> { panic!("http headers file does not exists: {}", path.display()); } } - let http_proxy = if let Some(proxy) = args.http_proxy { - let mut proxy = if proxy.starts_with("http://") { - Url::parse(&proxy).expect("Invalid http proxy url") - } else { - Url::parse(&format!("http://{}", proxy)).expect("Invalid http proxy url") - }; - if let Some(login) = args.http_proxy_login { - proxy.set_username(login.as_str()).expect("Cannot set http proxy login"); - } - if let Some(password) = args.http_proxy_password { - proxy - .set_password(Some(password.as_str())) - .expect("Cannot set http proxy password"); - } - Some(proxy) - } else { - None - }; + let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?; let client_config = WsClientConfig { remote_addr: TransportAddr::new( TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), @@ -1176,26 +1087,7 @@ async fn main() -> anyhow::Result<()> { restriction_cfg }; - let http_proxy = if let Some(proxy) = args.http_proxy { - let mut proxy = if proxy.starts_with("http://") { - Url::parse(&proxy).expect("Invalid http proxy url") - } else { - Url::parse(&format!("http://{}", proxy)).expect("Invalid http proxy url") - }; - - if let Some(login) = args.http_proxy_login { - proxy.set_username(login.as_str()).expect("Cannot set http proxy login"); - } - if let Some(password) = args.http_proxy_password { - proxy - .set_password(Some(password.as_str())) - .expect("Cannot set http proxy password"); - } - Some(proxy) - } else { - None - }; - + let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?; let server_config = WsServerConfig { socket_so_mark: args.socket_so_mark, bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], @@ -1230,3 +1122,33 @@ async fn main() -> anyhow::Result<()> { tokio::signal::ctrl_c().await.unwrap(); Ok(()) } + +fn mk_http_proxy( + http_proxy: Option, + proxy_login: Option, + proxy_password: Option, +) -> anyhow::Result> { + let Some(proxy) = http_proxy else { + return Ok(None); + }; + + let mut proxy = if proxy.starts_with("http://") { + Url::parse(&proxy).with_context(|| "Invalid http proxy url")? + } else { + Url::parse(&format!("http://{}", proxy)).with_context(|| "Invalid http proxy url")? + }; + + if let Some(login) = proxy_login { + proxy + .set_username(login.as_str()) + .map_err(|_| anyhow!("Cannot set http proxy login"))?; + } + + if let Some(password) = proxy_password { + proxy + .set_password(Some(password.as_str())) + .map_err(|_| anyhow!("Cannot set http proxy password"))?; + } + + Ok(Some(proxy)) +} diff --git a/src/protocols/socks5/tcp_server.rs b/src/protocols/socks5/tcp_server.rs index 5a45729..115860a 100644 --- a/src/protocols/socks5/tcp_server.rs +++ b/src/protocols/socks5/tcp_server.rs @@ -1,5 +1,5 @@ use super::udp_server::{Socks5UdpStream, Socks5UdpStreamWriter}; -use crate::LocalProtocol; +use crate::tunnel::LocalProtocol; use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server}; use fast_socks5::util::target_addr::TargetAddr; diff --git a/src/restrictions/types.rs b/src/restrictions/types.rs index 9637528..f58897e 100644 --- a/src/restrictions/types.rs +++ b/src/restrictions/types.rs @@ -1,4 +1,4 @@ -use crate::LocalProtocol; +use crate::tunnel::LocalProtocol; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use regex::Regex; use serde::{Deserialize, Deserializer}; diff --git a/src/tunnel/client/cnx_pool.rs b/src/tunnel/client/cnx_pool.rs index 6a04985..0cef719 100644 --- a/src/tunnel/client/cnx_pool.rs +++ b/src/tunnel/client/cnx_pool.rs @@ -1,7 +1,7 @@ use crate::protocols; use crate::protocols::tls; +use crate::tunnel::client::l4_transport_stream::TransportStream; use crate::tunnel::client::WsClientConfig; -use crate::tunnel::TransportStream; use async_trait::async_trait; use bb8::ManageConnection; use std::ops::Deref; diff --git a/src/tunnel/client/l4_transport_stream.rs b/src/tunnel/client/l4_transport_stream.rs new file mode 100644 index 0000000..bbf55e1 --- /dev/null +++ b/src/tunnel/client/l4_transport_stream.rs @@ -0,0 +1,61 @@ +use std::io::{Error, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; + +pub enum TransportStream { + Plain(TcpStream), + Tls(TlsStream), +} + +impl AsyncRead for TransportStream { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), + Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TransportStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), + Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx), + Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx), + Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match &self { + Self::Plain(cnx) => cnx.is_write_vectored(), + Self::Tls(cnx) => cnx.is_write_vectored(), + } + } +} diff --git a/src/tunnel/client/mod.rs b/src/tunnel/client/mod.rs index c001fec..0a74ffd 100644 --- a/src/tunnel/client/mod.rs +++ b/src/tunnel/client/mod.rs @@ -2,6 +2,7 @@ mod client; mod cnx_pool; mod config; +pub mod l4_transport_stream; pub use client::WsClient; pub use config::TlsClientConfig; diff --git a/src/tunnel/connectors/sock5.rs b/src/tunnel/connectors/sock5.rs index 9837a52..66ac590 100644 --- a/src/tunnel/connectors/sock5.rs +++ b/src/tunnel/connectors/sock5.rs @@ -8,12 +8,12 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use url::Url; +use crate::protocols; use crate::protocols::dns::DnsResolver; use crate::protocols::udp; use crate::protocols::udp::WsUdpSocket; use crate::tunnel::connectors::TunnelConnector; -use crate::tunnel::RemoteAddr; -use crate::{protocols, LocalProtocol}; +use crate::tunnel::{LocalProtocol, RemoteAddr}; pub struct Socks5TunnelConnector<'a> { so_mark: Option, diff --git a/src/tunnel/listeners/http_proxy.rs b/src/tunnel/listeners/http_proxy.rs index e22a627..51f1f4d 100644 --- a/src/tunnel/listeners/http_proxy.rs +++ b/src/tunnel/listeners/http_proxy.rs @@ -1,7 +1,6 @@ use crate::protocols::http_proxy; use crate::protocols::http_proxy::HttpProxyListener; -use crate::tunnel::RemoteAddr; -use crate::LocalProtocol; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::net::SocketAddr; use std::pin::Pin; diff --git a/src/tunnel/listeners/stdio.rs b/src/tunnel/listeners/stdio.rs index 3960975..b9eb01e 100644 --- a/src/tunnel/listeners/stdio.rs +++ b/src/tunnel/listeners/stdio.rs @@ -1,6 +1,5 @@ use crate::protocols::stdio; -use crate::tunnel::RemoteAddr; -use crate::LocalProtocol; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::pin::Pin; use std::task::Poll; diff --git a/src/tunnel/listeners/tcp.rs b/src/tunnel/listeners/tcp.rs index 7358b7c..efb184c 100644 --- a/src/tunnel/listeners/tcp.rs +++ b/src/tunnel/listeners/tcp.rs @@ -1,5 +1,5 @@ -use crate::tunnel::RemoteAddr; -use crate::{protocols, LocalProtocol}; +use crate::protocols; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::net::SocketAddr; use std::pin::Pin; diff --git a/src/tunnel/listeners/tproxy.rs b/src/tunnel/listeners/tproxy.rs index 073a3b9..d912582 100644 --- a/src/tunnel/listeners/tproxy.rs +++ b/src/tunnel/listeners/tproxy.rs @@ -1,7 +1,7 @@ +use crate::protocols; use crate::protocols::udp; use crate::protocols::udp::{UdpStream, UdpStreamWriter}; -use crate::tunnel::{to_host_port, RemoteAddr}; -use crate::{protocols, LocalProtocol}; +use crate::tunnel::{to_host_port, LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::io; use std::net::SocketAddr; diff --git a/src/tunnel/listeners/udp.rs b/src/tunnel/listeners/udp.rs index e23027e..2df81b8 100644 --- a/src/tunnel/listeners/udp.rs +++ b/src/tunnel/listeners/udp.rs @@ -1,7 +1,6 @@ use crate::protocols::udp; use crate::protocols::udp::{UdpStream, UdpStreamWriter}; -use crate::tunnel::RemoteAddr; -use crate::LocalProtocol; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::io; use std::net::SocketAddr; diff --git a/src/tunnel/listeners/unix_sock.rs b/src/tunnel/listeners/unix_sock.rs index 884956d..392fd4f 100644 --- a/src/tunnel/listeners/unix_sock.rs +++ b/src/tunnel/listeners/unix_sock.rs @@ -1,7 +1,6 @@ use crate::protocols::unix_sock; use crate::protocols::unix_sock::UnixListenerStream; -use crate::tunnel::RemoteAddr; -use crate::LocalProtocol; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use anyhow::{anyhow, Context}; use std::path::Path; use std::pin::Pin; diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index bac5dfc..15cc12a 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -5,21 +5,17 @@ pub mod server; mod tls_reloader; mod transport; -use crate::{LocalProtocol, TlsClientConfig}; +use crate::TlsClientConfig; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::fmt::{Debug, Display, Formatter}; -use std::io::{Error, IoSlice}; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; -use std::pin::Pin; +use std::path::PathBuf; use std::str::FromStr; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; +use std::time::Duration; use url::Host; use uuid::Uuid; @@ -73,6 +69,61 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { (validation, DecodingKey::from_secret(JWT_SECRET)) }); +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum LocalProtocol { + Tcp { + proxy_protocol: bool, + }, + Udp { + timeout: Option, + }, + Stdio, + Socks5 { + timeout: Option, + credentials: Option<(String, String)>, + }, + TProxyTcp, + TProxyUdp { + timeout: Option, + }, + HttpProxy { + timeout: Option, + credentials: Option<(String, String)>, + proxy_protocol: bool, + }, + ReverseTcp, + ReverseUdp { + timeout: Option, + }, + ReverseSocks5 { + timeout: Option, + credentials: Option<(String, String)>, + }, + ReverseHttpProxy { + timeout: Option, + credentials: Option<(String, String)>, + }, + ReverseUnix { + path: PathBuf, + }, + Unix { + path: PathBuf, + }, +} + +impl LocalProtocol { + pub const fn is_reverse_tunnel(&self) -> bool { + matches!( + self, + Self::ReverseTcp + | Self::ReverseUdp { .. } + | Self::ReverseSocks5 { .. } + | Self::ReverseUnix { .. } + | Self::ReverseHttpProxy { .. } + ) + } +} + #[derive(Debug, Clone)] pub struct RemoteAddr { pub protocol: LocalProtocol, @@ -245,61 +296,6 @@ impl TryFrom for RemoteAddr { } } -pub enum TransportStream { - Plain(TcpStream), - Tls(TlsStream), -} - -impl AsyncRead for TransportStream { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), - Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for TransportStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), - Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx), - Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx), - } - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx), - Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx), - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), - Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), - } - } - - fn is_write_vectored(&self) -> bool { - match &self { - Self::Plain(cnx) => cnx.is_write_vectored(), - Self::Tls(cnx) => cnx.is_write_vectored(), - } - } -} - pub fn to_host_port(addr: SocketAddr) -> (Host, u16) { match addr.ip() { IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()), diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index d0b258e..5914e62 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -15,8 +15,8 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use crate::tunnel::RemoteAddr; -use crate::{protocols, LocalProtocol}; +use crate::protocols; +use crate::tunnel::{LocalProtocol, RemoteAddr}; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn;