diff --git a/restrictions.yaml b/restrictions.yaml index 368cb4c..c86e95a 100644 --- a/restrictions.yaml +++ b/restrictions.yaml @@ -48,6 +48,11 @@ restrictions: - Unix port: - 1..65535 + # Maps ports on the server side from X to Y (X:Y). For example with 10001:8080 configured and a client + # which connects using '-R tcp://10001:localhost:80' the server will listen on port 8080 instead of 10001. + # The originally requested ports (NOT the mapped ports) need to be allowed via the 'ports' directive. + port_mapping: + - 10001:8080 cidr: - 0.0.0.0/0 - ::/0 diff --git a/src/restrictions/mod.rs b/src/restrictions/mod.rs index cf09ca3..87f7bd1 100644 --- a/src/restrictions/mod.rs +++ b/src/restrictions/mod.rs @@ -36,6 +36,7 @@ impl RestrictionsRules { let reverse_tunnel = types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig { protocol: vec![], port: vec![], + port_mapping: Default::default(), cidr: default_cidr(), }); @@ -56,6 +57,7 @@ impl RestrictionsRules { types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig { protocol: vec![], port: vec![RangeInclusive::new(*port, *port)], + port_mapping: Default::default(), cidr: vec![IpNet::new(ip, if ip.is_ipv4() { 32 } else { 128 })?], }), ] @@ -70,6 +72,7 @@ impl RestrictionsRules { types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig { protocol: vec![], port: vec![], + port_mapping: Default::default(), cidr: default_cidr(), }), ] diff --git a/src/restrictions/types.rs b/src/restrictions/types.rs index b7eefde..e0cd49c 100644 --- a/src/restrictions/types.rs +++ b/src/restrictions/types.rs @@ -2,6 +2,7 @@ use crate::LocalProtocol; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use regex::Regex; use serde::{Deserialize, Deserializer}; +use std::collections::HashMap; use std::ops::RangeInclusive; #[derive(Debug, Clone, Deserialize)] @@ -56,6 +57,10 @@ pub struct AllowReverseTunnelConfig { #[serde(default)] pub port: Vec>, + #[serde(deserialize_with = "deserialize_port_mapping")] + #[serde(default)] + pub port_mapping: HashMap, + #[serde(default = "default_cidr")] pub cidr: Vec, } @@ -110,6 +115,29 @@ where Ok(ranges) } +fn deserialize_port_mapping<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let mappings: Vec = Deserialize::deserialize(deserializer)?; + mappings + .into_iter() + .map(|port_mapping| { + let port_mapping_parts: Vec<&str> = port_mapping.split(':').collect(); + if port_mapping_parts.len() != 2 { + Err(serde::de::Error::custom(format!( + "Invalid port_mapping entry: {}", + port_mapping + ))) + } else { + let orig_port = port_mapping_parts[0].parse::().map_err(serde::de::Error::custom)?; + let target_port = port_mapping_parts[1].parse::().map_err(serde::de::Error::custom)?; + Ok((orig_port, target_port)) + } + }) + .collect() +} + fn deserialize_non_empty_vec<'de, D, T>(d: D) -> Result, D::Error> where D: Deserializer<'de>, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index d53ccd1..e17b40d 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -50,6 +50,7 @@ use uuid::Uuid; async fn run_tunnel( server_config: &WsServerConfig, + restriction: &RestrictionConfig, remote: RemoteAddr, client_address: SocketAddr, ) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { @@ -94,7 +95,8 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (remote.host, remote.port); + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = tcp::run_server(bind.parse()?, false); let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; @@ -112,7 +114,8 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (remote.host, remote.port); + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone())); @@ -131,7 +134,8 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (remote.host, remote.port); + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = socks5::run_server(bind.parse()?, None); let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; @@ -154,7 +158,8 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (remote.host, remote.port); + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); let listening_server = unix_socket::run_server(path); let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = stream.into_split(); @@ -182,6 +187,29 @@ async fn run_tunnel( } } +/// Checks if the requested (remote) port has been mapped in the configuration to another port. +/// If it is not mapped the original port number is returned. +#[inline] +fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -> u16 { + // Determine if the requested port is to be mapped to a different port. + let remote_port = restriction + .allow + .iter() + .find_map(|allow| { + if let AllowConfig::ReverseTunnel(allow) = allow { + return allow.port_mapping.get(&req_port).cloned(); + } + None + }) + .unwrap_or(req_port); + + if req_port != remote_port { + info!("Client requested port {} was mapped to {}", req_port, remote_port); + } + + remote_port +} + #[allow(clippy::type_complexity)] async fn run_listening_server( local_srv: &(Host, u16), @@ -482,15 +510,16 @@ async fn ws_server_upgrade( } }; - match validate_tunnel(&remote, path_prefix, &restrictions) { + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction } Err(err) => return err, - } + }; let req_protocol = remote.protocol.clone(); - let tunnel = match run_tunnel(&server_config, remote, client_addr).await { + let tunnel = match run_tunnel(&server_config, restriction, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); @@ -612,15 +641,16 @@ async fn http_server_upgrade( } }; - match validate_tunnel(&remote, path_prefix, &restrictions) { + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction } Err(err) => return err.map(Either::Left), - } + }; let req_protocol = remote.protocol.clone(); - let tunnel = match run_tunnel(&server_config, remote, client_addr).await { + let tunnel = match run_tunnel(&server_config, restriction, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());