From 10f15d122560def96165ba42970bee18ceba531b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Fri, 12 Jan 2024 19:31:00 +0100 Subject: [PATCH] Add support for unix socket --- src/main.rs | 91 +++++++++++++++++++++++++++++++++++++++++++- src/tunnel/mod.rs | 8 ++-- src/tunnel/server.rs | 38 ++++++++++++++++-- src/unix_socket.rs | 58 ++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 9 deletions(-) create mode 100644 src/unix_socket.rs diff --git a/src/main.rs b/src/main.rs index 9364409..e4d031c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,8 @@ mod tcp; mod tls; mod tunnel; mod udp; +#[cfg(unix)] +mod unix_socket; use anyhow::anyhow; use base64::Engine; @@ -258,7 +260,7 @@ struct Server { tls_private_key: Option, } -#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] enum LocalProtocol { Tcp { proxy_protocol: bool }, Udp { timeout: Option }, @@ -269,6 +271,8 @@ enum LocalProtocol { ReverseTcp, ReverseUdp { timeout: Option }, ReverseSocks5, + ReverseUnix { path: PathBuf }, + Unix { path: PathBuf }, } #[derive(Clone, Debug)] @@ -392,6 +396,22 @@ fn parse_tunnel_arg(arg: &str) -> Result { remote: (dest_host, dest_port), }) } + "unix:/" => { + let Some((path, remote)) = arg[7..].split_once(':') else { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse unix socket path from {}", arg), + )); + }; + let (dest_host, dest_port, _options) = parse_tunnel_dest(remote)?; + Ok(LocalToRemote { + local_protocol: LocalProtocol::Unix { + path: PathBuf::from(path), + }, + local: SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)), + remote: (dest_host, dest_port), + }) + } _ => match &arg[..8] { "socks5:/" => { let (local_bind, remaining) = parse_local_bind(&arg[9..])?; @@ -800,7 +820,44 @@ async fn main() { } }); } - _ => panic!("Invalid protocol for reverse tunnel"), + LocalProtocol::Unix { path } => { + let path = path.clone(); + tokio::spawn(async move { + let remote = tunnel.remote.clone(); + let cfg = client_config.clone(); + let connect_to_dest = |_| async { + tcp::connect( + &remote.0, + remote.1, + cfg.socket_so_mark, + cfg.timeout_connect, + &cfg.dns_resolver, + ) + .await + }; + + let (host, port) = to_host_port(tunnel.local); + let remote = RemoteAddr { + protocol: LocalProtocol::ReverseUnix { path: path.clone() }, + host, + port, + }; + if let Err(err) = + tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await + { + error!("{:?}", err); + } + }); + } + LocalProtocol::Stdio + | LocalProtocol::TProxyTcp + | LocalProtocol::TProxyUdp { .. } + | LocalProtocol::ReverseTcp + | LocalProtocol::ReverseUdp { .. } + | LocalProtocol::ReverseSocks5 + | LocalProtocol::ReverseUnix { .. } => { + panic!("Invalid protocol for reverse tunnel"); + } } } @@ -853,6 +910,35 @@ async fn main() { } }); } + #[cfg(unix)] + LocalProtocol::Unix { path } => { + let remote = tunnel.remote.clone(); + let server = unix_socket::run_server(path) + .await + .unwrap_or_else(|err| { + panic!("Cannot start Unix domain server on {}: {}", tunnel.local, err) + }) + .map_err(anyhow::Error::new) + .map_ok(move |stream| { + let remote = RemoteAddr { + protocol: LocalProtocol::Tcp { proxy_protocol: false }, + host: remote.0.clone(), + port: remote.1, + }; + (stream.into_split(), remote) + }); + + tokio::spawn(async move { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { + error!("{:?}", err); + } + }); + } + #[cfg(not(unix))] + LocalProtocol::Unix => { + panic!("Unix socket is not available for non Unix platform") + } + #[cfg(target_os = "linux")] LocalProtocol::TProxyUdp { timeout } => { let timeout = *timeout; @@ -951,6 +1037,7 @@ async fn main() { LocalProtocol::ReverseTcp => {} LocalProtocol::ReverseUdp { .. } => {} LocalProtocol::ReverseSocks5 => {} + LocalProtocol::ReverseUnix { .. } => {} } } } diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index b9da6ef..6c7cdc5 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -35,15 +35,17 @@ impl JwtTunnelConfig { Self { id: request_id.to_string(), p: match dest.protocol { - LocalProtocol::Tcp { .. } => dest.protocol, - LocalProtocol::Udp { .. } => dest.protocol, + LocalProtocol::Tcp { .. } => dest.protocol.clone(), + LocalProtocol::Udp { .. } => dest.protocol.clone(), LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false }, LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false }, LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, - LocalProtocol::ReverseUdp { .. } => dest.protocol, + LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(), LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5, LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false }, LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, + LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(), }, r: dest.host.to_string(), rp: dest.port, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 4281644..d3bca27 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -11,7 +11,7 @@ use std::sync::Arc; use std::time::Duration; use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; +use crate::{socks5, tcp, tls, udp, unix_socket, LocalProtocol, TlsServerConfig, WsServerConfig}; use hyper::body::Incoming; use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; @@ -26,7 +26,7 @@ use crate::socks5::Socks5Stream; use crate::tunnel::tls_reloader::TlsReloader; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::{TcpListener, TcpStream, UnixStream}; use tokio::select; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::TlsAcceptor; @@ -133,7 +133,37 @@ async fn run_tunnel( }; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } - _ => Err(anyhow::anyhow!("Invalid upgrade request")), + #[cfg(unix)] + LocalProtocol::ReverseUnix { ref path } => { + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let listening_server = unix_socket::run_server(path); + let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + let (local_rx, local_tx) = stream.into_split(); + + let remote = RemoteAddr { + protocol: jwt.claims.p.clone(), + host: local_srv.0, + port: local_srv.1, + }; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + #[cfg(not(unix))] + LocalProtocol::ReverseUnix { ref path } => { + error!("Received an unsupported target protocol {:?}", jwt.claims); + Err(anyhow::anyhow!("Invalid upgrade request")) + } + LocalProtocol::Stdio + | LocalProtocol::Socks5 { .. } + | LocalProtocol::TProxyTcp + | LocalProtocol::TProxyUdp { .. } + | LocalProtocol::Unix { .. } => { + error!("Received an unsupported target protocol {:?}", jwt.claims); + Err(anyhow::anyhow!("Invalid upgrade request")) + } } } @@ -333,7 +363,7 @@ async fn server_upgrade( return err; } - let req_protocol = jwt.claims.p; + let req_protocol = jwt.claims.p.clone(); let tunnel = match run_tunnel(&server_config, jwt, client_addr).await { Ok(ret) => ret, Err(err) => { diff --git a/src/unix_socket.rs b/src/unix_socket.rs new file mode 100644 index 0000000..fe73647 --- /dev/null +++ b/src/unix_socket.rs @@ -0,0 +1,58 @@ +use anyhow::Context; +use futures_util::Stream; +use std::io; +use std::path::Path; +use std::pin::Pin; +use std::task::Poll; +use tokio::net::{UnixListener, UnixStream}; +use tracing::log::info; + +pub struct UnixListenerStream { + inner: UnixListener, + path_to_delete: bool, +} + +impl UnixListenerStream { + pub fn new(listener: UnixListener, path_to_delete: bool) -> Self { + Self { + inner: listener, + path_to_delete, + } + } +} + +impl Drop for UnixListenerStream { + fn drop(&mut self) { + if self.path_to_delete { + let Ok(addr) = &self.inner.local_addr() else { + return; + }; + let Some(path) = addr.as_pathname() else { + return; + }; + let _ = std::fs::remove_file(path); + } + } +} + +impl Stream for UnixListenerStream { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll>> { + match self.inner.poll_accept(cx) { + Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } +} + +pub async fn run_server(socket_path: &Path) -> Result { + info!("Starting Unix socket server listening cnx on {:?}", socket_path); + + let path_to_delete = socket_path.exists(); + let listener = UnixListener::bind(socket_path) + .with_context(|| format!("Cannot create Unix socket server {:?}", socket_path))?; + + Ok(UnixListenerStream::new(listener, path_to_delete)) +}