Add config file for restrictions

This commit is contained in:
Σrebe - Romain GERARD 2024-04-27 22:40:32 +02:00
parent 727e92902c
commit 8a228248d7
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
7 changed files with 559 additions and 75 deletions

36
Cargo.lock generated
View file

@ -913,6 +913,9 @@ name = "ipnet"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
dependencies = [
"serde",
]
[[package]]
name = "itoa"
@ -1524,6 +1527,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_regex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf"
dependencies = [
"regex",
"serde",
]
[[package]]
name = "serde_with"
version = "1.14.0"
@ -1546,6 +1559,19 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "sha1"
version = "0.10.6"
@ -1950,6 +1976,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "untrusted"
version = "0.9.0"
@ -2285,6 +2317,7 @@ dependencies = [
"http-body-util",
"hyper",
"hyper-util",
"ipnet",
"jsonwebtoken",
"log",
"nix",
@ -2293,10 +2326,13 @@ dependencies = [
"parking_lot",
"pin-project",
"ppp",
"regex",
"rustls-native-certs",
"rustls-pemfile 2.1.1",
"scopeguard",
"serde",
"serde_regex",
"serde_yaml",
"socket2",
"testcontainers",
"tokio",

View file

@ -18,7 +18,13 @@ fast-socks5 = { version = "0.9.6", features = [] }
fastwebsockets = { version = "0.7.1", features = ["upgrade", "simd", "unstable-split"] }
futures-util = { version = "0.3.30" }
hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] }
ppp = { version = "2.2.0", features = [] }
ppp = { version = "2.2.0", features = [] }
# For config file parsing
regex = { version = "1.10.4", default-features = false, features = ["std", "perf"] }
serde_regex = "1.1.0"
serde_yaml = { version = "0.9.34", features = [] }
ipnet = { version = "2.9.0", features = ["serde"] }
hyper = { version = "1.2.0", features = ["client", "http1", "http2"] }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "server-auto"] }

84
restrictions.yaml Normal file
View file

@ -0,0 +1,84 @@
# Restrictions are whitelist rules for the tunnels
# By default, all requests are denied and only if a restriction match, the request is allowed
restrictions:
- name: "Allow all"
description: "This restriction allows all requests"
# This restriction apply only if it matches the prefix that match the given regex
# The regex does a match, so if you want to match exactly you need to bound the pattern with ^ $
# I.e: "tesotron" is going to match "XXXtesotronXXX", but "^tesotron$" is going to match only "tesotron"
match: !PathPrefix "^.*$"
# This is th list of tunnels your restriction is going to allow
# The list is going to be checked in order, the first match is going to allow the request
allow:
# !Tunnel allows forward tunnels
- !Tunnel
# Protocol that are allowed. Empty list means all protocols are allowed
protocol:
- Tcp
- Udp
# Port that are allowed. Can be a single port or an inclusive range (i.e. 80..90)
port: 9999
# if the tunnel wants to connect to a specific host, this regex must match
host: ^.*$
# if the tunnel wants to connect to a specific IP, it must match one of the network cidr
cidr:
- 0.0.0.0/0
- ::/0
# !ReverseTunnel allows reverse tunnels
# Not specifying anything means all reverse tunnels are allowed
- !ReverseTunnel
protocol:
- Tcp
- Udp
- Socks5
- Unix
port: 1..65535
cidr:
- 0.0.0.0/0
- ::/0
---
# Examples
restrictions:
- name: "example 1"
description: "Only allow forward tunnels to port 443 and forbid reverse tunnels"
match: !PathPrefix "^.*$"
allow:
- !Tunnel
port: 443
---
restrictions:
- name: "example 2"
description: "Only allow forward tunnels to local ssh and forbid reverse tunnels"
match: !PathPrefix "^.*$"
allow:
- !Tunnel
protocol:
- Tcp
port: 22
host: ^localhost$
cidr:
- 127.0.0.1/32
---
restrictions:
- name: "example 3"
description: "Only allow socks5 reverse tunnels listening on port between 1080..1443 on lan network"
match: !PathPrefix "^.*$"
allow:
- !ReverseTunnel
protocol:
- Socks5
port: 1080..1443
cidr:
- 192.168.0.0/16
---
restrictions:
- name: "example 4"
description: "Allow everything for client using path prefix my-super-secret-path"
match: !PathPrefix "^my-super-secret-path$"
allow:
- !Tunnel
- !ReverseTunnel

View file

