Former-commit-id: a5c765b81ecdd7a4a64ae9a29ac827f054dc0f6b [formerly 303f4a61e4c6b8146afa3035fd79c7ea9f4c3093] [formerly 514ae96091247299929d298c96fb0898b5c94bf2 [formerly 3496804edf70a12aa77173061c40f36d9301d58c]]
Former-commit-id: c344f3736face3ce16e7fda6e395f63adf864725 [formerly f82d31b14bf430666762dd5c4936d5a3107a0d17]
Former-commit-id: dc05fa123697b2011a8c8e77bd7c3cb94f6fa30f
Former-commit-id: 55513ef82602cc7eba5f9979282c3c0aa750ebba
Former-commit-id: 64bbd9b657056f122a10dd9d83fc3900f1f311a4
Former-commit-id: ef78d4ab69ea339f99878c8f3cc649c4f39a1f96 [formerly 252cf5bb9484db648812603dffb0cb686fed0aa3]
Former-commit-id: 583caf5e4d75622fb34badbec536acde3f2b4c58
This commit is contained in:
Σrebe - Romain GERARD 2023-10-17 00:00:45 +02:00
parent a15e8a2548
commit aca065fcab
4 changed files with 192 additions and 95 deletions

View file

@ -12,22 +12,20 @@ use clap::Parser;
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
use hyper::http::HeaderValue;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use std::io;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::rustls::server::DnsName;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
use tracing::{debug, error, instrument, Instrument, Span};
use tracing::{debug, error, span, Instrument, Level};
use tracing_subscriber::EnvFilter;
use url::{Host, Url};
@ -52,7 +50,7 @@ enum Commands {
struct Client {
/// Listen on local and forwards traffic from remote
/// Can be specified multiple times
#[arg(short='L', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
#[arg(short='L', long, value_name = "{tcp,udp,socks5}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg)]
local_to_remote: Vec<LocalToRemote>,
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
@ -138,24 +136,17 @@ struct Server {
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
enum L4Protocol {
enum LocalProtocol {
Tcp,
Udp { timeout: Option<Duration> },
Stdio,
}
impl L4Protocol {
fn new_udp() -> L4Protocol {
L4Protocol::Udp {
timeout: Some(Duration::from_secs(30)),
}
}
Socks5,
}
#[derive(Clone, Debug)]
pub struct LocalToRemote {
socket_so_mark: Option<i32>,
protocol: L4Protocol,
local_protocol: LocalProtocol,
local: SocketAddr,
remote: (Host<String>, u16),
}
@ -173,18 +164,9 @@ fn parse_duration_sec(arg: &str) -> Result<Duration, io::Error> {
Ok(Duration::from_secs(secs))
}
fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
use std::io::Error;
let (mut protocol, arg) = match &arg[..6] {
"tcp://" => (L4Protocol::Tcp, &arg[6..]),
"udp://" => (L4Protocol::new_udp(), &arg[6..]),
_ => match &arg[..8] {
"stdio://" => (L4Protocol::Stdio, &arg[8..]),
_ => (L4Protocol::Tcp, arg),
},
};
let (bind, remaining) = if arg.starts_with('[') {
// ipv6 bind
let Some((ipv6_str, remaining)) = arg.split_once(']') else {
@ -217,12 +199,8 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
}
};
let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse bind port from {}", remaining),
));
};
let remaining = remaining.trim_start_matches(':');
let (port_str, remaining) = remaining.split_once([':', '?']).unwrap_or((remaining, ""));
let Ok(bind_port): Result<u16, _> = port_str.parse() else {
return Err(Error::new(
@ -231,6 +209,14 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
));
};
Ok((SocketAddr::new(bind, bind_port), remaining))
}
fn parse_tunnel_dest(
remaining: &str,
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
use std::io::Error;
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
return Err(Error::new(
ErrorKind::InvalidInput,
@ -252,14 +238,30 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
));
};
let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
match &mut protocol {
L4Protocol::Stdio => {}
L4Protocol::Tcp => {}
L4Protocol::Udp {
ref mut timeout, ..
} => {
if let Some(duration) = options
let options: BTreeMap<String, String> = remote.query_pairs().into_owned().collect();
Ok((remote_host.to_owned(), remote_port, options))
}
fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
use std::io::Error;
match &arg[..6] {
"tcp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Tcp,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"udp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| {
@ -269,20 +271,48 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
Some(Duration::from_secs(d))
}
})
{
*timeout = duration;
}
}
};
.unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
protocol,
local: SocketAddr::new(bind, bind_port),
remote: (remote_host.to_owned(), remote_port),
})
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Udp { timeout },
local: local_bind,
remote: (dest_host, dest_port),
})
}
_ => match &arg[..8] {
"socks5:/" => {
let (local_bind, remaining) = parse_local_bind(&arg[9..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Socks5,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"stdio://" => {
let (dest_host, dest_port, options) = parse_tunnel_dest(&arg[8..])?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Stdio,
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
remote: (dest_host, dest_port),
})
}
_ => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid local protocol for tunnel {}", arg),
)),
},
}
}
fn parse_sni_override(arg: &str) -> Result<DnsName, io::Error> {
@ -432,7 +462,7 @@ async fn main() {
if args
.local_to_remote
.iter()
.filter(|x| x.protocol == L4Protocol::Stdio)
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
.count()
> 0 => {}
_ => {
@ -474,14 +504,16 @@ async fn main() {
for tunnel in args.local_to_remote.into_iter() {
let server_config = server_config.clone();
match &tunnel.protocol {
L4Protocol::Tcp => {
match &tunnel.local_protocol {
LocalProtocol::Tcp => {
let remote = tunnel.remote.clone();
let server = tcp::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
})
.map_ok(TcpStream::into_split);
.map_err(anyhow::Error::new)
.map_ok(move |stream| (stream.into_split(), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
@ -489,13 +521,15 @@ async fn main() {
}
});
}
L4Protocol::Udp { timeout } => {
LocalProtocol::Udp { timeout } => {
let remote = tunnel.remote.clone();
let server = udp::run_server(tunnel.local, *timeout)
.await
.unwrap_or_else(|err| {
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
})
.map_ok(tokio::io::split);
.map_err(anyhow::Error::new)
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
@ -503,7 +537,21 @@ async fn main() {
}
});
}
L4Protocol::Stdio => {
LocalProtocol::Socks5 => {
let server = socks5::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
})
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
LocalProtocol::Stdio => {
#[cfg(target_family = "unix")]
{
let server = stdio::run_server().await.unwrap_or_else(|err| {
@ -512,8 +560,8 @@ async fn main() {
tokio::spawn(async move {
if let Err(err) = run_tunnel(
server_config,
tunnel,
stream::once(async move { Ok(server) }),
tunnel.clone(),
stream::once(async move { Ok((server, tunnel.remote)) }),
)
.await
{
@ -573,31 +621,28 @@ async fn main() {
tokio::signal::ctrl_c().await.unwrap();
}
#[instrument(name="tunnel", level="info", skip_all, fields(id=tracing::field::Empty, remote=tracing::field::Empty))]
async fn run_tunnel<T, R, W>(
server_config: Arc<WsClientConfig>,
tunnel: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = io::Result<(R, W)>>,
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let span = Span::current();
let request_id = Uuid::now_v7();
span.record("id", request_id.to_string());
span.record(
"remote",
&format!("{}:{}", tunnel.remote.0, tunnel.remote.1),
);
let tunnel = Arc::new(tunnel);
pin_mut!(incoming_cnx);
while let Some(Ok(cnx_stream)) = incoming_cnx.next().await {
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let server_config = server_config.clone();
let tunnel = tunnel.clone();
let mut tunnel = tunnel.clone();
tunnel.remote = remote_dest;
tokio::spawn(
async move {

View file

@ -1,22 +1,22 @@
use anyhow::Context;
use fast_socks5::server::{Config, DenyAuthentication, Socks5Server};
use fast_socks5::util::target_addr::TargetAddr;
use fast_socks5::{consts, ReplyError};
use futures_util::{stream, Stream, StreamExt};
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::task::Poll;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use log::warn;
use tracing::{info, warn};
use url::Host;
pub struct Socks5Listener {
stream: Pin<Box<dyn Stream<Item = anyhow::Result<(TcpStream, Host, u16)>>>>,
stream: Pin<Box<dyn Stream<Item = anyhow::Result<(TcpStream, (Host, u16))>> + Send>>,
}
impl Stream for Socks5Listener {
type Item = anyhow::Result<(TcpStream, Host, u16)>;
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
fn poll_next(
self: Pin<&mut Self>,
@ -27,7 +27,7 @@ impl Stream for Socks5Listener {
}
pub async fn run_server(bind: SocketAddr) -> Result<Socks5Listener, anyhow::Error> {
info!("Starting TCP server listening cnx on {}", bind);
info!("Starting SOCKS5 server listening cnx on {}", bind);
let server = Socks5Server::<DenyAuthentication>::bind(bind)
.await
@ -69,8 +69,22 @@ pub async fn run_server(bind: SocketAddr) -> Result<Socks5Listener, anyhow::Erro
TargetAddr::Ip(SocketAddr::V6(ip)) => (Host::Ipv6(*ip.ip()), ip.port()),
TargetAddr::Domain(host, port) => (Host::Domain(host.clone()), *port),
};
let mut cnx = cnx.into_inner();
let ret = cnx
.write_all(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
))
.await;
if let Err(err) = ret {
warn!("Cannot reply to socks5 client: {}", err);
continue;
}
drop(acceptor);
return Some((Ok((cnx.into_inner(), host, port)), server));
return Some((Ok((cnx, (host, port))), server));
}
});
@ -81,6 +95,32 @@ pub async fn run_server(bind: SocketAddr) -> Result<Socks5Listener, anyhow::Erro
Ok(listener)
}
pub fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
let (addr_type, mut ip_oct, mut port) = match sock_addr {
SocketAddr::V4(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV4,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
SocketAddr::V6(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV6,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
};
let mut reply = vec![
consts::SOCKS5_VERSION,
error.as_u8(), // transform the error into byte code
0x00, // reserved
addr_type, // address type (ipv4, v6, domain)
];
reply.append(&mut ip_oct);
reply.append(&mut port);
reply
}
#[cfg(test)]
mod test {
use super::*;

View file

@ -1,8 +1,7 @@
use tokio_fd::AsyncFd;
use tracing::info;
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
info!("Starting STDIO server");
eprintln!("Starting STDIO server");
let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?;

View file

@ -5,7 +5,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use crate::{tcp, tls, L4Protocol, LocalToRemote, WsClientConfig, WsServerConfig};
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig, WsServerConfig};
use anyhow::Context;
use fastwebsockets::{
Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite,
@ -28,7 +28,7 @@ use tokio::time::timeout;
use crate::udp::MyUdpSocket;
use serde::{Deserialize, Serialize};
use tracing::log::debug;
use tracing::{error, info, instrument, trace, warn, Instrument, Span};
use tracing::{error, info, instrument, span, trace, warn, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
@ -47,7 +47,7 @@ where
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String,
pub p: L4Protocol,
pub p: LocalProtocol,
pub r: String,
pub rp: u16,
}
@ -81,7 +81,12 @@ pub async fn connect(
let data = JwtTunnelConfig {
id: request_id.to_string(),
p: tunnel_cfg.protocol,
p: match tunnel_cfg.local_protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 => LocalProtocol::Tcp,
},
r: tunnel_cfg.remote.0.to_string(),
rp: tunnel_cfg.remote.1,
};
@ -166,7 +171,7 @@ async fn from_query(
server_config: &WsServerConfig,
query: &str,
) -> anyhow::Result<(
L4Protocol,
LocalProtocol,
Host,
u16,
Pin<Box<dyn AsyncRead + Send>>,
@ -204,19 +209,19 @@ async fn from_query(
}
match jwt.claims.p {
L4Protocol::Udp { .. } => {
LocalProtocol::Udp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let cnx = Arc::new(UdpSocket::bind("[::]:0").await?);
cnx.connect((host.to_string(), jwt.claims.rp)).await?;
Ok((
L4Protocol::Udp { timeout: None },
LocalProtocol::Udp { timeout: None },
host,
jwt.claims.rp,
Box::pin(MyUdpSocket::new(cnx.clone())),
Box::pin(MyUdpSocket::new(cnx)),
))
}
L4Protocol::Tcp { .. } => {
LocalProtocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(
@ -330,7 +335,15 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
Span::current().record("peer", peer_addr.to_string());
let span = span!(
Level::INFO,
"tunnel",
id = tracing::field::Empty,
remote = tracing::field::Empty,
peer = peer_addr.to_string(),
forwarded_for = tracing::field::Empty
);
info!("Accepting connection");
let upgrade_fn = upgrade_fn.clone();
// TLS
@ -354,7 +367,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
error!("Error while upgrading cnx to websocket: {:?}", e);
}
}
.instrument(Span::current());
.instrument(span);
tokio::spawn(fut);
// Normal
@ -369,7 +382,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
error!("Error while upgrading cnx to weboscket: {:?}", e);
}
}
.instrument(Span::current());
.instrument(span);
tokio::spawn(fut);
};