diff --git a/restrictions.yaml b/restrictions.yaml index a1a49f8..795b73a 100644 --- a/restrictions.yaml +++ b/restrictions.yaml @@ -18,7 +18,10 @@ restrictions: - Tcp - Udp # Port that are allowed. Can be a single port or an inclusive range (i.e. 80..90) - port: 9999 + port: + - 80 + - 443 + - 8080..8089 # if the tunnel wants to connect to a specific host, this regex must match host: ^.*$ @@ -35,7 +38,8 @@ restrictions: - Udp - Socks5 - Unix - port: 1..65535 + port: + - 1..65535 cidr: - 0.0.0.0/0 - ::/0 @@ -48,7 +52,8 @@ restrictions: match: !PathPrefix "^.*$" allow: - !Tunnel - port: 443 + port: + - 443 --- restrictions: - name: "example 2" @@ -58,7 +63,8 @@ restrictions: - !Tunnel protocol: - Tcp - port: 22 + port: + - 22 host: ^localhost$ cidr: - 127.0.0.1/32 @@ -71,7 +77,8 @@ restrictions: - !ReverseTunnel protocol: - Socks5 - port: 1080..1443 + port: + - 1080..1443 cidr: - 192.168.0.0/16 --- diff --git a/src/main.rs b/src/main.rs index de742e6..0421739 100644 --- a/src/main.rs +++ b/src/main.rs @@ -259,12 +259,6 @@ struct Server { #[arg(long, default_value = "false", verbatim_doc_comment)] websocket_mask_frame: bool, - /// Server will only accept connection from the specified tunnel information. - /// Can be specified multiple time - /// Example: --restrict-to "google.com:443" --restrict-to "localhost:22" - #[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)] - restrict_to: Option>, - /// Dns resolver to use to lookup ips of domain name /// This option is not going to work if you use transparent proxy /// Can be specified multiple time @@ -277,6 +271,12 @@ struct Server { #[arg(long, verbatim_doc_comment)] dns_resolver: Option>, + /// Server will only accept connection from the specified tunnel information. + /// Can be specified multiple time + /// Example: --restrict-to "google.com:443" --restrict-to "localhost:22" + #[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)] + restrict_to: Option>, + /// Server will only accept connection from if this specific path prefix is used during websocket upgrade. /// Useful if you specify in the client a custom path prefix, and you want the server to only allow this one. /// The path prefix act as a secret to authenticate clients @@ -292,7 +292,7 @@ struct Server { /// Path to the location of the restriction yaml config file. /// Restriction file is automatically reloaded if it changes #[arg(long, verbatim_doc_comment)] - restriction_file: Option, + restrict_config: Option, /// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate. /// The certificate will be automatically reloaded if it changes @@ -1260,7 +1260,7 @@ async fn main() { } }; - let restrictions = if let Some(path) = &args.restriction_file { + let restrictions = if let Some(path) = &args.restrict_config { RestrictionsRules::from_config_file(path).expect("Cannot parse restriction file") } else { let restrict_to: Vec<(String, u16)> = args diff --git a/src/restrictions/mod.rs b/src/restrictions/mod.rs index 31d1149..5581099 100644 --- a/src/restrictions/mod.rs +++ b/src/restrictions/mod.rs @@ -1,4 +1,4 @@ -use crate::restrictions::types::{default_cidr, default_host, default_port}; +use crate::restrictions::types::{default_cidr, default_host}; use regex::Regex; use std::fs::File; use std::io::BufReader; @@ -21,7 +21,7 @@ impl RestrictionsRules { let mut tunnels_restrictions = if restrict_to.is_empty() { let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig { protocol: vec![], - port: default_port(), + port: vec![], host: default_host(), cidr: default_cidr(), }); @@ -30,21 +30,20 @@ impl RestrictionsRules { restrict_to .iter() .map(|(host, port)| { - // Fixme: Remove the unwrap - let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap(); - types::AllowConfig::Tunnel(types::AllowTunnelConfig { + let reg = Regex::new(&format!("^{}$", regex::escape(host)))?; + Ok(types::AllowConfig::Tunnel(types::AllowTunnelConfig { protocol: vec![], - port: RangeInclusive::new(*port, *port), + port: vec![RangeInclusive::new(*port, *port)], host: reg, cidr: default_cidr(), - }) + })) }) - .collect() + .collect::, anyhow::Error>>()? }; tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig { protocol: vec![], - port: default_port(), + port: vec![], cidr: default_cidr(), })); @@ -61,19 +60,16 @@ impl RestrictionsRules { path_prefixes .iter() .map(|path_prefix| { - // Fixme: Remove the unwrap - let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap(); - types::RestrictionConfig { + let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix)))?; + Ok(types::RestrictionConfig { name: format!("Allow path prefix {}", path_prefix), r#match: types::MatchConfig::PathPrefix(reg), allow: tunnels_restrictions.clone(), - } + }) }) - .collect() + .collect::, anyhow::Error>>()? }; - let restrictions = RestrictionsRules { restrictions }; - - Ok(restrictions) + Ok(RestrictionsRules { restrictions }) } } diff --git a/src/restrictions/types.rs b/src/restrictions/types.rs index 7f065d1..f20e5e2 100644 --- a/src/restrictions/types.rs +++ b/src/restrictions/types.rs @@ -35,8 +35,8 @@ pub struct AllowTunnelConfig { pub protocol: Vec, #[serde(deserialize_with = "deserialize_port_range")] - #[serde(default = "default_port")] - pub port: RangeInclusive, + #[serde(default)] + pub port: Vec>, #[serde(with = "serde_regex")] #[serde(default = "default_host")] @@ -52,8 +52,8 @@ pub struct AllowReverseTunnelConfig { pub protocol: Vec, #[serde(deserialize_with = "deserialize_port_range")] - #[serde(default = "default_port")] - pub port: RangeInclusive, + #[serde(default)] + pub port: Vec>, #[serde(default = "default_cidr")] pub cidr: Vec, @@ -75,10 +75,6 @@ pub enum ReverseTunnelConfigProtocol { Unknown, } -pub fn default_port() -> RangeInclusive { - RangeInclusive::new(1, 65535) -} - pub fn default_host() -> Regex { Regex::new("^.*$").unwrap() } @@ -87,22 +83,30 @@ pub fn default_cidr() -> Vec { vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())] } -fn deserialize_port_range<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_port_range<'de, D>(deserializer: D) -> Result>, D::Error> where D: Deserializer<'de>, { - let s = String::deserialize(deserializer)?; - let range = if let Some((l, r)) = s.split_once("..") { - RangeInclusive::new( - l.parse().map_err(serde::de::Error::custom)?, - r.parse().map_err(serde::de::Error::custom)?, - ) - } else { - let port = s.parse::().map_err(serde::de::Error::custom)?; - RangeInclusive::new(port, port) - }; + let s = Vec::::deserialize(deserializer)?; + let ranges = s + .into_iter() + .map(|s| { + let range: Result, D::Error> = if let Some((l, r)) = s.split_once("..") { + Ok(RangeInclusive::new( + l.parse().map_err(::custom)?, + r.parse().map_err(::custom)?, + )) + } else { + let port = s.parse::().map_err(serde::de::Error::custom)?; + Ok(RangeInclusive::new(port, port)) + }; + range + }) + .collect::>() + .into_iter() + .collect::>, D::Error>>()?; - Ok(range) + Ok(ranges) } impl From<&LocalProtocol> for ReverseTunnelConfigProtocol { diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 67cfb80..6a3fcd0 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -211,7 +211,7 @@ where } Some(Ok(cnx)) => { if tx.send_timeout(cnx, send_timeout).await.is_err() { - info!("New remote connection failed to be picked by client after {}s. Closing remote tunnel server", send_timeout.as_secs()); + info!("New reverse connection failed to be picked by client after {}s. Closing reverse tunnel server", send_timeout.as_secs()); break; } } @@ -223,7 +223,7 @@ where } } } - info!("Stopping listening server"); + info!("Stopping listening reverse server"); }; tokio::spawn(fut.instrument(Span::current())); @@ -233,7 +233,7 @@ where let cnx = listening_server .recv() .await - .ok_or_else(|| anyhow!("listening server stopped"))?; + .ok_or_else(|| anyhow!("listening reverse server stopped"))?; servers.lock().insert(local_srv.clone(), listening_server); Ok(cnx) } @@ -314,7 +314,6 @@ fn extract_tunnel_info(req: &Request) -> Result( - _req: &Request, remote: &RemoteAddr, path_prefix: &str, restrictions: &'a RestrictionsRules, @@ -332,7 +331,11 @@ fn validate_tunnel<'a>( for allow in &restriction.allow { match allow { AllowConfig::ReverseTunnel(allow) => { - if !remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) { + if !remote.protocol.is_reverse_tunnel() { + continue; + } + + if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { continue; } @@ -366,7 +369,11 @@ fn validate_tunnel<'a>( } AllowConfig::Tunnel(allow) => { - if remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) { + if remote.protocol.is_reverse_tunnel() { + continue; + } + + if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { continue; } @@ -458,7 +465,7 @@ async fn ws_server_upgrade( } }; - match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) { + match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); } @@ -573,7 +580,7 @@ async fn http_server_upgrade( } }; - match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) { + match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); }