Add config file for restrictions

This commit is contained in:
Σrebe - Romain GERARD 2024-04-27 22:40:32 +02:00
parent 727e92902c
commit 8a228248d7
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
7 changed files with 559 additions and 75 deletions

79
src/restrictions/mod.rs Normal file
View file

@ -0,0 +1,79 @@
use crate::restrictions::types::{default_cidr, default_host, default_port};
use regex::Regex;
use std::fs::File;
use std::io::BufReader;
use std::ops::RangeInclusive;
use std::path::Path;
use types::RestrictionsRules;
pub mod types;
impl RestrictionsRules {
pub fn from_config_file(config_path: &Path) -> anyhow::Result<RestrictionsRules> {
let restrictions: RestrictionsRules = serde_yaml::from_reader(BufReader::new(File::open(config_path)?))?;
Ok(restrictions)
}
pub fn from_path_prefix(
path_prefixes: &[String],
restrict_to: &[(String, u16)],
) -> anyhow::Result<RestrictionsRules> {
let mut tunnels_restrictions = if restrict_to.is_empty() {
let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig {
protocol: vec![],
port: default_port(),
host: default_host(),
cidr: default_cidr(),
});
vec![r]
} else {
restrict_to
.iter()
.map(|(host, port)| {
// Fixme: Remove the unwrap
let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap();
types::AllowConfig::Tunnel(types::AllowTunnelConfig {
protocol: vec![],
port: RangeInclusive::new(*port, *port),
host: reg,
cidr: default_cidr(),
})
})
.collect()
};
tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig {
protocol: vec![],
port: default_port(),
cidr: default_cidr(),
}));
let restrictions = if path_prefixes.is_empty() {
// if no path prefixes are provided, we allow all
let reg = Regex::new(".").unwrap();
let r = types::RestrictionConfig {
name: "Allow All".to_string(),
r#match: types::MatchConfig::PathPrefix(reg),
allow: tunnels_restrictions,
};
vec![r]
} else {
path_prefixes
.iter()
.map(|path_prefix| {
// Fixme: Remove the unwrap
let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap();
types::RestrictionConfig {
name: format!("Allow path prefix {}", path_prefix),
r#match: types::MatchConfig::PathPrefix(reg),
allow: tunnels_restrictions.clone(),
}
})
.collect()
};
let restrictions = RestrictionsRules { restrictions };
Ok(restrictions)
}
}

141
src/restrictions/types.rs Normal file
View file

@ -0,0 +1,141 @@
use crate::LocalProtocol;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use regex::Regex;
use serde::{Deserialize, Deserializer};
use std::ops::RangeInclusive;
#[derive(Debug, Clone, Deserialize)]
pub struct RestrictionsRules {
pub restrictions: Vec<RestrictionConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RestrictionConfig {
pub name: String,
pub r#match: MatchConfig,
pub allow: Vec<AllowConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub enum MatchConfig {
Any,
#[serde(with = "serde_regex")]
PathPrefix(Regex),
}
#[derive(Debug, Clone, Deserialize)]
pub enum AllowConfig {
ReverseTunnel(AllowReverseTunnelConfig),
Tunnel(AllowTunnelConfig),
}
#[derive(Debug, Clone, Deserialize)]
pub struct AllowTunnelConfig {
#[serde(default)]
pub protocol: Vec<TunnelConfigProtocol>,
#[serde(deserialize_with = "deserialize_port_range")]
#[serde(default = "default_port")]
pub port: RangeInclusive<u16>,
#[serde(with = "serde_regex")]
#[serde(default = "default_host")]
pub host: Regex,
#[serde(default = "default_cidr")]
pub cidr: Vec<IpNet>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AllowReverseTunnelConfig {
#[serde(default)]
pub protocol: Vec<ReverseTunnelConfigProtocol>,
#[serde(deserialize_with = "deserialize_port_range")]
#[serde(default = "default_port")]
pub port: RangeInclusive<u16>,
#[serde(default = "default_cidr")]
pub cidr: Vec<IpNet>,
}
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
pub enum TunnelConfigProtocol {
Tcp,
Udp,
Unknown,
}
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
pub enum ReverseTunnelConfigProtocol {
Tcp,
Udp,
Socks5,
Unix,
Unknown,
}
pub fn default_port() -> RangeInclusive<u16> {
RangeInclusive::new(1, 65535)
}
pub fn default_host() -> Regex {
Regex::new("^.*$").unwrap()
}
pub fn default_cidr() -> Vec<IpNet> {
vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())]
}
fn deserialize_port_range<'de, D>(deserializer: D) -> Result<RangeInclusive<u16>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let range = if let Some((l, r)) = s.split_once("..") {
RangeInclusive::new(
l.parse().map_err(serde::de::Error::custom)?,
r.parse().map_err(serde::de::Error::custom)?,
)
} else {
let port = s.parse::<u16>().map_err(serde::de::Error::custom)?;
RangeInclusive::new(port, port)
};
Ok(range)
}
impl From<&LocalProtocol> for ReverseTunnelConfigProtocol {
fn from(value: &LocalProtocol) -> Self {
match value {
LocalProtocol::Tcp { .. }
| LocalProtocol::Udp { .. }
| LocalProtocol::Stdio
| LocalProtocol::Socks5 { .. }
| LocalProtocol::TProxyTcp { .. }
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Unix { .. } => ReverseTunnelConfigProtocol::Unknown,
LocalProtocol::ReverseTcp => ReverseTunnelConfigProtocol::Tcp,
LocalProtocol::ReverseUdp { .. } => ReverseTunnelConfigProtocol::Udp,
LocalProtocol::ReverseSocks5 => ReverseTunnelConfigProtocol::Socks5,
LocalProtocol::ReverseUnix { .. } => ReverseTunnelConfigProtocol::Unix,
}
}
}
impl From<&LocalProtocol> for TunnelConfigProtocol {
fn from(value: &LocalProtocol) -> Self {
match value {
LocalProtocol::ReverseTcp
| LocalProtocol::ReverseUdp { .. }
| LocalProtocol::ReverseSocks5
| LocalProtocol::ReverseUnix { .. }
| LocalProtocol::Stdio
| LocalProtocol::Socks5 { .. }
| LocalProtocol::TProxyTcp { .. }
| LocalProtocol::TProxyUdp { .. }
| LocalProtocol::Unix { .. } => TunnelConfigProtocol::Unknown,
LocalProtocol::Tcp { .. } => TunnelConfigProtocol::Tcp,
LocalProtocol::Udp { .. } => TunnelConfigProtocol::Udp,
}
}
}