@ -1,5 +1,6 @@
mod dns;
mod embedded_certificate;
mod restrictions;
mod socks5;
mod socks5_udp;
mod stdio;
@ -40,6 +41,7 @@ use tokio_rustls::TlsConnector;
use tracing::{error, info};
use crate::dns::DnsResolver;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use crate::udp::MyUdpSocket;
@ -287,6 +289,11 @@ struct Server {
)]
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
/// Path to the location of the restriction yaml config file.
/// Restriction file is automatically reloaded if it changes
#[arg(long, verbatim_doc_comment)]
restriction_file: Option<PathBuf>,
/// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate.
/// The certificate will be automatically reloaded if it changes
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
@ -319,6 +326,15 @@ enum LocalProtocol {
Unix { path: PathBuf },
}
impl LocalProtocol {
pub fn is_reverse_tunnel(&self) -> bool {
matches!(
self,
LocalProtocol::ReverseTcp | LocalProtocol::ReverseUdp { .. } | LocalProtocol::ReverseSocks5
)
}
}
#[derive(Clone, Debug)]
pub struct LocalToRemote {
local_protocol: LocalProtocol,
@ -607,13 +623,12 @@ pub struct TlsServerConfig {
pub struct WsServerConfig {
pub socket_so_mark: Option<u32>,
pub bind: SocketAddr,
pub restrict_to: Option<Vec<String>>,
pub restrict_http_upgrade_path_prefix: Option<Vec<String>>,
pub websocket_ping_frequency: Option<Duration>,
pub timeout_connect: Duration,
pub websocket_mask_frame: bool,
pub tls: Option<TlsServerConfig>,
pub dns_resolver: DnsResolver,
pub restrictions: RestrictionsRules,
}
impl Debug for WsServerConfig {
@ -621,8 +636,6 @@ impl Debug for WsServerConfig {
f.debug_struct("WsServerConfig")
.field("socket_so_mark", &self.socket_so_mark)
.field("bind", &self.bind)
.field("restrict_to", &self.restrict_to)
.field("restrict_http_upgrade_path_prefix", &self.restrict_http_upgrade_path_prefix)
.field("websocket_ping_frequency", &self.websocket_ping_frequency)
.field("timeout_connect", &self.timeout_connect)
.field("websocket_mask_frame", &self.websocket_mask_frame)
@ -1246,16 +1259,38 @@ async fn main() {
}
}
};
let restrictions = if let Some(path) = &args.restriction_file {
RestrictionsRules::from_config_file(path).expect("Cannot parse restriction file")
} else {
let restrict_to: Vec<(String, u16)> = args
.restrict_to
.as_deref()
.unwrap_or(&[])
.iter()
.map(|x| {
let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format");
(host.to_string(), port.parse::<u16>().expect("Invalid restrict-to port format"))
})
.collect();
let restriction_cfg = RestrictionsRules::from_path_prefix(
args.restrict_http_upgrade_path_prefix.as_deref().unwrap_or(&[]),
&restrict_to,
)
.expect("Cannot covertion restriction rules from path-prefix and restric-to");
restriction_cfg
};
let server_config = WsServerConfig {
socket_so_mark: args.socket_so_mark,
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
restrict_to: args.restrict_to,
restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix,
websocket_ping_frequency: args.websocket_ping_frequency_sec,
timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config,
dns_resolver,
restrictions,
};
info!(

79
src/restrictions/mod.rs Normal file
View file

@ -0,0 +1,79 @@
use crate::restrictions::types::{default_cidr, default_host, default_port};
use regex::Regex;
use std::fs::File;
use std::io::BufReader;
use std::ops::RangeInclusive;
use std::path::Path;
use types::RestrictionsRules;
pub mod types;
impl RestrictionsRules {
pub fn from_config_file(config_path: &Path) -> anyhow::Result<RestrictionsRules> {
let restrictions: RestrictionsRules = serde_yaml::from_reader(BufReader::new(File::open(config_path)?))?;
Ok(restrictions)
}
pub fn from_path_prefix(
path_prefixes: &[String],
restrict_to: &[(String, u16)],
) -> anyhow::Result<RestrictionsRules> {
let mut tunnels_restrictions = if restrict_to.is_empty() {
let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig {
protocol: vec![],
port: default_port(),
host: default_host(),
cidr: default_cidr(),
});
vec![r]
} else {
restrict_to
.iter()
.map(|(host, port)| {
// Fixme: Remove the unwrap
let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap();
types::AllowConfig::Tunnel(types::AllowTunnelConfig {
protocol: vec![],
port: RangeInclusive::new(*port, *port),
host: reg,
cidr: default_cidr(),
})
})
.collect()
};
tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig {
protocol: vec![],
port: default_port(),
cidr: default_cidr(),
}));
let restrictions = if path_prefixes.is_empty() {
// if no path prefixes are provided, we allow all
let reg = Regex::new(".").unwrap();
let r = types::RestrictionConfig {
name: "Allow All".to_string(),
r#match: types::MatchConfig::PathPrefix(reg),
allow: tunnels_restrictions,
};
vec![r]
} else {
path_prefixes
.iter()
.map(|path_prefix| {
// Fixme: Remove the unwrap
let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap();
types::RestrictionConfig {
name: format!("Allow path prefix {}", path_prefix),
r#match: types::MatchConfig::PathPrefix(reg),
allow: tunnels_restrictions.clone(),
}
})
.collect()
};
let restrictions = RestrictionsRules { restrictions };
Ok(restrictions)
}
}

