From 8a228248d726fd2c7cb1d0d2eadd558b49491bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sat, 27 Apr 2024 22:40:32 +0200 Subject: [PATCH] Add config file for restrictions --- Cargo.lock | 36 ++++++ Cargo.toml | 8 +- restrictions.yaml | 84 ++++++++++++++ src/main.rs | 47 +++++++- src/restrictions/mod.rs | 79 +++++++++++++ src/restrictions/types.rs | 141 ++++++++++++++++++++++ src/tunnel/server.rs | 239 +++++++++++++++++++++++++++----------- 7 files changed, 559 insertions(+), 75 deletions(-) create mode 100644 restrictions.yaml create mode 100644 src/restrictions/mod.rs create mode 100644 src/restrictions/types.rs diff --git a/Cargo.lock b/Cargo.lock index 5c127c2..f5ca821 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,6 +913,9 @@ name = "ipnet" version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +dependencies = [ + "serde", +] [[package]] name = "itoa" @@ -1524,6 +1527,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_regex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf" +dependencies = [ + "regex", + "serde", +] + [[package]] name = "serde_with" version = "1.14.0" @@ -1546,6 +1559,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha1" version = "0.10.6" @@ -1950,6 +1976,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.9.0" @@ -2285,6 +2317,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "ipnet", "jsonwebtoken", "log", "nix", @@ -2293,10 +2326,13 @@ dependencies = [ "parking_lot", "pin-project", "ppp", + "regex", "rustls-native-certs", "rustls-pemfile 2.1.1", "scopeguard", "serde", + "serde_regex", + "serde_yaml", "socket2", "testcontainers", "tokio", diff --git a/Cargo.toml b/Cargo.toml index edf1109..f354599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,13 @@ fast-socks5 = { version = "0.9.6", features = [] } fastwebsockets = { version = "0.7.1", features = ["upgrade", "simd", "unstable-split"] } futures-util = { version = "0.3.30" } hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] } -ppp = { version = "2.2.0", features = [] } +ppp = { version = "2.2.0", features = [] } + +# For config file parsing +regex = { version = "1.10.4", default-features = false, features = ["std", "perf"] } +serde_regex = "1.1.0" +serde_yaml = { version = "0.9.34", features = [] } +ipnet = { version = "2.9.0", features = ["serde"] } hyper = { version = "1.2.0", features = ["client", "http1", "http2"] } hyper-util = { version = "0.1.3", features = ["tokio", "server", "server-auto"] } diff --git a/restrictions.yaml b/restrictions.yaml new file mode 100644 index 0000000..a1a49f8 --- /dev/null +++ b/restrictions.yaml @@ -0,0 +1,84 @@ +# Restrictions are whitelist rules for the tunnels +# By default, all requests are denied and only if a restriction match, the request is allowed +restrictions: + - name: "Allow all" + description: "This restriction allows all requests" + # This restriction apply only if it matches the prefix that match the given regex + # The regex does a match, so if you want to match exactly you need to bound the pattern with ^ $ + # I.e: "tesotron" is going to match "XXXtesotronXXX", but "^tesotron$" is going to match only "tesotron" + match: !PathPrefix "^.*$" + + # This is th list of tunnels your restriction is going to allow + # The list is going to be checked in order, the first match is going to allow the request + allow: + # !Tunnel allows forward tunnels + - !Tunnel + # Protocol that are allowed. Empty list means all protocols are allowed + protocol: + - Tcp + - Udp + # Port that are allowed. Can be a single port or an inclusive range (i.e. 80..90) + port: 9999 + + # if the tunnel wants to connect to a specific host, this regex must match + host: ^.*$ + # if the tunnel wants to connect to a specific IP, it must match one of the network cidr + cidr: + - 0.0.0.0/0 + - ::/0 + + # !ReverseTunnel allows reverse tunnels + # Not specifying anything means all reverse tunnels are allowed + - !ReverseTunnel + protocol: + - Tcp + - Udp + - Socks5 + - Unix + port: 1..65535 + cidr: + - 0.0.0.0/0 + - ::/0 + +--- +# Examples +restrictions: + - name: "example 1" + description: "Only allow forward tunnels to port 443 and forbid reverse tunnels" + match: !PathPrefix "^.*$" + allow: + - !Tunnel + port: 443 +--- +restrictions: + - name: "example 2" + description: "Only allow forward tunnels to local ssh and forbid reverse tunnels" + match: !PathPrefix "^.*$" + allow: + - !Tunnel + protocol: + - Tcp + port: 22 + host: ^localhost$ + cidr: + - 127.0.0.1/32 +--- +restrictions: + - name: "example 3" + description: "Only allow socks5 reverse tunnels listening on port between 1080..1443 on lan network" + match: !PathPrefix "^.*$" + allow: + - !ReverseTunnel + protocol: + - Socks5 + port: 1080..1443 + cidr: + - 192.168.0.0/16 +--- +restrictions: + - name: "example 4" + description: "Allow everything for client using path prefix my-super-secret-path" + match: !PathPrefix "^my-super-secret-path$" + allow: + - !Tunnel + - !ReverseTunnel diff --git a/src/main.rs b/src/main.rs index 79abcc1..de742e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod dns; mod embedded_certificate; +mod restrictions; mod socks5; mod socks5_udp; mod stdio; @@ -40,6 +41,7 @@ use tokio_rustls::TlsConnector; use tracing::{error, info}; use crate::dns::DnsResolver; +use crate::restrictions::types::RestrictionsRules; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::udp::MyUdpSocket; @@ -287,6 +289,11 @@ struct Server { )] restrict_http_upgrade_path_prefix: Option>, + /// Path to the location of the restriction yaml config file. + /// Restriction file is automatically reloaded if it changes + #[arg(long, verbatim_doc_comment)] + restriction_file: Option, + /// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate. /// The certificate will be automatically reloaded if it changes #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] @@ -319,6 +326,15 @@ enum LocalProtocol { Unix { path: PathBuf }, } +impl LocalProtocol { + pub fn is_reverse_tunnel(&self) -> bool { + matches!( + self, + LocalProtocol::ReverseTcp | LocalProtocol::ReverseUdp { .. } | LocalProtocol::ReverseSocks5 + ) + } +} + #[derive(Clone, Debug)] pub struct LocalToRemote { local_protocol: LocalProtocol, @@ -607,13 +623,12 @@ pub struct TlsServerConfig { pub struct WsServerConfig { pub socket_so_mark: Option, pub bind: SocketAddr, - pub restrict_to: Option>, - pub restrict_http_upgrade_path_prefix: Option>, pub websocket_ping_frequency: Option, pub timeout_connect: Duration, pub websocket_mask_frame: bool, pub tls: Option, pub dns_resolver: DnsResolver, + pub restrictions: RestrictionsRules, } impl Debug for WsServerConfig { @@ -621,8 +636,6 @@ impl Debug for WsServerConfig { f.debug_struct("WsServerConfig") .field("socket_so_mark", &self.socket_so_mark) .field("bind", &self.bind) - .field("restrict_to", &self.restrict_to) - .field("restrict_http_upgrade_path_prefix", &self.restrict_http_upgrade_path_prefix) .field("websocket_ping_frequency", &self.websocket_ping_frequency) .field("timeout_connect", &self.timeout_connect) .field("websocket_mask_frame", &self.websocket_mask_frame) @@ -1246,16 +1259,38 @@ async fn main() { } } }; + + let restrictions = if let Some(path) = &args.restriction_file { + RestrictionsRules::from_config_file(path).expect("Cannot parse restriction file") + } else { + let restrict_to: Vec<(String, u16)> = args + .restrict_to + .as_deref() + .unwrap_or(&[]) + .iter() + .map(|x| { + let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format"); + (host.to_string(), port.parse::().expect("Invalid restrict-to port format")) + }) + .collect(); + + let restriction_cfg = RestrictionsRules::from_path_prefix( + args.restrict_http_upgrade_path_prefix.as_deref().unwrap_or(&[]), + &restrict_to, + ) + .expect("Cannot covertion restriction rules from path-prefix and restric-to"); + restriction_cfg + }; + let server_config = WsServerConfig { socket_so_mark: args.socket_so_mark, bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], - restrict_to: args.restrict_to, - restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix, websocket_ping_frequency: args.websocket_ping_frequency_sec, timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, tls: tls_config, dns_resolver, + restrictions, }; info!( diff --git a/src/restrictions/mod.rs b/src/restrictions/mod.rs new file mode 100644 index 0000000..31d1149 --- /dev/null +++ b/src/restrictions/mod.rs @@ -0,0 +1,79 @@ +use crate::restrictions::types::{default_cidr, default_host, default_port}; +use regex::Regex; +use std::fs::File; +use std::io::BufReader; +use std::ops::RangeInclusive; +use std::path::Path; +use types::RestrictionsRules; + +pub mod types; + +impl RestrictionsRules { + pub fn from_config_file(config_path: &Path) -> anyhow::Result { + let restrictions: RestrictionsRules = serde_yaml::from_reader(BufReader::new(File::open(config_path)?))?; + Ok(restrictions) + } + + pub fn from_path_prefix( + path_prefixes: &[String], + restrict_to: &[(String, u16)], + ) -> anyhow::Result { + let mut tunnels_restrictions = if restrict_to.is_empty() { + let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig { + protocol: vec![], + port: default_port(), + host: default_host(), + cidr: default_cidr(), + }); + vec![r] + } else { + restrict_to + .iter() + .map(|(host, port)| { + // Fixme: Remove the unwrap + let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap(); + types::AllowConfig::Tunnel(types::AllowTunnelConfig { + protocol: vec![], + port: RangeInclusive::new(*port, *port), + host: reg, + cidr: default_cidr(), + }) + }) + .collect() + }; + + tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig { + protocol: vec![], + port: default_port(), + cidr: default_cidr(), + })); + + let restrictions = if path_prefixes.is_empty() { + // if no path prefixes are provided, we allow all + let reg = Regex::new(".").unwrap(); + let r = types::RestrictionConfig { + name: "Allow All".to_string(), + r#match: types::MatchConfig::PathPrefix(reg), + allow: tunnels_restrictions, + }; + vec![r] + } else { + path_prefixes + .iter() + .map(|path_prefix| { + // Fixme: Remove the unwrap + let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap(); + types::RestrictionConfig { + name: format!("Allow path prefix {}", path_prefix), + r#match: types::MatchConfig::PathPrefix(reg), + allow: tunnels_restrictions.clone(), + } + }) + .collect() + }; + + let restrictions = RestrictionsRules { restrictions }; + + Ok(restrictions) + } +} diff --git a/src/restrictions/types.rs b/src/restrictions/types.rs new file mode 100644 index 0000000..7f065d1 --- /dev/null +++ b/src/restrictions/types.rs @@ -0,0 +1,141 @@ +use crate::LocalProtocol; +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +use regex::Regex; +use serde::{Deserialize, Deserializer}; +use std::ops::RangeInclusive; + +#[derive(Debug, Clone, Deserialize)] +pub struct RestrictionsRules { + pub restrictions: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct RestrictionConfig { + pub name: String, + pub r#match: MatchConfig, + pub allow: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub enum MatchConfig { + Any, + #[serde(with = "serde_regex")] + PathPrefix(Regex), +} + +#[derive(Debug, Clone, Deserialize)] +pub enum AllowConfig { + ReverseTunnel(AllowReverseTunnelConfig), + Tunnel(AllowTunnelConfig), +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AllowTunnelConfig { + #[serde(default)] + pub protocol: Vec, + + #[serde(deserialize_with = "deserialize_port_range")] + #[serde(default = "default_port")] + pub port: RangeInclusive, + + #[serde(with = "serde_regex")] + #[serde(default = "default_host")] + pub host: Regex, + + #[serde(default = "default_cidr")] + pub cidr: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AllowReverseTunnelConfig { + #[serde(default)] + pub protocol: Vec, + + #[serde(deserialize_with = "deserialize_port_range")] + #[serde(default = "default_port")] + pub port: RangeInclusive, + + #[serde(default = "default_cidr")] + pub cidr: Vec, +} + +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +pub enum TunnelConfigProtocol { + Tcp, + Udp, + Unknown, +} + +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +pub enum ReverseTunnelConfigProtocol { + Tcp, + Udp, + Socks5, + Unix, + Unknown, +} + +pub fn default_port() -> RangeInclusive { + RangeInclusive::new(1, 65535) +} + +pub fn default_host() -> Regex { + Regex::new("^.*$").unwrap() +} + +pub fn default_cidr() -> Vec { + vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())] +} + +fn deserialize_port_range<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + let range = if let Some((l, r)) = s.split_once("..") { + RangeInclusive::new( + l.parse().map_err(serde::de::Error::custom)?, + r.parse().map_err(serde::de::Error::custom)?, + ) + } else { + let port = s.parse::().map_err(serde::de::Error::custom)?; + RangeInclusive::new(port, port) + }; + + Ok(range) +} + +impl From<&LocalProtocol> for ReverseTunnelConfigProtocol { + fn from(value: &LocalProtocol) -> Self { + match value { + LocalProtocol::Tcp { .. } + | LocalProtocol::Udp { .. } + | LocalProtocol::Stdio + | LocalProtocol::Socks5 { .. } + | LocalProtocol::TProxyTcp { .. } + | LocalProtocol::TProxyUdp { .. } + | LocalProtocol::Unix { .. } => ReverseTunnelConfigProtocol::Unknown, + LocalProtocol::ReverseTcp => ReverseTunnelConfigProtocol::Tcp, + LocalProtocol::ReverseUdp { .. } => ReverseTunnelConfigProtocol::Udp, + LocalProtocol::ReverseSocks5 => ReverseTunnelConfigProtocol::Socks5, + LocalProtocol::ReverseUnix { .. } => ReverseTunnelConfigProtocol::Unix, + } + } +} +impl From<&LocalProtocol> for TunnelConfigProtocol { + fn from(value: &LocalProtocol) -> Self { + match value { + LocalProtocol::ReverseTcp + | LocalProtocol::ReverseUdp { .. } + | LocalProtocol::ReverseSocks5 + | LocalProtocol::ReverseUnix { .. } + | LocalProtocol::Stdio + | LocalProtocol::Socks5 { .. } + | LocalProtocol::TProxyTcp { .. } + | LocalProtocol::TProxyUdp { .. } + | LocalProtocol::Unix { .. } => TunnelConfigProtocol::Unknown, + LocalProtocol::Tcp { .. } => TunnelConfigProtocol::Tcp, + LocalProtocol::Udp { .. } => TunnelConfigProtocol::Udp, + } + } +} diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 92295b0..f966e60 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -8,7 +8,7 @@ use std::cmp::min; use std::fmt::Debug; use std::future::Future; use std::net::{IpAddr, SocketAddr}; -use std::ops::{Deref, Not}; +use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -26,6 +26,9 @@ use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; +use crate::restrictions::types::{ + AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, +}; use crate::socks5::Socks5Stream; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; @@ -43,12 +46,11 @@ use uuid::Uuid; async fn run_tunnel( server_config: &WsServerConfig, - jwt: TokenData, + remote: RemoteAddr, client_address: SocketAddr, ) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { - match jwt.claims.p { + match remote.protocol { LocalProtocol::Udp { timeout, .. } => { - let remote = RemoteAddr::try_from(jwt.claims)?; let cnx = udp::connect( &remote.host, remote.port, @@ -60,7 +62,6 @@ async fn run_tunnel( Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) } LocalProtocol::Tcp { proxy_protocol } => { - let remote = RemoteAddr::try_from(jwt.claims)?; let mut socket = tcp::connect( &remote.host, remote.port, @@ -89,14 +90,14 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let local_srv = (remote.host, remote.port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = tcp::run_server(bind.parse()?, false); let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tcp.into_split(); let remote = RemoteAddr { - protocol: jwt.claims.p, + protocol: remote.protocol, host: local_srv.0, port: local_srv.1, }; @@ -107,7 +108,7 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let local_srv = (remote.host, remote.port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone())); @@ -115,7 +116,7 @@ async fn run_tunnel( let (local_rx, local_tx) = tokio::io::split(udp); let remote = RemoteAddr { - protocol: jwt.claims.p, + protocol: remote.protocol, host: local_srv.0, port: local_srv.1, }; @@ -126,7 +127,7 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let local_srv = (remote.host, remote.port); let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = socks5::run_server(bind.parse()?, None); let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; @@ -149,13 +150,13 @@ async fn run_tunnel( static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); + let local_srv = (remote.host, remote.port); let listening_server = unix_socket::run_server(path); let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = stream.into_split(); let remote = RemoteAddr { - protocol: jwt.claims.p.clone(), + protocol: remote.protocol, host: local_srv.0, port: local_srv.1, }; @@ -163,7 +164,7 @@ async fn run_tunnel( } #[cfg(not(unix))] LocalProtocol::ReverseUnix { ref path } => { - error!("Received an unsupported target protocol {:?}", jwt.claims); + error!("Received an unsupported target protocol {:?}", remote); Err(anyhow::anyhow!("Invalid upgrade request")) } LocalProtocol::Stdio @@ -171,7 +172,7 @@ async fn run_tunnel( | LocalProtocol::TProxyTcp | LocalProtocol::TProxyUdp { .. } | LocalProtocol::Unix { .. } => { - error!("Received an unsupported target protocol {:?}", jwt.claims); + error!("Received an unsupported target protocol {:?}", remote); Err(anyhow::anyhow!("Invalid upgrade request")) } } @@ -251,11 +252,26 @@ fn extract_x_forwarded_for(req: &Request) -> Result, - path_restriction_prefix: &Option>, -) -> Result<(), Response> { - if !req.uri().path().ends_with("/events") { +fn extract_path_prefix(req: &Request) -> Result<&str, Response> { + let path = req.uri().path(); + let min_len = min(path.len(), 1); + if &path[0..min_len] != "/" { + warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + + let Some((l, r)) = path[min_len..].split_once('/') else { + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".into()) + .unwrap()); + }; + + if !r.ends_with("events") { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return Err(http::Response::builder() .status(StatusCode::BAD_REQUEST) @@ -263,26 +279,7 @@ fn validate_url( .unwrap()); } - if let Some(paths_prefix) = &path_restriction_prefix { - let path = req.uri().path(); - let min_len = min(path.len(), 1); - let mut max_len = 0; - if &path[0..min_len] != "/" - || !paths_prefix.iter().any(|p| { - max_len = min(path.len(), p.len() + 1); - p == &path[min_len..max_len] - }) - || !path[max_len..].starts_with('/') - { - warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); - } - } - - Ok(()) + Ok(l) } #[inline] @@ -316,25 +313,102 @@ fn extract_tunnel_info(req: &Request) -> Result( _req: &Request, - jwt: &TokenData, - destination_restriction: &Option>, -) -> Result<(), Response> { - let Some(allowed_dests) = &destination_restriction else { - return Ok(()); - }; + remote: &RemoteAddr, + path_prefix: &str, + restrictions: &'a RestrictionsRules, +) -> Result<&'a RestrictionConfig, Response> { + for restriction in &restrictions.restrictions { + match &restriction.r#match { + MatchConfig::Any => {} + MatchConfig::PathPrefix(path) => { + if !path.is_match(path_prefix) { + continue; + } + } + } - let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); - if allowed_dests.iter().any(|dest| dest == &requested_dest).not() { - warn!("Rejecting connection with not allowed destination: {}", requested_dest); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); + for allow in &restriction.allow { + match allow { + AllowConfig::ReverseTunnel(allow) => { + if !remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) { + continue; + } + + if !allow.protocol.is_empty() + && !allow + .protocol + .contains(&ReverseTunnelConfigProtocol::from(&remote.protocol)) + { + continue; + } + + match &remote.host { + Host::Domain(_) => {} + Host::Ipv4(ip) => { + let ip = IpAddr::V4(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + Host::Ipv6(ip) => { + let ip = IpAddr::V6(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + } + } + + AllowConfig::Tunnel(allow) => { + if remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) { + continue; + } + + if !allow.protocol.is_empty() + && !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol)) + { + continue; + } + + match &remote.host { + Host::Domain(host) => { + if allow.host.is_match(host) { + return Ok(restriction); + } + } + Host::Ipv4(ip) => { + let ip = IpAddr::V4(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + Host::Ipv6(ip) => { + let ip = IpAddr::V6(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + } + } + } + } } - Ok(()) + warn!("Rejecting connection with not allowed destination: {:?}", remote); + Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()) } async fn ws_server_upgrade( @@ -360,9 +434,10 @@ async fn ws_server_upgrade( Err(err) => return err, }; - if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) { - return err; - } + let path_prefix = match extract_path_prefix(&req) { + Ok(p) => p, + Err(err) => return err, + }; let jwt = match extract_tunnel_info(&req) { Ok(jwt) => jwt, @@ -372,12 +447,26 @@ async fn ws_server_upgrade( Span::current().record("id", &jwt.claims.id); Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) { - return err; + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + }; + + match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + } + Err(err) => return err, } - let req_protocol = jwt.claims.p.clone(); - let tunnel = match run_tunnel(&server_config, jwt, client_addr).await { + let req_protocol = remote.protocol.clone(); + let tunnel = match run_tunnel(&server_config, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); @@ -461,9 +550,10 @@ async fn http_server_upgrade( Err(err) => return err.map(Either::Left), }; - if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) { - return err.map(Either::Left); - } + let path_prefix = match extract_path_prefix(&req) { + Ok(p) => p, + Err(err) => return err.map(Either::Left), + }; let jwt = match extract_tunnel_info(&req) { Ok(jwt) => jwt, @@ -472,13 +562,26 @@ async fn http_server_upgrade( Span::current().record("id", &jwt.claims.id); Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + } + }; - if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) { - return err.map(Either::Left); + match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + } + Err(err) => return err.map(Either::Left), } - let req_protocol = jwt.claims.p.clone(); - let tunnel = match run_tunnel(&server_config, jwt, client_addr).await { + let req_protocol = remote.protocol.clone(); + let tunnel = match run_tunnel(&server_config, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());