diff --git a/src/main.rs b/src/main.rs index d591492..41ab62f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,22 +12,20 @@ use clap::Parser; use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; use hyper::http::HeaderValue; use serde::{Deserialize, Serialize}; -use std::borrow::Cow; use std::collections::{BTreeMap, HashMap}; use std::io; use std::io::ErrorKind; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; use tokio_rustls::rustls::server::DnsName; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; -use tracing::{debug, error, instrument, Instrument, Span}; +use tracing::{debug, error, span, Instrument, Level}; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -52,7 +50,7 @@ enum Commands { struct Client { /// Listen on local and forwards traffic from remote /// Can be specified multiple times - #[arg(short='L', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)] + #[arg(short='L', long, value_name = "{tcp,udp,socks5}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg)] local_to_remote: Vec, /// (linux only) Mark network packet with SO_MARK sockoption with the specified value. @@ -138,24 +136,17 @@ struct Server { } #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -enum L4Protocol { +enum LocalProtocol { Tcp, Udp { timeout: Option }, Stdio, -} - -impl L4Protocol { - fn new_udp() -> L4Protocol { - L4Protocol::Udp { - timeout: Some(Duration::from_secs(30)), - } - } + Socks5, } #[derive(Clone, Debug)] pub struct LocalToRemote { socket_so_mark: Option, - protocol: L4Protocol, + local_protocol: LocalProtocol, local: SocketAddr, remote: (Host, u16), } @@ -173,18 +164,9 @@ fn parse_duration_sec(arg: &str) -> Result { Ok(Duration::from_secs(secs)) } -fn parse_env_var(arg: &str) -> Result { +fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> { use std::io::Error; - let (mut protocol, arg) = match &arg[..6] { - "tcp://" => (L4Protocol::Tcp, &arg[6..]), - "udp://" => (L4Protocol::new_udp(), &arg[6..]), - _ => match &arg[..8] { - "stdio://" => (L4Protocol::Stdio, &arg[8..]), - _ => (L4Protocol::Tcp, arg), - }, - }; - let (bind, remaining) = if arg.starts_with('[') { // ipv6 bind let Some((ipv6_str, remaining)) = arg.split_once(']') else { @@ -217,12 +199,8 @@ fn parse_env_var(arg: &str) -> Result { } }; - let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else { - return Err(Error::new( - ErrorKind::InvalidInput, - format!("cannot parse bind port from {}", remaining), - )); - }; + let remaining = remaining.trim_start_matches(':'); + let (port_str, remaining) = remaining.split_once([':', '?']).unwrap_or((remaining, "")); let Ok(bind_port): Result = port_str.parse() else { return Err(Error::new( @@ -231,6 +209,14 @@ fn parse_env_var(arg: &str) -> Result { )); }; + Ok((SocketAddr::new(bind, bind_port), remaining)) +} + +fn parse_tunnel_dest( + remaining: &str, +) -> Result<(Host, u16, BTreeMap), io::Error> { + use std::io::Error; + let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else { return Err(Error::new( ErrorKind::InvalidInput, @@ -252,14 +238,30 @@ fn parse_env_var(arg: &str) -> Result { )); }; - let options: BTreeMap, Cow<'_, str>> = remote.query_pairs().collect(); - match &mut protocol { - L4Protocol::Stdio => {} - L4Protocol::Tcp => {} - L4Protocol::Udp { - ref mut timeout, .. - } => { - if let Some(duration) = options + let options: BTreeMap = remote.query_pairs().into_owned().collect(); + Ok((remote_host.to_owned(), remote_port, options)) +} + +fn parse_tunnel_arg(arg: &str) -> Result { + use std::io::Error; + + match &arg[..6] { + "tcp://" => { + let (local_bind, remaining) = parse_local_bind(&arg[6..])?; + let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?; + Ok(LocalToRemote { + socket_so_mark: options + .get("socket_so_mark") + .and_then(|x| x.parse::().ok()), + local_protocol: LocalProtocol::Tcp, + local: local_bind, + remote: (dest_host, dest_port), + }) + } + "udp://" => { + let (local_bind, remaining) = parse_local_bind(&arg[6..])?; + let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?; + let timeout = options .get("timeout_sec") .and_then(|x| x.parse::().ok()) .map(|d| { @@ -269,20 +271,48 @@ fn parse_env_var(arg: &str) -> Result { Some(Duration::from_secs(d)) } }) - { - *timeout = duration; - } - } - }; + .unwrap_or(Some(Duration::from_secs(30))); - Ok(LocalToRemote { - socket_so_mark: options - .get("socket_so_mark") - .and_then(|x| x.parse::().ok()), - protocol, - local: SocketAddr::new(bind, bind_port), - remote: (remote_host.to_owned(), remote_port), - }) + Ok(LocalToRemote { + socket_so_mark: options + .get("socket_so_mark") + .and_then(|x| x.parse::().ok()), + local_protocol: LocalProtocol::Udp { timeout }, + local: local_bind, + remote: (dest_host, dest_port), + }) + } + _ => match &arg[..8] { + "socks5:/" => { + let (local_bind, remaining) = parse_local_bind(&arg[9..])?; + let x = format!("0.0.0.0:0?{}", remaining); + let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; + Ok(LocalToRemote { + socket_so_mark: options + .get("socket_so_mark") + .and_then(|x| x.parse::().ok()), + local_protocol: LocalProtocol::Socks5, + local: local_bind, + remote: (dest_host, dest_port), + }) + } + "stdio://" => { + let (dest_host, dest_port, options) = parse_tunnel_dest(&arg[8..])?; + Ok(LocalToRemote { + socket_so_mark: options + .get("socket_so_mark") + .and_then(|x| x.parse::().ok()), + local_protocol: LocalProtocol::Stdio, + local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)), + remote: (dest_host, dest_port), + }) + } + _ => Err(Error::new( + ErrorKind::InvalidInput, + format!("Invalid local protocol for tunnel {}", arg), + )), + }, + } } fn parse_sni_override(arg: &str) -> Result { @@ -432,7 +462,7 @@ async fn main() { if args .local_to_remote .iter() - .filter(|x| x.protocol == L4Protocol::Stdio) + .filter(|x| x.local_protocol == LocalProtocol::Stdio) .count() > 0 => {} _ => { @@ -474,14 +504,16 @@ async fn main() { for tunnel in args.local_to_remote.into_iter() { let server_config = server_config.clone(); - match &tunnel.protocol { - L4Protocol::Tcp => { + match &tunnel.local_protocol { + LocalProtocol::Tcp => { + let remote = tunnel.remote.clone(); let server = tcp::run_server(tunnel.local) .await .unwrap_or_else(|err| { panic!("Cannot start TCP server on {}: {}", tunnel.local, err) }) - .map_ok(TcpStream::into_split); + .map_err(anyhow::Error::new) + .map_ok(move |stream| (stream.into_split(), remote.clone())); tokio::spawn(async move { if let Err(err) = run_tunnel(server_config, tunnel, server).await { @@ -489,13 +521,15 @@ async fn main() { } }); } - L4Protocol::Udp { timeout } => { + LocalProtocol::Udp { timeout } => { + let remote = tunnel.remote.clone(); let server = udp::run_server(tunnel.local, *timeout) .await .unwrap_or_else(|err| { panic!("Cannot start UDP server on {}: {}", tunnel.local, err) }) - .map_ok(tokio::io::split); + .map_err(anyhow::Error::new) + .map_ok(move |stream| (tokio::io::split(stream), remote.clone())); tokio::spawn(async move { if let Err(err) = run_tunnel(server_config, tunnel, server).await { @@ -503,7 +537,21 @@ async fn main() { } }); } - L4Protocol::Stdio => { + LocalProtocol::Socks5 => { + let server = socks5::run_server(tunnel.local) + .await + .unwrap_or_else(|err| { + panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err) + }) + .map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest)); + + tokio::spawn(async move { + if let Err(err) = run_tunnel(server_config, tunnel, server).await { + error!("{:?}", err); + } + }); + } + LocalProtocol::Stdio => { #[cfg(target_family = "unix")] { let server = stdio::run_server().await.unwrap_or_else(|err| { @@ -512,8 +560,8 @@ async fn main() { tokio::spawn(async move { if let Err(err) = run_tunnel( server_config, - tunnel, - stream::once(async move { Ok(server) }), + tunnel.clone(), + stream::once(async move { Ok((server, tunnel.remote)) }), ) .await { @@ -573,31 +621,28 @@ async fn main() { tokio::signal::ctrl_c().await.unwrap(); } -#[instrument(name="tunnel", level="info", skip_all, fields(id=tracing::field::Empty, remote=tracing::field::Empty))] async fn run_tunnel( server_config: Arc, tunnel: LocalToRemote, incoming_cnx: T, ) -> anyhow::Result<()> where - T: Stream>, + T: Stream>, R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { - let span = Span::current(); - let request_id = Uuid::now_v7(); - span.record("id", request_id.to_string()); - span.record( - "remote", - &format!("{}:{}", tunnel.remote.0, tunnel.remote.1), - ); - - let tunnel = Arc::new(tunnel); pin_mut!(incoming_cnx); - - while let Some(Ok(cnx_stream)) = incoming_cnx.next().await { + while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await { + let request_id = Uuid::now_v7(); + let span = span!( + Level::INFO, + "tunnel", + id = request_id.to_string(), + remote = format!("{}:{}", remote_dest.0, remote_dest.1) + ); let server_config = server_config.clone(); - let tunnel = tunnel.clone(); + let mut tunnel = tunnel.clone(); + tunnel.remote = remote_dest; tokio::spawn( async move { diff --git a/src/socks5.rs b/src/socks5.rs index 5b39a4b..d75a6f0 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -1,22 +1,22 @@ use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, Socks5Server}; use fast_socks5::util::target_addr::TargetAddr; +use fast_socks5::{consts, ReplyError}; use futures_util::{stream, Stream, StreamExt}; -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; +use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; - -use log::warn; use tracing::{info, warn}; use url::Host; pub struct Socks5Listener { - stream: Pin>>>, + stream: Pin> + Send>>, } impl Stream for Socks5Listener { - type Item = anyhow::Result<(TcpStream, Host, u16)>; + type Item = anyhow::Result<(TcpStream, (Host, u16))>; fn poll_next( self: Pin<&mut Self>, @@ -27,7 +27,7 @@ impl Stream for Socks5Listener { } pub async fn run_server(bind: SocketAddr) -> Result { - info!("Starting TCP server listening cnx on {}", bind); + info!("Starting SOCKS5 server listening cnx on {}", bind); let server = Socks5Server::::bind(bind) .await @@ -69,8 +69,22 @@ pub async fn run_server(bind: SocketAddr) -> Result (Host::Ipv6(*ip.ip()), ip.port()), TargetAddr::Domain(host, port) => (Host::Domain(host.clone()), *port), }; + + let mut cnx = cnx.into_inner(); + let ret = cnx + .write_all(&new_reply( + &ReplyError::Succeeded, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), + )) + .await; + + if let Err(err) = ret { + warn!("Cannot reply to socks5 client: {}", err); + continue; + } + drop(acceptor); - return Some((Ok((cnx.into_inner(), host, port)), server)); + return Some((Ok((cnx, (host, port))), server)); } }); @@ -81,6 +95,32 @@ pub async fn run_server(bind: SocketAddr) -> Result Vec { + let (addr_type, mut ip_oct, mut port) = match sock_addr { + SocketAddr::V4(sock) => ( + consts::SOCKS5_ADDR_TYPE_IPV4, + sock.ip().octets().to_vec(), + sock.port().to_be_bytes().to_vec(), + ), + SocketAddr::V6(sock) => ( + consts::SOCKS5_ADDR_TYPE_IPV6, + sock.ip().octets().to_vec(), + sock.port().to_be_bytes().to_vec(), + ), + }; + + let mut reply = vec![ + consts::SOCKS5_VERSION, + error.as_u8(), // transform the error into byte code + 0x00, // reserved + addr_type, // address type (ipv4, v6, domain) + ]; + reply.append(&mut ip_oct); + reply.append(&mut port); + + reply +} + #[cfg(test)] mod test { use super::*; diff --git a/src/stdio.rs b/src/stdio.rs index 26cfed9..c9ee40c 100644 --- a/src/stdio.rs +++ b/src/stdio.rs @@ -1,8 +1,7 @@ use tokio_fd::AsyncFd; -use tracing::info; pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> { - info!("Starting STDIO server"); + eprintln!("Starting STDIO server"); let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?; let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?; diff --git a/src/transport.rs b/src/transport.rs index 072e4d6..8aaa162 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use crate::{tcp, tls, L4Protocol, LocalToRemote, WsClientConfig, WsServerConfig}; +use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig, WsServerConfig}; use anyhow::Context; use fastwebsockets::{ Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite, @@ -28,7 +28,7 @@ use tokio::time::timeout; use crate::udp::MyUdpSocket; use serde::{Deserialize, Serialize}; use tracing::log::debug; -use tracing::{error, info, instrument, trace, warn, Instrument, Span}; +use tracing::{error, info, instrument, span, trace, warn, Instrument, Level, Span}; use url::Host; use uuid::Uuid; @@ -47,7 +47,7 @@ where #[derive(Debug, Clone, Serialize, Deserialize)] struct JwtTunnelConfig { pub id: String, - pub p: L4Protocol, + pub p: LocalProtocol, pub r: String, pub rp: u16, } @@ -81,7 +81,12 @@ pub async fn connect( let data = JwtTunnelConfig { id: request_id.to_string(), - p: tunnel_cfg.protocol, + p: match tunnel_cfg.local_protocol { + LocalProtocol::Tcp => LocalProtocol::Tcp, + LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol, + LocalProtocol::Stdio => LocalProtocol::Tcp, + LocalProtocol::Socks5 => LocalProtocol::Tcp, + }, r: tunnel_cfg.remote.0.to_string(), rp: tunnel_cfg.remote.1, }; @@ -166,7 +171,7 @@ async fn from_query( server_config: &WsServerConfig, query: &str, ) -> anyhow::Result<( - L4Protocol, + LocalProtocol, Host, u16, Pin>, @@ -204,19 +209,19 @@ async fn from_query( } match jwt.claims.p { - L4Protocol::Udp { .. } => { + LocalProtocol::Udp { .. } => { let host = Host::parse(&jwt.claims.r)?; let cnx = Arc::new(UdpSocket::bind("[::]:0").await?); cnx.connect((host.to_string(), jwt.claims.rp)).await?; Ok(( - L4Protocol::Udp { timeout: None }, + LocalProtocol::Udp { timeout: None }, host, jwt.claims.rp, Box::pin(MyUdpSocket::new(cnx.clone())), Box::pin(MyUdpSocket::new(cnx)), )) } - L4Protocol::Tcp { .. } => { + LocalProtocol::Tcp { .. } => { let host = Host::parse(&jwt.claims.r)?; let port = jwt.claims.rp; let (rx, tx) = tcp::connect( @@ -330,7 +335,15 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() let (stream, peer_addr) = listener.accept().await?; let _ = stream.set_nodelay(true); - Span::current().record("peer", peer_addr.to_string()); + let span = span!( + Level::INFO, + "tunnel", + id = tracing::field::Empty, + remote = tracing::field::Empty, + peer = peer_addr.to_string(), + forwarded_for = tracing::field::Empty + ); + info!("Accepting connection"); let upgrade_fn = upgrade_fn.clone(); // TLS @@ -354,7 +367,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() error!("Error while upgrading cnx to websocket: {:?}", e); } } - .instrument(Span::current()); + .instrument(span); tokio::spawn(fut); // Normal @@ -369,7 +382,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() error!("Error while upgrading cnx to weboscket: {:?}", e); } } - .instrument(Span::current()); + .instrument(span); tokio::spawn(fut); };