diff --git a/Cargo.lock b/Cargo.lock index 37e6a28..15b9b84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1614,7 +1614,7 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "wstunnel" -version = "7.9.2" +version = "7.9.4" dependencies = [ "ahash", "anyhow", diff --git a/src/main.rs b/src/main.rs index d341b29..553be24 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; use tracing::{error, info, Level}; +use crate::LocalProtocol::ReverseTcp; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -58,6 +59,12 @@ struct Client { #[arg(short='L', long, value_name = "{tcp,udp,socks5,stdio}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] local_to_remote: Vec, + /// Listen on remote and forwards traffic from local. Can be specified multiple times. Only tcp is supported + /// examples: + /// 'tcp://1212:google.com:443' => listen on server for incoming tcp cnx on port 1212 and forward to google.com on port 443 from local machine + #[arg(short='R', long, value_name = "{tcp}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] + remote_to_local: Vec, + /// (linux only) Mark network packet with SO_MARK sockoption with the specified value. /// You need to use {root, sudo, capabilities} to run wstunnel when using this option #[arg(long, value_name = "INT", verbatim_doc_comment)] @@ -164,6 +171,7 @@ enum LocalProtocol { Udp { timeout: Option }, Stdio, Socks5, + ReverseTcp, } #[derive(Clone, Debug)] @@ -537,6 +545,21 @@ async fn main() { let client_config = Arc::new(client_config); // Start tunnels + for mut tunnel in args.remote_to_local.into_iter() { + let client_config = client_config.clone(); + match &tunnel.local_protocol { + LocalProtocol::Tcp => { + tunnel.local_protocol = ReverseTcp; + tokio::spawn(async move { + if let Err(err) = tunnel::client::run_reverse_tunnel(client_config, tunnel).await { + error!("{:?}", err); + } + }); + } + _ => panic!("Invalid protocol for reverse tunnel"), + } + } + for tunnel in args.local_to_remote.into_iter() { let client_config = client_config.clone(); @@ -604,6 +627,7 @@ async fn main() { panic!("stdio is not implemented for non unix platform") } } + LocalProtocol::ReverseTcp => {} } } } diff --git a/src/tcp.rs b/src/tcp.rs index 83fbf79..91f908c 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -3,6 +3,7 @@ use std::{io, vec}; use base64::Engine; use bytes::BytesMut; +use log::warn; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -74,11 +75,11 @@ pub async fn connect( break; } Ok(Err(err)) => { - debug!("Cannot connect to tcp endpoint {addr} reason {err}"); + warn!("Cannot connect to tcp endpoint {addr} reason {err}"); last_err = Some(err); } Err(_) => { - debug!( + warn!( "Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed", connect_timeout.as_secs() ); diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 745fa7d..2c37ae4 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,5 +1,5 @@ use super::{JwtTunnelConfig, JWT_KEY}; -use crate::{LocalToRemote, WsClientConfig}; +use crate::{tcp, LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; use fastwebsockets::WebSocket; @@ -9,6 +9,7 @@ use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; use hyper::{Body, Request}; use std::future::Future; +use std::net::IpAddr; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; @@ -147,3 +148,67 @@ where Ok(()) } + +pub async fn run_reverse_tunnel( + client_config: Arc, + mut tunnel_cfg: LocalToRemote, +) -> anyhow::Result<()> { + // Invert local with remote + let remote = tunnel_cfg.remote; + tunnel_cfg.remote = match tunnel_cfg.local.ip() { + IpAddr::V4(ip) => (Host::Ipv4(ip), tunnel_cfg.local.port()), + IpAddr::V6(ip) => (Host::Ipv6(ip), tunnel_cfg.local.port()), + }; + + loop { + let client_config = client_config.clone(); + let request_id = Uuid::now_v7(); + let span = span!( + Level::INFO, + "tunnel", + id = request_id.to_string(), + remote = format!("{}:{}", tunnel_cfg.remote.0, tunnel_cfg.remote.1) + ); + let _span = span.enter(); + + // Correctly configure tunnel cfg + let mut ws = connect(request_id, &client_config, &tunnel_cfg) + .instrument(span.clone()) + .await?; + ws.set_auto_apply_mask(client_config.websocket_mask_frame); + + // Connect to endpoint + let stream = tcp::connect( + &remote.0, + remote.1, + &client_config.socket_so_mark, + client_config.timeout_connect, + ) + .instrument(span.clone()) + .await; + + let stream = match stream { + Ok(s) => s, + Err(err) => { + error!("Cannot connect to {remote:?}: {err:?}"); + continue; + } + }; + + let (local_rx, local_tx) = tokio::io::split(stream); + let (ws_rx, ws_tx) = ws.split(tokio::io::split); + let (close_tx, close_rx) = oneshot::channel::<()>(); + + let tunnel = async move { + let ping_frequency = client_config.websocket_ping_frequency; + tokio::spawn( + super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()), + ); + + // Forward websocket rx to local rx + let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; + } + .instrument(span.clone()); + tokio::spawn(tunnel); + } +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 3b9b0a7..dbe2bb6 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -34,6 +34,7 @@ impl JwtTunnelConfig { LocalProtocol::Udp { .. } => tunnel.local_protocol, LocalProtocol::Stdio => LocalProtocol::Tcp, LocalProtocol::Socks5 => LocalProtocol::Tcp, + LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, }, r: tunnel.remote.0.to_string(), rp: tunnel.remote.1, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 4b7a4c7..1737d04 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,3 +1,5 @@ +use ahash::{HashMap, HashMapExt}; +use futures_util::StreamExt; use std::cmp::min; use std::ops::{Deref, Not}; use std::pin::Pin; @@ -10,10 +12,13 @@ use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{http, Body, Request, Response, StatusCode}; use jsonwebtoken::TokenData; +use once_cell::sync::Lazy; +use parking_lot::Mutex; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio::sync::oneshot; +use tokio_stream::wrappers::TcpListenerStream; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; @@ -63,7 +68,7 @@ async fn from_query( Box::pin(cnx), )) } - LocalProtocol::Tcp { .. } => { + LocalProtocol::Tcp => { let host = Host::parse(&jwt.claims.r)?; let port = jwt.claims.rp; let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10)) @@ -72,6 +77,26 @@ async fn from_query( Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx))) } + LocalProtocol::ReverseTcp => { + #[allow(clippy::type_complexity)] + static REVERSE: Lazy, u16), TcpListenerStream>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let listening_server = REVERSE.lock().remove(&local_srv); + let mut listening_server = if let Some(listening_server) = listening_server { + listening_server + } else { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + tcp::run_server(bind.parse()?).await? + }; + + let tcp = listening_server.next().await.unwrap()?; + let (local_rx, local_tx) = tokio::io::split(tcp); + REVERSE.lock().insert(local_srv.clone(), listening_server); + + Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) + } _ => Err(anyhow::anyhow!("Invalid upgrade request")), } }