141
src/restrictions/types.rs Normal file
View file

@ -0,0 +1,141 @@
use crate::LocalProtocol;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use regex::Regex;
use serde::{Deserialize, Deserializer};
use std::ops::RangeInclusive;
#[derive(Debug, Clone, Deserialize)]
pub struct RestrictionsRules {
pub restrictions: Vec<RestrictionConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RestrictionConfig {
pub name: String,
pub r#match: MatchConfig,
pub allow: Vec<AllowConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub enum MatchConfig {
Any,
#[serde(with = "serde_regex")]
PathPrefix(Regex),
}
#[derive(Debug, Clone, Deserialize)]
pub enum AllowConfig {
ReverseTunnel(AllowReverseTunnelConfig),
Tunnel(AllowTunnelConfig),
}
#[derive(Debug, Clone, Deserialize)]
pub struct AllowTunnelConfig {
#[serde(default)]
pub protocol: Vec<TunnelConfigProtocol>,
#[serde(deserialize_with = "deserialize_port_range")]
#[serde(default = "default_port")]
pub port: RangeInclusive<u16>,
#[serde(with = "serde_regex")]
#[serde(default = "default_host")]
pub host: Regex,
#[serde(default = "default_cidr")]
pub cidr: Vec<IpNet>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AllowReverseTunnelConfig {
#[serde(default)]
pub protocol: Vec<ReverseTunnelConfigProtocol>,
#[serde(deserialize_with = "deserialize_port_range")]
#[serde(default = "default_port")]
pub port: RangeInclusive<u16>,
#[serde(default = "default_cidr")]
pub cidr: Vec<IpNet>,
}
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
pub enum TunnelConfigProtocol {
Tcp,
Udp,
Unknown,
}
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
pub enum ReverseTunnelConfigProtocol {
Tcp,
Udp,
Socks5,
Unix,
Unknown,
}
pub fn default_port() -> RangeInclusive<u16> {
RangeInclusive::new(1, 65535)
}
pub fn default_host() -> Regex {
Regex::new("^.*$").unwrap()
}
pub fn default_cidr() -> Vec<IpNet> {
vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())]
}
fn deserialize_port_range<'de, D>(deserializer: D) -> Result<RangeInclusive<u16>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let range = if let Some((l, r)) = s.split_once("..") {
RangeInclusive::new(
l.parse().map_err(serde::de::Error::custom)?,
r.parse().map_err(serde::de::Error::custom)?,
)
} else {
let port = s.parse::<u16>().map_err(serde::de::Error::custom)?;
RangeInclusive::new(port, port)
};
Ok(range)
}
impl From<&LocalProtocol> for ReverseTunnelConfigProtocol {
fn from(value: &LocalProtocol) -> Self {
match value {
LocalProtocol::Tcp { .. }
| LocalProtocol::Udp { .. }
| LocalProtocol::Stdio
| LocalProtocol::Socks5 { .. }
| LocalProtocol::TProxyTcp { .. }
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Unix { .. } => ReverseTunnelConfigProtocol::Unknown,
LocalProtocol::ReverseTcp => ReverseTunnelConfigProtocol::Tcp,
LocalProtocol::ReverseUdp { .. } => ReverseTunnelConfigProtocol::Udp,
LocalProtocol::ReverseSocks5 => ReverseTunnelConfigProtocol::Socks5,
LocalProtocol::ReverseUnix { .. } => ReverseTunnelConfigProtocol::Unix,
}
}
}
impl From<&LocalProtocol> for TunnelConfigProtocol {
fn from(value: &LocalProtocol) -> Self {
match value {
LocalProtocol::ReverseTcp
| LocalProtocol::ReverseUdp { .. }
| LocalProtocol::ReverseSocks5
| LocalProtocol::ReverseUnix { .. }
| LocalProtocol::Stdio
| LocalProtocol::Socks5 { .. }
| LocalProtocol::TProxyTcp { .. }
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Unix { .. } => TunnelConfigProtocol::Unknown,
LocalProtocol::Tcp { .. } => TunnelConfigProtocol::Tcp,
LocalProtocol::Udp { .. } => TunnelConfigProtocol::Udp,
}
}
}

