cleanup argument parsing

This commit is contained in:
Σrebe - Romain GERARD 2024-08-01 22:26:05 +02:00
parent f149b8190b
commit 811a1e6adf
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4

View file

@ -114,7 +114,7 @@ struct Client {
/// 'socks5://[::1]:1212' => listen on server for incoming socks5 request on port 1212 and forward dynamically request from local machine (login/password is supported) /// 'socks5://[::1]:1212' => listen on server for incoming socks5 request on port 1212 and forward dynamically request from local machine (login/password is supported)
/// 'http://[::1]:1212' => listen on server for incoming http proxy request on port 1212 and forward dynamically request from local machine (login/password is supported) /// 'http://[::1]:1212' => listen on server for incoming http proxy request on port 1212 and forward dynamically request from local machine (login/password is supported)
/// 'unix://wstunnel.sock:g.com:443' => listen on server for incoming data from unix socket of path wstunnel.sock and forward to g.com:443 from local machine /// 'unix://wstunnel.sock:g.com:443' => listen on server for incoming data from unix socket of path wstunnel.sock and forward to g.com:443 from local machine
#[arg(short='R', long, value_name = "{tcp,udp,socks5,unix}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] #[arg(short='R', long, value_name = "{tcp,udp,socks5,unix}://[BIND:]PORT:HOST:PORT", value_parser = parse_reverse_tunnel_arg, verbatim_doc_comment)]
remote_to_local: Vec<LocalToRemote>, remote_to_local: Vec<LocalToRemote>,
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value. /// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
@ -468,35 +468,53 @@ fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<Str
fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> { fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
use std::io::Error; use std::io::Error;
let get_timeout = |options: &BTreeMap<String, String>| {
options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)))
};
let get_credentials = |options: &BTreeMap<String, String>| {
options
.get("login")
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string())))
};
let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol");
match &arg[..6] { let Some((proto, tunnel_info)) = arg.split_once("://") else {
"tcp://" => { return Err(Error::new(
let (local_bind, remaining) = parse_local_bind(&arg[6..])?; ErrorKind::InvalidInput,
format!("cannot parse protocol from {}", arg),
));
};
match proto {
"tcp" => {
let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?; let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let proxy_protocol = options.contains_key("proxy_protocol");
Ok(LocalToRemote { Ok(LocalToRemote {
local_protocol: LocalProtocol::Tcp { proxy_protocol }, local_protocol: LocalProtocol::Tcp {
proxy_protocol: get_proxy_protocol(&options),
},
local: local_bind, local: local_bind,
remote: (dest_host, dest_port), remote: (dest_host, dest_port),
}) })
} }
"udp://" => { "udp" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?; let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?; let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote { Ok(LocalToRemote {
local_protocol: LocalProtocol::Udp { timeout }, local_protocol: LocalProtocol::Udp {
timeout: get_timeout(&options),
},
local: local_bind, local: local_bind,
remote: (dest_host, dest_port), remote: (dest_host, dest_port),
}) })
} }
"unix:/" => { "unix" => {
let Some((path, remote)) = arg[7..].split_once(':') else { let Some((path, remote)) = tunnel_info.split_once(':') else {
return Err(Error::new( return Err(Error::new(
ErrorKind::InvalidInput, ErrorKind::InvalidInput,
format!("cannot parse unix socket path from {}", arg), format!("cannot parse unix socket path from {}", arg),
@ -511,89 +529,104 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
remote: (dest_host, dest_port), remote: (dest_host, dest_port),
}) })
} }
"http:/" => { "http" => {
let (local_bind, remaining) = parse_local_bind(&arg["http://".len()..])?; let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let x = format!("0.0.0.0:0?{}", remaining); let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let proxy_protocol = options.contains_key("proxy_protocol");
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
.unwrap_or(Some(Duration::from_secs(30)));
let credentials = options
.get("login")
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string())));
Ok(LocalToRemote { Ok(LocalToRemote {
local_protocol: LocalProtocol::HttpProxy { local_protocol: LocalProtocol::HttpProxy {
timeout, timeout: get_timeout(&options),
credentials, credentials: get_credentials(&options),
proxy_protocol, proxy_protocol: get_proxy_protocol(&options),
}, },
local: local_bind, local: local_bind,
remote: (dest_host, dest_port), remote: (dest_host, dest_port),
}) })
} }
_ => match &arg[..8] { "socks5" => {
"socks5:/" => { let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
let (local_bind, remaining) = parse_local_bind(&arg["socks5://".len()..])?; let x = format!("0.0.0.0:0?{}", remaining);
let x = format!("0.0.0.0:0?{}", remaining); let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; Ok(LocalToRemote {
let timeout = options local_protocol: LocalProtocol::Socks5 {
.get("timeout_sec") timeout: get_timeout(&options),
.and_then(|x| x.parse::<u64>().ok()) credentials: get_credentials(&options),
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) }) },
.unwrap_or(Some(Duration::from_secs(30))); local: local_bind,
let credentials = options remote: (dest_host, dest_port),
.get("login") })
.and_then(|login| options.get("password").map(|p| (login.to_string(), p.to_string()))); }
Ok(LocalToRemote { "stdio" => {
local_protocol: LocalProtocol::Socks5 { timeout, credentials }, let (dest_host, dest_port, _options) = parse_tunnel_dest(tunnel_info)?;
local: local_bind, Ok(LocalToRemote {
remote: (dest_host, dest_port), local_protocol: LocalProtocol::Stdio,
}) local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
} remote: (dest_host, dest_port),
"stdio://" => { })
let (dest_host, dest_port, _options) = parse_tunnel_dest(&arg["stdio://".len()..])?; }
Ok(LocalToRemote { "tproxy+tcp" => {
local_protocol: LocalProtocol::Stdio, let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)), let x = format!("0.0.0.0:0?{}", remaining);
remote: (dest_host, dest_port), let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?;
}) Ok(LocalToRemote {
} local_protocol: LocalProtocol::TProxyTcp,
"tproxy+t" => { local: local_bind,
let (local_bind, remaining) = parse_local_bind(&arg["tproxy+tcp://".len()..])?; remote: (dest_host, dest_port),
let x = format!("0.0.0.0:0?{}", remaining); })
let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?; }
Ok(LocalToRemote { "tproxy+udp" => {
local_protocol: LocalProtocol::TProxyTcp, let (local_bind, remaining) = parse_local_bind(tunnel_info)?;
local: local_bind, let x = format!("0.0.0.0:0?{}", remaining);
remote: (dest_host, dest_port), let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
}) Ok(LocalToRemote {
} local_protocol: LocalProtocol::TProxyUdp {
"tproxy+u" => { timeout: get_timeout(&options),
let (local_bind, remaining) = parse_local_bind(&arg["tproxy+udp://".len()..])?; },
let x = format!("0.0.0.0:0?{}", remaining); local: local_bind,
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; remote: (dest_host, dest_port),
let timeout = options })
.get("timeout_sec") }
.and_then(|x| x.parse::<u64>().ok()) _ => Err(Error::new(
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) }) ErrorKind::InvalidInput,
.unwrap_or(Some(Duration::from_secs(30))); format!("Invalid local protocol for tunnel {}", arg),
Ok(LocalToRemote { )),
local_protocol: LocalProtocol::TProxyUdp { timeout },
local: local_bind,
remote: (dest_host, dest_port),
})
}
_ => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid local protocol for tunnel {}", arg),
)),
},
} }
} }
fn parse_reverse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
let proto = parse_tunnel_arg(arg)?;
let local_protocol = match proto.local_protocol {
LocalProtocol::Tcp { .. } => LocalProtocol::ReverseTcp {},
LocalProtocol::Udp { timeout } => LocalProtocol::ReverseUdp { timeout },
LocalProtocol::Socks5 { timeout, credentials } => LocalProtocol::ReverseSocks5 { timeout, credentials },
LocalProtocol::HttpProxy {
timeout,
credentials,
proxy_protocol: _proxy_protocol,
} => LocalProtocol::ReverseHttpProxy { timeout, credentials },
LocalProtocol::Unix { path } => LocalProtocol::ReverseUnix { path },
LocalProtocol::ReverseTcp { .. }
| LocalProtocol::ReverseUdp { .. }
| LocalProtocol::ReverseSocks5 { .. }
| LocalProtocol::ReverseHttpProxy { .. }
| LocalProtocol::ReverseUnix { .. }
| LocalProtocol::TProxyTcp
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Stdio => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("Cannot use {:?} as reverse tunnels {}", proto.local_protocol, arg),
))
}
};
Ok(LocalToRemote {
local_protocol,
local: proto.local,
remote: proto.remote,
})
}
fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> { fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> {
match DnsName::try_from(arg.to_string()) { match DnsName::try_from(arg.to_string()) {
Ok(val) => Ok(val), Ok(val) => Ok(val),
@ -788,7 +821,7 @@ async fn main() -> anyhow::Result<()> {
for tunnel in args.remote_to_local.into_iter() { for tunnel in args.remote_to_local.into_iter() {
let client = client.clone(); let client = client.clone();
match &tunnel.local_protocol { match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol: _ } => { LocalProtocol::ReverseTcp { .. } => {
tokio::spawn(async move { tokio::spawn(async move {
let cfg = client.config.clone(); let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new( let tcp_connector = TcpTunnelConnector::new(
@ -809,7 +842,7 @@ async fn main() -> anyhow::Result<()> {
} }
}); });
} }
LocalProtocol::Udp { timeout } => { LocalProtocol::ReverseUdp { timeout } => {
let timeout = *timeout; let timeout = *timeout;
tokio::spawn(async move { tokio::spawn(async move {
@ -833,7 +866,7 @@ async fn main() -> anyhow::Result<()> {
} }
}); });
} }
LocalProtocol::Socks5 { timeout, credentials } => { LocalProtocol::ReverseSocks5 { timeout, credentials } => {
let credentials = credentials.clone(); let credentials = credentials.clone();
let timeout = *timeout; let timeout = *timeout;
tokio::spawn(async move { tokio::spawn(async move {
@ -852,9 +885,7 @@ async fn main() -> anyhow::Result<()> {
} }
}); });
} }
LocalProtocol::HttpProxy { LocalProtocol::ReverseHttpProxy { timeout, credentials } => {
timeout, credentials, ..
} => {
let credentials = credentials.clone(); let credentials = credentials.clone();
let timeout = *timeout; let timeout = *timeout;
tokio::spawn(async move { tokio::spawn(async move {
@ -879,7 +910,7 @@ async fn main() -> anyhow::Result<()> {
}); });
} }
#[cfg(unix)] #[cfg(unix)]
LocalProtocol::Unix { path } => { LocalProtocol::ReverseUnix { path } => {
let path = path.clone(); let path = path.clone();
tokio::spawn(async move { tokio::spawn(async move {
let cfg = client.config.clone(); let cfg = client.config.clone();
@ -903,17 +934,17 @@ async fn main() -> anyhow::Result<()> {
}); });
} }
#[cfg(not(unix))] #[cfg(not(unix))]
LocalProtocol::Unix { .. } => { LocalProtocol::ReverseUnix { .. } => {
panic!("Unix socket is not available for non Unix platform") panic!("Unix socket is not available for non Unix platform")
} }
LocalProtocol::Stdio LocalProtocol::Stdio
| LocalProtocol::TProxyTcp | LocalProtocol::TProxyTcp
| LocalProtocol::TProxyUdp { .. } | LocalProtocol::TProxyUdp { .. }
| LocalProtocol::ReverseTcp | LocalProtocol::Tcp { .. }
| LocalProtocol::ReverseUdp { .. } | LocalProtocol::Udp { .. }
| LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::Socks5 { .. }
| LocalProtocol::ReverseHttpProxy { .. } => {} | LocalProtocol::HttpProxy { .. } => {}
LocalProtocol::ReverseUnix { .. } => { LocalProtocol::Unix { .. } => {
panic!("Invalid protocol for reverse tunnel"); panic!("Invalid protocol for reverse tunnel");
} }
} }