diff --git a/src/main.rs b/src/main.rs index 553be24..6735ee8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,7 +62,8 @@ struct Client { /// Listen on remote and forwards traffic from local. Can be specified multiple times. Only tcp is supported /// examples: /// 'tcp://1212:google.com:443' => listen on server for incoming tcp cnx on port 1212 and forward to google.com on port 443 from local machine - #[arg(short='R', long, value_name = "{tcp}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] + /// 'udp://1212:1.1.1.1:53' => listen on server for incoming udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53 from local machine + #[arg(short='R', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] remote_to_local: Vec, /// (linux only) Mark network packet with SO_MARK sockoption with the specified value. @@ -172,6 +173,7 @@ enum LocalProtocol { Stdio, Socks5, ReverseTcp, + ReverseUdp { timeout: Option }, } #[derive(Clone, Debug)] @@ -551,7 +553,31 @@ async fn main() { LocalProtocol::Tcp => { tunnel.local_protocol = ReverseTcp; tokio::spawn(async move { - if let Err(err) = tunnel::client::run_reverse_tunnel(client_config, tunnel).await { + let remote = tunnel.remote.clone(); + let cfg = client_config.clone(); + let connect_to_dest = || async { + tcp::connect(&remote.0, remote.1, cfg.socket_so_mark, cfg.timeout_connect).await + }; + + if let Err(err) = + tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await + { + error!("{:?}", err); + } + }); + } + LocalProtocol::Udp { timeout } => { + tunnel.local_protocol = LocalProtocol::ReverseUdp { timeout: *timeout }; + + tokio::spawn(async move { + let cfg = client_config.clone(); + let remote = tunnel.remote.clone(); + let connect_to_dest = + || async { udp::connect(&remote.0, remote.1, cfg.timeout_connect).await }; + + if let Err(err) = + tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await + { error!("{:?}", err); } }); @@ -628,6 +654,7 @@ async fn main() { } } LocalProtocol::ReverseTcp => {} + LocalProtocol::ReverseUdp { .. } => {} } } } diff --git a/src/tcp.rs b/src/tcp.rs index 91f908c..630c491 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -44,7 +44,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option) -> Result<(), pub async fn connect( host: &Host, port: u16, - so_mark: &Option, + so_mark: Option, connect_timeout: Duration, ) -> Result { info!("Opening TCP connection to {}:{}", host, port); @@ -68,7 +68,7 @@ pub async fn connect( SocketAddr::V6(_) => TcpSocket::new_v6()?, }; - configure_socket(&mut socket, so_mark)?; + configure_socket(&mut socket, &so_mark)?; match timeout(connect_timeout, socket.connect(addr)).await { Ok(Ok(stream)) => { cnx = Some(stream); @@ -103,7 +103,7 @@ pub async fn connect_with_http_proxy( proxy: &Url, host: &Host, port: u16, - so_mark: &Option, + so_mark: Option, connect_timeout: Duration, ) -> Result { let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned(); @@ -226,7 +226,7 @@ mod tests { &"http://localhost:8080".parse().unwrap(), &Host::Domain("[::1]".to_string()), 1236, - &None, + None, Duration::from_secs(1), ) .await diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 2c37ae4..fab5cea 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,5 +1,5 @@ use super::{JwtTunnelConfig, JWT_KEY}; -use crate::{tcp, LocalToRemote, WsClientConfig}; +use crate::{LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; use fastwebsockets::WebSocket; @@ -149,10 +149,16 @@ where Ok(()) } -pub async fn run_reverse_tunnel( +pub async fn run_reverse_tunnel( client_config: Arc, mut tunnel_cfg: LocalToRemote, -) -> anyhow::Result<()> { + connect_to_dest: F, +) -> anyhow::Result<()> +where + F: Fn() -> Fut, + Fut: Future>, + T: AsyncRead + AsyncWrite + Send + 'static, +{ // Invert local with remote let remote = tunnel_cfg.remote; tunnel_cfg.remote = match tunnel_cfg.local.ip() { @@ -178,14 +184,7 @@ pub async fn run_reverse_tunnel( ws.set_auto_apply_mask(client_config.websocket_mask_frame); // Connect to endpoint - let stream = tcp::connect( - &remote.0, - remote.1, - &client_config.socket_so_mark, - client_config.timeout_connect, - ) - .instrument(span.clone()) - .await; + let stream = connect_to_dest().instrument(span.clone()).await; let stream = match stream { Ok(s) => s, diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index dbe2bb6..3835e07 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -35,6 +35,7 @@ impl JwtTunnelConfig { LocalProtocol::Stdio => LocalProtocol::Tcp, LocalProtocol::Socks5 => LocalProtocol::Tcp, LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, + LocalProtocol::ReverseUdp { .. } => tunnel.local_protocol, }, r: tunnel.remote.0.to_string(), rp: tunnel.remote.1, @@ -114,7 +115,7 @@ impl ManageConnection for WsClientConfig { async fn connect(&self) -> Result { let (host, port) = &self.remote_addr; - let so_mark = &self.socket_so_mark; + let so_mark = self.socket_so_mark; let timeout = self.timeout_connect; let tcp_stream = if let Some(http_proxy) = &self.http_proxy { diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 1737d04..00ff6a8 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,6 +1,7 @@ use ahash::{HashMap, HashMapExt}; -use futures_util::StreamExt; +use futures_util::{Stream, StreamExt}; use std::cmp::min; +use std::io; use std::ops::{Deref, Not}; use std::pin::Pin; use std::sync::Arc; @@ -15,6 +16,7 @@ use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; +use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio::sync::oneshot; @@ -71,7 +73,7 @@ async fn from_query( LocalProtocol::Tcp => { let host = Host::parse(&jwt.claims.r)?; let port = jwt.claims.rp; - let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10)) + let (rx, tx) = tcp::connect(&host, port, server_config.socket_so_mark, Duration::from_secs(10)) .await? .into_split(); @@ -97,6 +99,27 @@ async fn from_query( Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) } + LocalProtocol::ReverseUdp { timeout } => { + #[allow(clippy::type_complexity)] + static REVERSE: Lazy< + Mutex, u16), Pin> + Send>>>>, + > = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let listening_server = REVERSE.lock().remove(&local_srv); + let mut listening_server = if let Some(listening_server) = listening_server { + listening_server + } else { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + Box::pin(udp::run_server(bind.parse()?, timeout).await?) + }; + + let udp = listening_server.next().await.unwrap()?; + let (local_rx, local_tx) = tokio::io::split(udp); + REVERSE.lock().insert(local_srv.clone(), listening_server); + + Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) + } _ => Err(anyhow::anyhow!("Invalid upgrade request")), } }