View file

@ -8,7 +8,7 @@ use std::cmp::min;
use std::fmt::Debug;
use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::ops::{Deref, Not};
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
@ -26,6 +26,9 @@ use jsonwebtoken::TokenData;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
};
use crate::socks5::Socks5Stream;
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
@ -43,12 +46,11 @@ use uuid::Uuid;
async fn run_tunnel(
server_config: &WsServerConfig,
jwt: TokenData<JwtTunnelConfig>,
remote: RemoteAddr,
client_address: SocketAddr,
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
match jwt.claims.p {
match remote.protocol {
LocalProtocol::Udp { timeout, .. } => {
let remote = RemoteAddr::try_from(jwt.claims)?;
let cnx = udp::connect(
&remote.host,
remote.port,
@ -60,7 +62,6 @@ async fn run_tunnel(
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
}
LocalProtocol::Tcp { proxy_protocol } => {
let remote = RemoteAddr::try_from(jwt.claims)?;
let mut socket = tcp::connect(
&remote.host,
remote.port,
@ -89,14 +90,14 @@ async fn run_tunnel(
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
let local_srv = (remote.host, remote.port);
let bind = format!("{}:{}", local_srv.0, local_srv.1);
let listening_server = tcp::run_server(bind.parse()?, false);
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = tcp.into_split();
let remote = RemoteAddr {
protocol: jwt.claims.p,
protocol: remote.protocol,
host: local_srv.0,
port: local_srv.1,
};
@ -107,7 +108,7 @@ async fn run_tunnel(
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
let local_srv = (remote.host, remote.port);
let bind = format!("{}:{}", local_srv.0, local_srv.1);
let listening_server =
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
@ -115,7 +116,7 @@ async fn run_tunnel(
let (local_rx, local_tx) = tokio::io::split(udp);
let remote = RemoteAddr {
protocol: jwt.claims.p,
protocol: remote.protocol,
host: local_srv.0,
port: local_srv.1,
};
@ -126,7 +127,7 @@ async fn run_tunnel(
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
let local_srv = (remote.host, remote.port);
let bind = format!("{}:{}", local_srv.0, local_srv.1);
let listening_server = socks5::run_server(bind.parse()?, None);
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
@ -149,13 +150,13 @@ async fn run_tunnel(
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UnixStream>>>> =
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
let local_srv = (remote.host, remote.port);
let listening_server = unix_socket::run_server(path);
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = stream.into_split();
let remote = RemoteAddr {
protocol: jwt.claims.p.clone(),
protocol: remote.protocol,
host: local_srv.0,
port: local_srv.1,
};
@ -163,7 +164,7 @@ async fn run_tunnel(
}
#[cfg(not(unix))]
LocalProtocol::ReverseUnix { ref path } => {
error!("Received an unsupported target protocol {:?}", jwt.claims);
error!("Received an unsupported target protocol {:?}", remote);
Err(anyhow::anyhow!("Invalid upgrade request"))
}
LocalProtocol::Stdio
@ -171,7 +172,7 @@ async fn run_tunnel(
| LocalProtocol::TProxyTcp
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Unix { .. } => {
error!("Received an unsupported target protocol {:?}", jwt.claims);
error!("Received an unsupported target protocol {:?}", remote);
Err(anyhow::anyhow!("Invalid upgrade request"))
}
}
@ -251,11 +252,26 @@ fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &s
}
#[inline]
fn validate_url(
req: &Request<Incoming>,
path_restriction_prefix: &Option<Vec<String>>,
) -> Result<(), Response<String>> {
if !req.uri().path().ends_with("/events") {
fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, Response<String>> {
let path = req.uri().path();
let min_len = min(path.len(), 1);
if &path[0..min_len] != "/" {
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap());
}
let Some((l, r)) = path[min_len..].split_once('/') else {
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".into())
.unwrap());
};
if !r.ends_with("events") {
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
@ -263,26 +279,7 @@ fn validate_url(
.unwrap());
}
if let Some(paths_prefix) = &path_restriction_prefix {
let path = req.uri().path();
let min_len = min(path.len(), 1);
let mut max_len = 0;
if &path[0..min_len] != "/"
|| !paths_prefix.iter().any(|p| {
max_len = min(path.len(), p.len() + 1);
p == &path[min_len..max_len]
})
|| !path[max_len..].starts_with('/')
{
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap());
}
}
Ok(())
Ok(l)
}
#[inline]
@ -316,25 +313,102 @@ fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelCon
}
#[inline]
fn validate_destination(
fn validate_tunnel<'a>(
_req: &Request<Incoming>,
jwt: &TokenData<JwtTunnelConfig>,
destination_restriction: &Option<Vec<String>>,
) -> Result<(), Response<String>> {
let Some(allowed_dests) = &destination_restriction else {
return Ok(());
};
remote: &RemoteAddr,
path_prefix: &str,
restrictions: &'a RestrictionsRules,
) -> Result<&'a RestrictionConfig, Response<String>> {
for restriction in &restrictions.restrictions {
match &restriction.r#match {
MatchConfig::Any => {}
MatchConfig::PathPrefix(path) => {
if !path.is_match(path_prefix) {
continue;
}
}
}
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap());
for allow in &restriction.allow {
match allow {
AllowConfig::ReverseTunnel(allow) => {
if !remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
continue;
}
if !allow.protocol.is_empty()
&& !allow
.protocol
.contains(&ReverseTunnelConfigProtocol::from(&remote.protocol))
{
continue;
}
match &remote.host {
Host::Domain(_) => {}
Host::Ipv4(ip) => {
let ip = IpAddr::V4(*ip);
for cidr in &allow.cidr {
if cidr.contains(&ip) {
return Ok(restriction);
}
}
}
Host::Ipv6(ip) => {
let ip = IpAddr::V6(*ip);
for cidr in &allow.cidr {
if cidr.contains(&ip) {
return Ok(restriction);
}
}
}
}
}
AllowConfig::Tunnel(allow) => {
if remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
continue;
}
if !allow.protocol.is_empty()
&& !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol))
{
continue;
}
match &remote.host {
Host::Domain(host) => {
if allow.host.is_match(host) {
return Ok(restriction);
}
}
Host::Ipv4(ip) => {
let ip = IpAddr::V4(*ip);
for cidr in &allow.cidr {
if cidr.contains(&ip) {
return Ok(restriction);
}
}
}
Host::Ipv6(ip) => {
let ip = IpAddr::V6(*ip);
for cidr in &allow.cidr {
if cidr.contains(&ip) {
return Ok(restriction);
}
}
}
}
}
}
}
}
Ok(())
warn!("Rejecting connection with not allowed destination: {:?}", remote);
Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap())
}
async fn ws_server_upgrade(
@ -360,9 +434,10 @@ async fn ws_server_upgrade(
Err(err) => return err,
};
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
return err;
}
let path_prefix = match extract_path_prefix(&req) {
Ok(p) => p,
Err(err) => return err,
};
let jwt = match extract_tunnel_info(&req) {
Ok(jwt) => jwt,
@ -372,12 +447,26 @@ async fn ws_server_upgrade(
Span::current().record("id", &jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
return err;
let remote = match RemoteAddr::try_from(jwt.claims) {
Ok(remote) => remote,
Err(err) => {
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap();
}
};
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
}
Err(err) => return err,
}
let req_protocol = jwt.claims.p.clone();
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
let req_protocol = remote.protocol.clone();
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
Ok(ret) => ret,
Err(err) => {
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
@ -461,9 +550,10 @@ async fn http_server_upgrade(
Err(err) => return err.map(Either::Left),
};
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
return err.map(Either::Left);
}
let path_prefix = match extract_path_prefix(&req) {
Ok(p) => p,
Err(err) => return err.map(Either::Left),
};
let jwt = match extract_tunnel_info(&req) {
Ok(jwt) => jwt,
@ -472,13 +562,26 @@ async fn http_server_upgrade(
Span::current().record("id", &jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
let remote = match RemoteAddr::try_from(jwt.claims) {
Ok(remote) => remote,
Err(err) => {
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Either::Left("Invalid upgrade request".to_string()))
.unwrap();
}
};
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
return err.map(Either::Left);
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
}
Err(err) => return err.map(Either::Left),
}
let req_protocol = jwt.claims.p.clone();
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
let req_protocol = remote.protocol.clone();
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
Ok(ret) => ret,
Err(err) => {
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());