From 38cb7ed5f84fea638450604f47e46147ca12c2d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sun, 28 Jul 2024 13:14:08 +0200 Subject: [PATCH] refacto: split into modules --- src/main.rs | 109 ++--- src/protocols/dns/mod.rs | 3 + src/{dns.rs => protocols/dns/resolver.rs} | 8 +- src/protocols/http_proxy/mod.rs | 4 + .../http_proxy/server.rs} | 4 +- src/protocols/mod.rs | 9 + src/protocols/socks5/mod.rs | 6 + .../socks5/tcp_server.rs} | 8 +- .../socks5/udp_server.rs} | 0 src/protocols/stdio/mod.rs | 9 + src/protocols/stdio/server_unix.rs | 27 ++ src/protocols/stdio/server_windows.rs | 82 ++++ src/protocols/tcp/mod.rs | 6 + src/{tcp.rs => protocols/tcp/server.rs} | 2 +- src/protocols/tls/mod.rs | 10 + src/{tls.rs => protocols/tls/server.rs} | 0 src/{tls_utils.rs => protocols/tls/utils.rs} | 0 src/protocols/udp/mod.rs | 11 + src/{udp.rs => protocols/udp/server.rs} | 10 +- src/protocols/unix_sock/mod.rs | 4 + .../unix_sock/server.rs} | 0 src/stdio.rs | 116 ------ src/tunnel/client.rs | 5 +- src/tunnel/listeners/http_proxy.rs | 54 +++ src/tunnel/listeners/mod.rs | 44 +++ src/tunnel/listeners/socks5.rs | 47 +++ src/tunnel/listeners/stdio.rs | 70 ++++ src/tunnel/listeners/tcp.rs | 57 +++ src/tunnel/listeners/tproxy.rs | 107 +++++ src/tunnel/listeners/udp.rs | 66 ++++ src/tunnel/listeners/unix_sock.rs | 58 +++ src/tunnel/mod.rs | 8 +- src/tunnel/server.rs | 20 +- src/tunnel/tls_reloader.rs | 3 +- src/types.rs | 374 ------------------ 35 files changed, 745 insertions(+), 596 deletions(-) create mode 100644 src/protocols/dns/mod.rs rename src/{dns.rs => protocols/dns/resolver.rs} (98%) create mode 100644 src/protocols/http_proxy/mod.rs rename src/{http_proxy.rs => protocols/http_proxy/server.rs} (97%) create mode 100644 src/protocols/mod.rs create mode 100644 src/protocols/socks5/mod.rs rename src/{socks5.rs => protocols/socks5/tcp_server.rs} (98%) rename src/{socks5_udp.rs => protocols/socks5/udp_server.rs} (100%) create mode 100644 src/protocols/stdio/mod.rs create mode 100644 src/protocols/stdio/server_unix.rs create mode 100644 src/protocols/stdio/server_windows.rs create mode 100644 src/protocols/tcp/mod.rs rename src/{tcp.rs => protocols/tcp/server.rs} (99%) create mode 100644 src/protocols/tls/mod.rs rename src/{tls.rs => protocols/tls/server.rs} (100%) rename src/{tls_utils.rs => protocols/tls/utils.rs} (100%) create mode 100644 src/protocols/udp/mod.rs rename src/{udp.rs => protocols/udp/server.rs} (99%) create mode 100644 src/protocols/unix_sock/mod.rs rename src/{unix_socket.rs => protocols/unix_sock/server.rs} (100%) delete mode 100644 src/stdio.rs create mode 100644 src/tunnel/listeners/http_proxy.rs create mode 100644 src/tunnel/listeners/mod.rs create mode 100644 src/tunnel/listeners/socks5.rs create mode 100644 src/tunnel/listeners/stdio.rs create mode 100644 src/tunnel/listeners/tcp.rs create mode 100644 src/tunnel/listeners/tproxy.rs create mode 100644 src/tunnel/listeners/udp.rs create mode 100644 src/tunnel/listeners/unix_sock.rs delete mode 100644 src/types.rs diff --git a/src/main.rs b/src/main.rs index 1e3821c..3977779 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,7 @@ -mod dns; mod embedded_certificate; -mod http_proxy; +mod protocols; mod restrictions; -mod socks5; -mod socks5_udp; -mod stdio; -mod tcp; -mod tls; -mod tls_utils; mod tunnel; -mod types; -mod udp; -#[cfg(unix)] -mod unix_socket; use anyhow::anyhow; use base64::Engine; @@ -41,16 +30,15 @@ use tokio_rustls::TlsConnector; use tracing::{error, info}; -use crate::dns::DnsResolver; +use crate::protocols::dns::DnsResolver; +use crate::protocols::udp::MyUdpSocket; +use crate::protocols::{socks5, tls, udp}; use crate::restrictions::types::RestrictionsRules; -use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; +use crate::tunnel::listeners::{ + new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, +}; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; -use crate::types::{ - HttpProxyTunnelListener, Socks5TunnelListener, StdioTunnelListener, TProxyUdpTunnelListener, TcpTunnelListener, - TproxyTcpTunnelListener, UdpTunnelListener, UnixTunnelListener, -}; -use crate::udp::MyUdpSocket; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -823,7 +811,7 @@ impl WsClientConfig { } #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { let args = Wstunnel::parse(); // Setup logging @@ -870,8 +858,8 @@ async fn main() { // to be the common name (CN) of the client's certificate. tls_certificate .as_ref() - .and_then(|certs| find_leaf_certificate(certs.as_slice())) - .and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) + .and_then(|certs| tls::find_leaf_certificate(certs.as_slice())) + .and_then(|leaf_cert| tls::cn_from_certificate(&leaf_cert)) .unwrap_or(args.http_upgrade_path_prefix) } else { args.http_upgrade_path_prefix @@ -1004,7 +992,7 @@ async fn main() { let remote = tunnel.remote.clone(); let cfg = client_config.clone(); let connect_to_dest = |_| async { - tcp::connect( + protocols::tcp::connect( &remote.0, remote.1, cfg.socket_so_mark, @@ -1081,11 +1069,15 @@ async fn main() { }; match remote.protocol { - LocalProtocol::Tcp { proxy_protocol: _ } => { - tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) - .await - .map(|s| Box::new(s) as Box) - } + LocalProtocol::Tcp { proxy_protocol: _ } => protocols::tcp::connect( + &remote.host, + remote.port, + so_mark, + timeout, + dns_resolver, + ) + .await + .map(|s| Box::new(s) as Box), LocalProtocol::Udp { .. } => { udp::connect(&remote.host, remote.port, timeout, so_mark, dns_resolver) .await @@ -1125,7 +1117,8 @@ async fn main() { return Err(anyhow!("Missing remote destination for reverse socks5")); }; - tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver).await + protocols::tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) + .await } }; @@ -1143,7 +1136,7 @@ async fn main() { let remote = tunnel.remote.clone(); let cfg = client_config.clone(); let connect_to_dest = |_| async { - tcp::connect( + protocols::tcp::connect( &remote.0, remote.1, cfg.socket_so_mark, @@ -1188,10 +1181,8 @@ async fn main() { match &tunnel.local_protocol { LocalProtocol::Tcp { proxy_protocol } => { - let server = tcp::run_server(tunnel.local, false) - .await - .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err)); - let server = TcpTunnelListener::new(server, tunnel.remote.clone(), *proxy_protocol); + let server = + TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1200,10 +1191,8 @@ async fn main() { } #[cfg(target_os = "linux")] LocalProtocol::TProxyTcp => { - let server = tcp::run_server(tunnel.local, true).await.unwrap_or_else(|err| { - panic!("Cannot start TProxy TCP server on {}: {}", tunnel.local, err) - }); - let server = TproxyTcpTunnelListener::new(server, false); // TODO: support proxy protocol + use crate::tunnel::listeners::TproxyTcpTunnelListener; + let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { @@ -1213,11 +1202,8 @@ async fn main() { } #[cfg(unix)] LocalProtocol::Unix { path } => { - let server = unix_socket::run_server(path).await.unwrap_or_else(|err| { - panic!("Cannot start Unix domain server on {}: {}", tunnel.local, err) - }); - - let server = UnixTunnelListener::new(server, tunnel.remote.clone(), false); // TODO: support proxy protocol + use crate::tunnel::listeners::UnixTunnelListener; + let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1231,14 +1217,8 @@ async fn main() { #[cfg(target_os = "linux")] LocalProtocol::TProxyUdp { timeout } => { - let server = - udp::run_server(tunnel.local, *timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) - .await - .unwrap_or_else(|err| { - panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err) - }); - - let server = TProxyUdpTunnelListener::new(server, *timeout); + use crate::tunnel::listeners::new_tproxy_udp; + let server = new_tproxy_udp(tunnel.local, *timeout).await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1250,10 +1230,7 @@ async fn main() { panic!("Transparent proxy is not available for non Linux platform") } LocalProtocol::Udp { timeout } => { - let server = udp::run_server(tunnel.local, *timeout, |_| Ok(()), |s| Ok(s.clone())) - .await - .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)); - let server = UdpTunnelListener::new(server, tunnel.remote.clone(), *timeout); + let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { @@ -1262,11 +1239,7 @@ async fn main() { }); } LocalProtocol::Socks5 { timeout, credentials } => { - let server = socks5::run_server(tunnel.local, *timeout, credentials.clone()) - .await - .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)); - - let server = Socks5TunnelListener::new(server); + let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1278,13 +1251,9 @@ async fn main() { credentials, proxy_protocol, } => { - let server = http_proxy::run_server(tunnel.local, *timeout, credentials.clone()) - .await - .unwrap_or_else(|err| { - panic!("Cannot start http proxy server on {}: {}", tunnel.local, err) - }); - - let server = HttpProxyTunnelListener::new(server, *proxy_protocol); + let server = + HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol) + .await?; tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1293,10 +1262,7 @@ async fn main() { } LocalProtocol::Stdio => { - let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| { - panic!("Cannot start STDIO server: {}", err); - }); - let server = StdioTunnelListener::new(server, tunnel.remote.clone(), false); + let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); @@ -1410,4 +1376,5 @@ async fn main() { } tokio::signal::ctrl_c().await.unwrap(); + Ok(()) } diff --git a/src/protocols/dns/mod.rs b/src/protocols/dns/mod.rs new file mode 100644 index 0000000..27d638e --- /dev/null +++ b/src/protocols/dns/mod.rs @@ -0,0 +1,3 @@ +mod resolver; + +pub use resolver::DnsResolver; diff --git a/src/dns.rs b/src/protocols/dns/resolver.rs similarity index 98% rename from src/dns.rs rename to src/protocols/dns/resolver.rs index 760b404..486ed19 100644 --- a/src/dns.rs +++ b/src/protocols/dns/resolver.rs @@ -1,4 +1,4 @@ -use crate::tcp; +use crate::protocols; use anyhow::{anyhow, Context}; use futures_util::{FutureExt, TryFutureExt}; use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; @@ -205,7 +205,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark { }; if let Some(proxy) = &proxy { - tcp::connect_with_http_proxy( + protocols::tcp::connect_with_http_proxy( proxy, &host, server_addr.port(), @@ -217,7 +217,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark { .map(|s| s.map(AsyncIoTokioAsStd)) .await } else { - tcp::connect( + protocols::tcp::connect( &host, server_addr.port(), so_mark, @@ -261,7 +261,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark { #[cfg(test)] mod tests { - use crate::dns::sort_socket_addrs; + use super::*; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; #[test] diff --git a/src/protocols/http_proxy/mod.rs b/src/protocols/http_proxy/mod.rs new file mode 100644 index 0000000..4ec7988 --- /dev/null +++ b/src/protocols/http_proxy/mod.rs @@ -0,0 +1,4 @@ +mod server; + +pub use server::run_server; +pub use server::HttpProxyListener; diff --git a/src/http_proxy.rs b/src/protocols/http_proxy/server.rs similarity index 97% rename from src/http_proxy.rs rename to src/protocols/http_proxy/server.rs index ab81961..81af076 100644 --- a/src/http_proxy.rs +++ b/src/protocols/http_proxy/server.rs @@ -33,7 +33,7 @@ impl Stream for HttpProxyListener { } } -fn server_client( +fn handle_request( credentials: &Option, dest: &Mutex<(Host, u16)>, req: Request, @@ -112,7 +112,7 @@ pub async fn run_server( let forward_to = Mutex::new((Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0)), 0)); let conn_fut = http1.serve_connection( hyper_util::rt::TokioIo::new(&mut stream), - service_fn(|req| server_client(&auth_header, &forward_to, req)), + service_fn(|req| handle_request(&auth_header, &forward_to, req)), ); match conn_fut.await { Ok(_) => return Some((Ok((stream, forward_to.into_inner())), (listener, http1, auth_header))), diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs new file mode 100644 index 0000000..9177924 --- /dev/null +++ b/src/protocols/mod.rs @@ -0,0 +1,9 @@ +pub mod dns; +pub mod http_proxy; +pub mod socks5; +pub mod stdio; +pub mod tcp; +pub mod tls; +pub mod udp; +#[cfg(unix)] +pub mod unix_sock; diff --git a/src/protocols/socks5/mod.rs b/src/protocols/socks5/mod.rs new file mode 100644 index 0000000..3214c0b --- /dev/null +++ b/src/protocols/socks5/mod.rs @@ -0,0 +1,6 @@ +mod tcp_server; +mod udp_server; + +pub use tcp_server::run_server; +pub use tcp_server::Socks5Listener; +pub use tcp_server::Socks5Stream; diff --git a/src/socks5.rs b/src/protocols/socks5/tcp_server.rs similarity index 98% rename from src/socks5.rs rename to src/protocols/socks5/tcp_server.rs index 084d7eb..158fca8 100644 --- a/src/socks5.rs +++ b/src/protocols/socks5/tcp_server.rs @@ -1,5 +1,5 @@ -use crate::socks5_udp::Socks5UdpStream; -use crate::{socks5_udp, LocalProtocol}; +use super::udp_server::Socks5UdpStream; +use crate::LocalProtocol; use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server}; use fast_socks5::util::target_addr::TargetAddr; @@ -29,7 +29,7 @@ pub enum Socks5Stream { impl Socks5Stream { pub fn local_protocol(&self) -> LocalProtocol { match self { - Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, + Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol Self::Udp(s) => LocalProtocol::Udp { timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), }, @@ -72,7 +72,7 @@ pub async fn run_server( cfg.set_execute_command(false); cfg.set_udp_support(true); - let udp_server = socks5_udp::run_server(bind, timeout).await?; + let udp_server = super::udp_server::run_server(bind, timeout).await?; let server = server.with_config(cfg); let stream = stream::unfold((server, Box::pin(udp_server)), move |(server, mut udp_server)| async move { let mut acceptor = server.incoming(); diff --git a/src/socks5_udp.rs b/src/protocols/socks5/udp_server.rs similarity index 100% rename from src/socks5_udp.rs rename to src/protocols/socks5/udp_server.rs diff --git a/src/protocols/stdio/mod.rs b/src/protocols/stdio/mod.rs new file mode 100644 index 0000000..05820c5 --- /dev/null +++ b/src/protocols/stdio/mod.rs @@ -0,0 +1,9 @@ +#[cfg(unix)] +mod server_unix; +#[cfg(not(unix))] +mod server_windows; + +#[cfg(unix)] +pub use server_unix::run_server; +#[cfg(not(unix))] +pub use server_windows::run_server; diff --git a/src/protocols/stdio/server_unix.rs b/src/protocols/stdio/server_unix.rs new file mode 100644 index 0000000..dfa4292 --- /dev/null +++ b/src/protocols/stdio/server_unix.rs @@ -0,0 +1,27 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::sync::oneshot; +use tokio_fd::AsyncFd; +use tracing::info; + +pub struct WsStdin { + stdin: AsyncFd, + _receiver: oneshot::Receiver<()>, +} + +impl AsyncRead for WsStdin { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf) + } +} + +pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> { + info!("Starting STDIO server"); + + let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; + let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; + let (tx, rx) = oneshot::channel::<()>(); + + Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx)) +} diff --git a/src/protocols/stdio/server_windows.rs b/src/protocols/stdio/server_windows.rs new file mode 100644 index 0000000..e5a3c27 --- /dev/null +++ b/src/protocols/stdio/server_windows.rs @@ -0,0 +1,82 @@ +use bytes::BytesMut; +use log::error; +use parking_lot::Mutex; +use scopeguard::guard; +use std::io::{Read, Write}; +use std::sync::Arc; +use std::{io, thread}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use tokio::sync::oneshot; +use tokio::task::LocalSet; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_util::io::StreamReader; +use tracing::info; + +pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> { + info!("Starting STDIO server. Press ctrl+c twice to exit"); + + crossterm::terminal::enable_raw_mode()?; + + let stdin = io::stdin(); + let (send, recv) = tokio::sync::mpsc::unbounded_channel(); + let (abort_tx, abort_rx) = oneshot::channel::<()>(); + let abort_rx = Arc::new(Mutex::new(abort_rx)); + let abort_rx2 = abort_rx.clone(); + thread::spawn(move || { + let _restore_terminal = guard((), move |_| { + let _ = crossterm::terminal::disable_raw_mode(); + abort_rx.lock().close(); + }); + let stdin = stdin; + let mut stdin = stdin.lock(); + let mut buf = [0u8; 65536]; + + loop { + let n = stdin.read(&mut buf).unwrap_or(0); + if n == 0 || (n == 1 && buf[0] == 3) { + // ctrl+c send char 3 + break; + } + if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) { + error!("Failed send inout: {:?}", err); + break; + } + } + }); + let stdin = StreamReader::new(UnboundedReceiverStream::new(recv)); + + let (stdout, mut recv) = tokio::io::duplex(65536); + let rt = tokio::runtime::Handle::current(); + thread::spawn(move || { + let task = async move { + let _restore_terminal = guard((), move |_| { + let _ = crossterm::terminal::disable_raw_mode(); + abort_rx2.lock().close(); + }); + let mut stdout = io::stdout().lock(); + let mut buf = [0u8; 65536]; + loop { + let Ok(n) = recv.read(&mut buf).await else { + break; + }; + + if n == 0 { + break; + } + + if let Err(err) = stdout.write_all(&buf[..n]) { + error!("Failed to write to stdout: {:?}", err); + break; + }; + let _ = stdout.flush(); + } + }; + + let local = LocalSet::new(); + local.spawn_local(task); + + rt.block_on(local); + }); + + Ok(((stdin, stdout), abort_tx)) +} diff --git a/src/protocols/tcp/mod.rs b/src/protocols/tcp/mod.rs new file mode 100644 index 0000000..b3e6f8d --- /dev/null +++ b/src/protocols/tcp/mod.rs @@ -0,0 +1,6 @@ +mod server; + +pub use server::configure_socket; +pub use server::connect; +pub use server::connect_with_http_proxy; +pub use server::run_server; diff --git a/src/tcp.rs b/src/protocols/tcp/server.rs similarity index 99% rename from src/tcp.rs rename to src/protocols/tcp/server.rs index bd02ff3..2e6153a 100644 --- a/src/tcp.rs +++ b/src/protocols/tcp/server.rs @@ -2,13 +2,13 @@ use anyhow::{anyhow, Context}; use std::{io, vec}; use tokio::task::JoinSet; -use crate::dns::DnsResolver; use base64::Engine; use bytes::BytesMut; use log::warn; use socket2::{SockRef, TcpKeepalive}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use crate::protocols::dns::DnsResolver; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; diff --git a/src/protocols/tls/mod.rs b/src/protocols/tls/mod.rs new file mode 100644 index 0000000..f3e9097 --- /dev/null +++ b/src/protocols/tls/mod.rs @@ -0,0 +1,10 @@ +mod server; +mod utils; + +pub use server::connect; +pub use server::load_certificates_from_pem; +pub use server::load_private_key_from_file; +pub use server::tls_acceptor; +pub use server::tls_connector; +pub use utils::cn_from_certificate; +pub use utils::find_leaf_certificate; diff --git a/src/tls.rs b/src/protocols/tls/server.rs similarity index 100% rename from src/tls.rs rename to src/protocols/tls/server.rs diff --git a/src/tls_utils.rs b/src/protocols/tls/utils.rs similarity index 100% rename from src/tls_utils.rs rename to src/protocols/tls/utils.rs diff --git a/src/protocols/udp/mod.rs b/src/protocols/udp/mod.rs new file mode 100644 index 0000000..e758958 --- /dev/null +++ b/src/protocols/udp/mod.rs @@ -0,0 +1,11 @@ +mod server; + +#[cfg(target_os = "linux")] +pub use server::configure_tproxy; +pub use server::connect; +#[cfg(target_os = "linux")] +pub use server::mk_send_socket_tproxy; +pub use server::run_server; +pub use server::MyUdpSocket; +pub use server::UdpStream; +pub use server::UdpStreamWriter; diff --git a/src/udp.rs b/src/protocols/udp/server.rs similarity index 99% rename from src/udp.rs rename to src/protocols/udp/server.rs index 4162e56..a995f49 100644 --- a/src/udp.rs +++ b/src/protocols/udp/server.rs @@ -5,9 +5,9 @@ use parking_lot::RwLock; use pin_project::{pin_project, pinned_drop}; use std::collections::HashMap; use std::future::Future; -use std::{io, task}; use std::io::{Error, ErrorKind}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::{io, task}; use tokio::task::JoinSet; use log::warn; @@ -20,7 +20,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; use tokio::sync::futures::Notified; -use crate::dns::DnsResolver; +use crate::protocols::dns::DnsResolver; use tokio::sync::Notify; use tokio::time::{sleep, timeout, Interval}; use tracing::{debug, error, info}; @@ -173,11 +173,7 @@ impl UdpStream { } impl AsyncRead for UdpStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - obuf: &mut ReadBuf<'_>, - ) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, obuf: &mut ReadBuf<'_>) -> Poll> { let mut project = self.project(); // Look that the timeout for client has not elapsed if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() { diff --git a/src/protocols/unix_sock/mod.rs b/src/protocols/unix_sock/mod.rs new file mode 100644 index 0000000..36aec1a --- /dev/null +++ b/src/protocols/unix_sock/mod.rs @@ -0,0 +1,4 @@ +mod server; + +pub use server::run_server; +pub use server::UnixListenerStream; diff --git a/src/unix_socket.rs b/src/protocols/unix_sock/server.rs similarity index 100% rename from src/unix_socket.rs rename to src/protocols/unix_sock/server.rs diff --git a/src/stdio.rs b/src/stdio.rs deleted file mode 100644 index aa44435..0000000 --- a/src/stdio.rs +++ /dev/null @@ -1,116 +0,0 @@ -#[cfg(unix)] -pub mod server { - use std::pin::Pin; - use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, ReadBuf}; - use tokio::sync::oneshot; - use tokio_fd::AsyncFd; - use tracing::info; - - pub struct WsStdin { - stdin: AsyncFd, - _receiver: oneshot::Receiver<()>, - } - - impl AsyncRead for WsStdin { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf) - } - } - - pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> { - info!("Starting STDIO server"); - - let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; - let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; - let (tx, rx) = oneshot::channel::<()>(); - - Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx)) - } -} - -#[cfg(not(unix))] -pub mod server { - use bytes::BytesMut; - use log::error; - use parking_lot::Mutex; - use scopeguard::guard; - use std::io::{Read, Write}; - use std::sync::Arc; - use std::{io, thread}; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; - use tokio::sync::oneshot; - use tokio::task::LocalSet; - use tokio_stream::wrappers::UnboundedReceiverStream; - use tokio_util::io::StreamReader; - use tracing::info; - - pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> { - info!("Starting STDIO server. Press ctrl+c twice to exit"); - - crossterm::terminal::enable_raw_mode()?; - - let stdin = io::stdin(); - let (send, recv) = tokio::sync::mpsc::unbounded_channel(); - let (abort_tx, abort_rx) = oneshot::channel::<()>(); - let abort_rx = Arc::new(Mutex::new(abort_rx)); - let abort_rx2 = abort_rx.clone(); - thread::spawn(move || { - let _restore_terminal = guard((), move |_| { - let _ = crossterm::terminal::disable_raw_mode(); - abort_rx.lock().close(); - }); - let stdin = stdin; - let mut stdin = stdin.lock(); - let mut buf = [0u8; 65536]; - - loop { - let n = stdin.read(&mut buf).unwrap_or(0); - if n == 0 || (n == 1 && buf[0] == 3) { - // ctrl+c send char 3 - break; - } - if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) { - error!("Failed send inout: {:?}", err); - break; - } - } - }); - let stdin = StreamReader::new(UnboundedReceiverStream::new(recv)); - - let (stdout, mut recv) = tokio::io::duplex(65536); - let rt = tokio::runtime::Handle::current(); - thread::spawn(move || { - let task = async move { - let _restore_terminal = guard((), move |_| { - let _ = crossterm::terminal::disable_raw_mode(); - abort_rx2.lock().close(); - }); - let mut stdout = io::stdout().lock(); - let mut buf = [0u8; 65536]; - loop { - let Ok(n) = recv.read(&mut buf).await else { - break; - }; - - if n == 0 { - break; - } - - if let Err(err) = stdout.write_all(&buf[..n]) { - error!("Failed to write to stdout: {:?}", err); - break; - }; - let _ = stdout.flush(); - } - }; - - let local = LocalSet::new(); - local.spawn_local(task); - - rt.block_on(local); - }); - - Ok(((stdin, stdout), abort_tx)) - } -} diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 4dea7e6..2759d1d 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,6 +1,6 @@ use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; +use crate::tunnel::listeners::TunnelListener; use crate::tunnel::transport::{TunnelReader, TunnelWriter}; -use crate::types::TunnelListener; use crate::{tunnel, WsClientConfig}; use futures_util::pin_mut; use hyper::header::COOKIE; @@ -57,8 +57,7 @@ where Ok(()) } -pub async fn run_tunnel(client_config: Arc, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> -{ +pub async fn run_tunnel(client_config: Arc, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> { pin_mut!(incoming_cnx); while let Some(cnx) = incoming_cnx.next().await { let (cnx_stream, remote_addr) = match cnx { diff --git a/src/tunnel/listeners/http_proxy.rs b/src/tunnel/listeners/http_proxy.rs new file mode 100644 index 0000000..e22a627 --- /dev/null +++ b/src/tunnel/listeners/http_proxy.rs @@ -0,0 +1,54 @@ +use crate::protocols::http_proxy; +use crate::protocols::http_proxy::HttpProxyListener; +use crate::tunnel::RemoteAddr; +use crate::LocalProtocol; +use anyhow::{anyhow, Context}; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{ready, Poll}; +use std::time::Duration; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio_stream::Stream; + +pub struct HttpProxyTunnelListener { + listener: HttpProxyListener, + proxy_protocol: bool, +} + +impl HttpProxyTunnelListener { + pub async fn new( + bind_addr: SocketAddr, + timeout: Option, + credentials: Option<(String, String)>, + proxy_protocol: bool, + ) -> anyhow::Result { + let listener = http_proxy::run_server(bind_addr, timeout, credentials) + .await + .with_context(|| anyhow!("Cannot start http proxy server on {}", bind_addr))?; + + Ok(Self { + listener, + proxy_protocol, + }) + } +} + +impl Stream for HttpProxyTunnelListener { + type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); + let ret = match ret { + Some(Ok((stream, (host, port)))) => { + let protocol = LocalProtocol::Tcp { + proxy_protocol: this.proxy_protocol, + }; + Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port }))) + } + Some(Err(err)) => Some(Err(err)), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/mod.rs b/src/tunnel/listeners/mod.rs new file mode 100644 index 0000000..daa7a9b --- /dev/null +++ b/src/tunnel/listeners/mod.rs @@ -0,0 +1,44 @@ +mod tcp; +#[cfg(target_os = "linux")] +mod tproxy; + +mod http_proxy; +mod socks5; +mod stdio; +mod udp; +#[cfg(unix)] +mod unix_sock; + +#[cfg(target_os = "linux")] +pub use tproxy::new_tproxy_udp; +#[cfg(target_os = "linux")] +pub use tproxy::TproxyTcpTunnelListener; + +pub use http_proxy::HttpProxyTunnelListener; +pub use socks5::Socks5TunnelListener; +pub use stdio::new_stdio_listener; +pub use tcp::TcpTunnelListener; +pub use udp::new_udp_listener; + +#[cfg(unix)] +pub use unix_sock::UnixTunnelListener; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::Stream; + +pub trait TunnelListener: + Stream> +{ + type Reader: AsyncRead + Send + 'static; + type Writer: AsyncWrite + Send + 'static; +} + +impl TunnelListener for T +where + T: Stream>, + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + type Reader = R; + type Writer = W; +} diff --git a/src/tunnel/listeners/socks5.rs b/src/tunnel/listeners/socks5.rs new file mode 100644 index 0000000..9c30fde --- /dev/null +++ b/src/tunnel/listeners/socks5.rs @@ -0,0 +1,47 @@ +use crate::protocols::socks5; +use crate::protocols::socks5::{Socks5Listener, Socks5Stream}; +use crate::tunnel::RemoteAddr; +use anyhow::{anyhow, Context}; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{ready, Poll}; +use std::time::Duration; +use tokio::io::{ReadHalf, WriteHalf}; +use tokio_stream::Stream; + +pub struct Socks5TunnelListener { + listener: Socks5Listener, +} + +impl Socks5TunnelListener { + pub async fn new( + bind_addr: SocketAddr, + timeout: Option, + credentials: Option<(String, String)>, + ) -> anyhow::Result { + let listener = socks5::run_server(bind_addr, timeout, credentials) + .await + .with_context(|| anyhow!("Cannot start Socks5 server on {}", bind_addr))?; + + Ok(Self { listener }) + } +} + +impl Stream for Socks5TunnelListener { + type Item = anyhow::Result<((ReadHalf, WriteHalf), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); + // TODO: Check if tokio::io::split can be avoided + let ret = match ret { + Some(Ok((stream, (host, port)))) => { + let protocol = stream.local_protocol(); + Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port }))) + } + Some(Err(err)) => Some(Err(err)), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/stdio.rs b/src/tunnel/listeners/stdio.rs new file mode 100644 index 0000000..3960975 --- /dev/null +++ b/src/tunnel/listeners/stdio.rs @@ -0,0 +1,70 @@ +use crate::protocols::stdio; +use crate::tunnel::RemoteAddr; +use crate::LocalProtocol; +use anyhow::{anyhow, Context}; +use std::pin::Pin; +use std::task::Poll; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::oneshot; +use tokio_stream::Stream; +use url::Host; + +pub struct StdioTunnelListener +where + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + listener: Option<(R, W)>, + dest: (Host, u16), + proxy_protocol: bool, +} + +pub async fn new_stdio_listener( + dest: (Host, u16), + proxy_protocol: bool, +) -> anyhow::Result<( + StdioTunnelListener, + oneshot::Sender<()>, +)> { + let (listener, handle) = stdio::run_server() + .await + .with_context(|| anyhow!("Cannot start STDIO server"))?; + Ok(( + StdioTunnelListener { + listener: Some(listener), + proxy_protocol, + dest, + }, + handle, + )) +} + +impl Stream for StdioTunnelListener +where + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + type Item = anyhow::Result<((R, W), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + let ret = match this.listener.take() { + None => None, + Some(stream) => { + let (host, port) = this.dest.clone(); + Some(Ok(( + stream, + RemoteAddr { + protocol: LocalProtocol::Tcp { + proxy_protocol: this.proxy_protocol, + }, + host, + port, + }, + ))) + } + }; + + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/tcp.rs b/src/tunnel/listeners/tcp.rs new file mode 100644 index 0000000..7358b7c --- /dev/null +++ b/src/tunnel/listeners/tcp.rs @@ -0,0 +1,57 @@ +use crate::tunnel::RemoteAddr; +use crate::{protocols, LocalProtocol}; +use anyhow::{anyhow, Context}; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{ready, Poll}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::Stream; +use url::Host; + +pub struct TcpTunnelListener { + listener: TcpListenerStream, + dest: (Host, u16), + proxy_protocol: bool, +} + +impl TcpTunnelListener { + pub async fn new(bind_addr: SocketAddr, dest: (Host, u16), proxy_protocol: bool) -> anyhow::Result { + let listener = protocols::tcp::run_server(bind_addr, false) + .await + .with_context(|| anyhow!("Cannot start TCP server on {}", bind_addr))?; + + Ok(Self { + listener, + dest, + proxy_protocol, + }) + } +} + +impl Stream for TcpTunnelListener { + type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); + let ret = match ret { + Some(Ok(strean)) => { + let (host, port) = this.dest.clone(); + Some(anyhow::Ok(( + strean.into_split(), + RemoteAddr { + protocol: LocalProtocol::Tcp { + proxy_protocol: this.proxy_protocol, + }, + host, + port, + }, + ))) + } + Some(Err(err)) => Some(Err(anyhow::Error::new(err))), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/tproxy.rs b/src/tunnel/listeners/tproxy.rs new file mode 100644 index 0000000..073a3b9 --- /dev/null +++ b/src/tunnel/listeners/tproxy.rs @@ -0,0 +1,107 @@ +use crate::protocols::udp; +use crate::protocols::udp::{UdpStream, UdpStreamWriter}; +use crate::tunnel::{to_host_port, RemoteAddr}; +use crate::{protocols, LocalProtocol}; +use anyhow::{anyhow, Context}; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{ready, Poll}; +use std::time::Duration; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::Stream; + +pub struct TproxyTcpTunnelListener { + listener: TcpListenerStream, + proxy_protocol: bool, +} + +impl TproxyTcpTunnelListener { + pub async fn new(bind_addr: SocketAddr, proxy_protocol: bool) -> anyhow::Result { + let listener = protocols::tcp::run_server(bind_addr, true) + .await + .with_context(|| anyhow!("Cannot start TProxy TCP server on {}", bind_addr))?; + + Ok(Self { + listener, + proxy_protocol, + }) + } +} + +impl Stream for TproxyTcpTunnelListener { + type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); + let ret = match ret { + Some(Ok(stream)) => { + let (host, port) = to_host_port(stream.local_addr().unwrap()); + Some(anyhow::Ok(( + stream.into_split(), + RemoteAddr { + protocol: LocalProtocol::Tcp { + proxy_protocol: this.proxy_protocol, + }, + host, + port, + }, + ))) + } + Some(Err(err)) => Some(Err(anyhow::Error::new(err))), + None => None, + }; + Poll::Ready(ret) + } +} + +// TPROXY UDP +pub struct TProxyUdpTunnelListener +where + S: Stream>, +{ + listener: S, + timeout: Option, +} + +pub async fn new_tproxy_udp( + bind_addr: SocketAddr, + timeout: Option, +) -> anyhow::Result>>> { + let listener = udp::run_server(bind_addr, timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) + .await + .with_context(|| anyhow!("Cannot start TProxy UDP server on {}", bind_addr))?; + + Ok(TProxyUdpTunnelListener { listener, timeout }) +} + +impl Stream for TProxyUdpTunnelListener +where + S: Stream>, +{ + type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx)); + let ret = match ret { + Some(Ok(stream)) => { + let (host, port) = to_host_port(stream.local_addr().unwrap()); + let stream_writer = stream.writer(); + Some(anyhow::Ok(( + (stream, stream_writer), + RemoteAddr { + protocol: LocalProtocol::Udp { timeout: this.timeout }, + host, + port, + }, + ))) + } + Some(Err(err)) => Some(Err(anyhow::Error::new(err))), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/udp.rs b/src/tunnel/listeners/udp.rs new file mode 100644 index 0000000..e23027e --- /dev/null +++ b/src/tunnel/listeners/udp.rs @@ -0,0 +1,66 @@ +use crate::protocols::udp; +use crate::protocols::udp::{UdpStream, UdpStreamWriter}; +use crate::tunnel::RemoteAddr; +use crate::LocalProtocol; +use anyhow::{anyhow, Context}; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{ready, Poll}; +use std::time::Duration; +use tokio_stream::Stream; +use url::Host; + +pub struct UdpTunnelListener +where + S: Stream>, +{ + listener: S, + dest: (Host, u16), + timeout: Option, +} + +pub async fn new_udp_listener( + bind_addr: SocketAddr, + dest: (Host, u16), + timeout: Option, +) -> anyhow::Result>>> { + let listener = udp::run_server(bind_addr, timeout, |_| Ok(()), |s| Ok(s.clone())) + .await + .with_context(|| anyhow!("Cannot start UDP server on {}", bind_addr))?; + + Ok(UdpTunnelListener { + listener, + dest, + timeout, + }) +} + +impl Stream for UdpTunnelListener +where + S: Stream>, +{ + type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx)); + let ret = match ret { + Some(Ok(stream)) => { + let (host, port) = this.dest.clone(); + let stream_writer = stream.writer(); + Some(anyhow::Ok(( + (stream, stream_writer), + RemoteAddr { + protocol: LocalProtocol::Udp { timeout: this.timeout }, + host, + port, + }, + ))) + } + Some(Err(err)) => Some(Err(anyhow::Error::new(err))), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/listeners/unix_sock.rs b/src/tunnel/listeners/unix_sock.rs new file mode 100644 index 0000000..884956d --- /dev/null +++ b/src/tunnel/listeners/unix_sock.rs @@ -0,0 +1,58 @@ +use crate::protocols::unix_sock; +use crate::protocols::unix_sock::UnixListenerStream; +use crate::tunnel::RemoteAddr; +use crate::LocalProtocol; +use anyhow::{anyhow, Context}; +use std::path::Path; +use std::pin::Pin; +use std::task::{ready, Poll}; +use tokio::net::unix; +use tokio_stream::Stream; +use url::Host; + +pub struct UnixTunnelListener { + listener: UnixListenerStream, + dest: (Host, u16), + proxy_protocol: bool, +} + +impl UnixTunnelListener { + pub async fn new(path: &Path, dest: (Host, u16), proxy_protocol: bool) -> anyhow::Result { + let listener = unix_sock::run_server(path) + .await + .with_context(|| anyhow!("Cannot start Unix domain server on {}", path.display()))?; + + Ok(Self { + listener, + dest, + proxy_protocol, + }) + } +} +impl Stream for UnixTunnelListener { + type Item = anyhow::Result<((unix::OwnedReadHalf, unix::OwnedWriteHalf), RemoteAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); + let ret = match ret { + Some(Ok(stream)) => { + let stream = stream.into_split(); + let (host, port) = this.dest.clone(); + Some(anyhow::Ok(( + stream, + RemoteAddr { + protocol: LocalProtocol::Tcp { + proxy_protocol: this.proxy_protocol, + }, + host, + port, + }, + ))) + } + Some(Err(err)) => Some(Err(anyhow::Error::new(err))), + None => None, + }; + Poll::Ready(ret) + } +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index a322444..413784b 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -1,9 +1,11 @@ pub mod client; +pub mod listeners; pub mod server; pub mod tls_reloader; mod transport; -use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig}; +use crate::protocols::tls; +use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig}; use async_trait::async_trait; use bb8::ManageConnection; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; @@ -312,7 +314,7 @@ impl ManageConnection for WsClientConfig { let timeout = self.timeout_connect; let tcp_stream = if let Some(http_proxy) = &self.http_proxy { - tcp::connect_with_http_proxy( + protocols::tcp::connect_with_http_proxy( http_proxy, self.remote_addr.host(), self.remote_addr.port(), @@ -322,7 +324,7 @@ impl ManageConnection for WsClientConfig { ) .await? } else { - tcp::connect( + protocols::tcp::connect( self.remote_addr.host(), self.remote_addr.port(), so_mark, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index c427786..60a552f 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use std::time::Duration; use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::{http_proxy, socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; +use crate::{protocols, socks5, LocalProtocol, TlsServerConfig, WsServerConfig}; use hyper::body::{Frame, Incoming}; use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; @@ -28,16 +28,16 @@ use once_cell::sync::Lazy; use parking_lot::Mutex; use socket2::SockRef; +use crate::protocols::udp::UdpStream; +use crate::protocols::{http_proxy, tls, udp}; use crate::restrictions::config_reloader::RestrictionsRulesReloader; use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; use crate::socks5::Socks5Stream; -use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; -use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::select; @@ -68,7 +68,7 @@ async fn run_tunnel( Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) } LocalProtocol::Tcp { proxy_protocol } => { - let mut socket = tcp::connect( + let mut socket = protocols::tcp::connect( &remote.host, remote.port, server_config.socket_so_mark, @@ -99,7 +99,7 @@ async fn run_tunnel( let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = tcp::run_server(bind.parse()?, false); + let listening_server = protocols::tcp::run_server(bind.parse()?, false); let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tcp.into_split(); @@ -172,7 +172,7 @@ async fn run_tunnel( } #[cfg(unix)] LocalProtocol::ReverseUnix { ref path } => { - use crate::unix_socket; + use protocols::unix_sock; use tokio::net::UnixStream; #[allow(clippy::type_complexity)] @@ -181,7 +181,7 @@ async fn run_tunnel( let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let listening_server = unix_socket::run_server(path); + let listening_server = unix_sock::run_server(path); let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = stream.into_split(); @@ -862,7 +862,7 @@ pub async fn run_server(server_config: Arc, restrictions: Restri } }; - if let Err(err) = tcp::configure_socket(SockRef::from(&stream), &None) { + if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) { warn!("Error while configuring server socket {:?}", err); } @@ -898,8 +898,8 @@ pub async fn run_server(server_config: Arc, restrictions: Restri // extract client certificate common name if any let restrict_path = tls_ctx .peer_certificates() - .and_then(find_leaf_certificate) - .and_then(|c| cn_from_certificate(&c)); + .and_then(tls::find_leaf_certificate) + .and_then(|c| tls::cn_from_certificate(&c)); match tls_ctx.alpn_protocol() { // http2 Some(b"h2") => { diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs index 17cc761..0573099 100644 --- a/src/tunnel/tls_reloader.rs +++ b/src/tunnel/tls_reloader.rs @@ -1,5 +1,6 @@ +use crate::protocols::tls; use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server}; -use crate::{tls, WsClientConfig, WsServerConfig}; +use crate::{WsClientConfig, WsServerConfig}; use anyhow::Context; use log::trace; use notify::{EventKind, RecommendedWatcher, Watcher}; diff --git a/src/types.rs b/src/types.rs deleted file mode 100644 index 4832606..0000000 --- a/src/types.rs +++ /dev/null @@ -1,374 +0,0 @@ -use crate::http_proxy::HttpProxyListener; -use crate::socks5::{Socks5Listener, Socks5Stream}; -use crate::tunnel::{to_host_port, RemoteAddr}; -use crate::udp::{UdpStream, UdpStreamWriter}; -use crate::unix_socket::UnixListenerStream; -use crate::LocalProtocol; -use std::io; -use std::pin::Pin; -use std::task::{ready, Poll}; -use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::net::unix; -use tokio_stream::wrappers::TcpListenerStream; -use tokio_stream::Stream; -use url::Host; - -pub trait TunnelListener: Stream> { - type Reader: AsyncRead + Send + 'static; - type Writer: AsyncWrite + Send + 'static; -} - -impl TunnelListener for T -where - T: Stream>, - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - type Reader = R; - type Writer = W; -} - -pub struct TcpTunnelListener { - listener: TcpListenerStream, - dest: (Host, u16), - proxy_protocol: bool, -} - -impl TcpTunnelListener { - pub fn new(listener: TcpListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self { - Self { - listener, - dest, - proxy_protocol, - } - } -} - -impl Stream for TcpTunnelListener { - type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); - let ret = match ret { - Some(Ok(strean)) => { - let (host, port) = this.dest.clone(); - Some(anyhow::Ok(( - strean.into_split(), - RemoteAddr { - protocol: LocalProtocol::Tcp { - proxy_protocol: this.proxy_protocol, - }, - host, - port, - }, - ))) - } - Some(Err(err)) => Some(Err(anyhow::Error::new(err))), - None => None, - }; - Poll::Ready(ret) - } -} - -// TPROXY -pub struct TproxyTcpTunnelListener { - listener: TcpListenerStream, - proxy_protocol: bool, -} - -impl TproxyTcpTunnelListener { - pub fn new(listener: TcpListenerStream, proxy_protocol: bool) -> Self { - Self { - listener, - proxy_protocol, - } - } -} - -impl Stream for TproxyTcpTunnelListener { - type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); - let ret = match ret { - Some(Ok(stream)) => { - let (host, port) = to_host_port(stream.local_addr().unwrap()); - Some(anyhow::Ok(( - stream.into_split(), - RemoteAddr { - protocol: LocalProtocol::Tcp { - proxy_protocol: this.proxy_protocol, - }, - host, - port, - }, - ))) - } - Some(Err(err)) => Some(Err(anyhow::Error::new(err))), - None => None, - }; - Poll::Ready(ret) - } -} - -// UNIX -pub struct UnixTunnelListener { - listener: UnixListenerStream, - dest: (Host, u16), - proxy_protocol: bool, -} - -impl UnixTunnelListener { - pub fn new(listener: UnixListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self { - Self { - listener, - dest, - proxy_protocol, - } - } -} -impl Stream for UnixTunnelListener { - type Item = anyhow::Result<((unix::OwnedReadHalf, unix::OwnedWriteHalf), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); - let ret = match ret { - Some(Ok(stream)) => { - let stream = stream.into_split(); - let (host, port) = this.dest.clone(); - Some(anyhow::Ok(( - stream, - RemoteAddr { - protocol: LocalProtocol::Tcp { - proxy_protocol: this.proxy_protocol, - }, - host, - port, - }, - ))) - } - Some(Err(err)) => Some(Err(anyhow::Error::new(err))), - None => None, - }; - Poll::Ready(ret) - } -} - -// TPROXY UDP -pub struct TProxyUdpTunnelListener -where - S: Stream>, -{ - listener: S, - timeout: Option, -} - -impl TProxyUdpTunnelListener -where - S: Stream>, -{ - pub fn new(listener: S, timeout: Option) -> Self { - Self { listener, timeout } - } -} - -impl Stream for TProxyUdpTunnelListener -where - S: Stream>, -{ - type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = unsafe { self.get_unchecked_mut() }; - let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx)); - let ret = match ret { - Some(Ok(stream)) => { - let (host, port) = to_host_port(stream.local_addr().unwrap()); - let stream_writer = stream.writer(); - Some(anyhow::Ok(( - (stream, stream_writer), - RemoteAddr { - protocol: LocalProtocol::Udp { timeout: this.timeout }, - host, - port, - }, - ))) - } - Some(Err(err)) => Some(Err(anyhow::Error::new(err))), - None => None, - }; - Poll::Ready(ret) - } -} - -pub struct UdpTunnelListener -where - S: Stream>, -{ - listener: S, - dest: (Host, u16), - timeout: Option, -} - -impl UdpTunnelListener -where - S: Stream>, -{ - pub fn new(listener: S, dest: (Host, u16), timeout: Option) -> Self { - Self { - listener, - dest, - timeout, - } - } -} - -impl Stream for UdpTunnelListener -where - S: Stream>, -{ - type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = unsafe { self.get_unchecked_mut() }; - let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx)); - let ret = match ret { - Some(Ok(stream)) => { - let (host, port) = this.dest.clone(); - let stream_writer = stream.writer(); - Some(anyhow::Ok(( - (stream, stream_writer), - RemoteAddr { - protocol: LocalProtocol::Udp { timeout: this.timeout }, - host, - port, - }, - ))) - } - Some(Err(err)) => Some(Err(anyhow::Error::new(err))), - None => None, - }; - Poll::Ready(ret) - } -} - -pub struct Socks5TunnelListener { - listener: Socks5Listener, -} - -impl Socks5TunnelListener { - pub fn new(listener: Socks5Listener) -> Self { - Self { listener } - } -} - -impl Stream for Socks5TunnelListener { - type Item = anyhow::Result<((ReadHalf, WriteHalf), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); - let ret = match ret { - Some(Ok((stream, (host, port)))) => { - let protocol = stream.local_protocol(); - Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port }))) - } - Some(Err(err)) => Some(Err(err)), - None => None, - }; - Poll::Ready(ret) - } -} - -pub struct HttpProxyTunnelListener { - listener: HttpProxyListener, - proxy_protocol: bool, -} - -impl HttpProxyTunnelListener { - pub fn new(listener: HttpProxyListener, proxy_protocol: bool) -> Self { - Self { - listener, - proxy_protocol, - } - } -} - -impl Stream for HttpProxyTunnelListener { - type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - let ret = ready!(Pin::new(&mut this.listener).poll_next(cx)); - let ret = match ret { - Some(Ok((stream, (host, port)))) => { - let protocol = LocalProtocol::Tcp { - proxy_protocol: this.proxy_protocol, - }; - Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port }))) - } - Some(Err(err)) => Some(Err(err)), - None => None, - }; - Poll::Ready(ret) - } -} - -pub struct StdioTunnelListener -where - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - listener: Option<(R, W)>, - dest: (Host, u16), - proxy_protocol: bool, -} - -impl StdioTunnelListener -where - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - pub fn new(listener: (R, W), dest: (Host, u16), proxy_protocol: bool) -> Self { - Self { - listener: Some(listener), - proxy_protocol, - dest, - } - } -} - -impl Stream for StdioTunnelListener -where - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - type Item = anyhow::Result<((R, W), RemoteAddr)>; - - fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { - let this = unsafe { self.get_unchecked_mut() }; - let ret = match this.listener.take() { - None => None, - Some(stream) => { - let (host, port) = this.dest.clone(); - Some(Ok(( - stream, - RemoteAddr { - protocol: LocalProtocol::Tcp { - proxy_protocol: this.proxy_protocol, - }, - host, - port, - }, - ))) - } - }; - - Poll::Ready(ret) - } -}