Add config file for restrictions
This commit is contained in:
parent
727e92902c
commit
8a228248d7
7 changed files with 559 additions and 75 deletions
36
Cargo.lock
generated
36
Cargo.lock
generated
|
@ -913,6 +913,9 @@ name = "ipnet"
|
|||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
|
@ -1524,6 +1527,16 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_regex"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf"
|
||||
dependencies = [
|
||||
"regex",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_with"
|
||||
version = "1.14.0"
|
||||
|
@ -1546,6 +1559,19 @@ dependencies = [
|
|||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.9.34+deprecated"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
"unsafe-libyaml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
|
@ -1950,6 +1976,12 @@ dependencies = [
|
|||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
|
@ -2285,6 +2317,7 @@ dependencies = [
|
|||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"jsonwebtoken",
|
||||
"log",
|
||||
"nix",
|
||||
|
@ -2293,10 +2326,13 @@ dependencies = [
|
|||
"parking_lot",
|
||||
"pin-project",
|
||||
"ppp",
|
||||
"regex",
|
||||
"rustls-native-certs",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_regex",
|
||||
"serde_yaml",
|
||||
"socket2",
|
||||
"testcontainers",
|
||||
"tokio",
|
||||
|
|
|
@ -18,7 +18,13 @@ fast-socks5 = { version = "0.9.6", features = [] }
|
|||
fastwebsockets = { version = "0.7.1", features = ["upgrade", "simd", "unstable-split"] }
|
||||
futures-util = { version = "0.3.30" }
|
||||
hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] }
|
||||
ppp = { version = "2.2.0", features = [] }
|
||||
ppp = { version = "2.2.0", features = [] }
|
||||
|
||||
# For config file parsing
|
||||
regex = { version = "1.10.4", default-features = false, features = ["std", "perf"] }
|
||||
serde_regex = "1.1.0"
|
||||
serde_yaml = { version = "0.9.34", features = [] }
|
||||
ipnet = { version = "2.9.0", features = ["serde"] }
|
||||
|
||||
hyper = { version = "1.2.0", features = ["client", "http1", "http2"] }
|
||||
hyper-util = { version = "0.1.3", features = ["tokio", "server", "server-auto"] }
|
||||
|
|
84
restrictions.yaml
Normal file
84
restrictions.yaml
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Restrictions are whitelist rules for the tunnels
|
||||
# By default, all requests are denied and only if a restriction match, the request is allowed
|
||||
restrictions:
|
||||
- name: "Allow all"
|
||||
description: "This restriction allows all requests"
|
||||
# This restriction apply only if it matches the prefix that match the given regex
|
||||
# The regex does a match, so if you want to match exactly you need to bound the pattern with ^ $
|
||||
# I.e: "tesotron" is going to match "XXXtesotronXXX", but "^tesotron$" is going to match only "tesotron"
|
||||
match: !PathPrefix "^.*$"
|
||||
|
||||
# This is th list of tunnels your restriction is going to allow
|
||||
# The list is going to be checked in order, the first match is going to allow the request
|
||||
allow:
|
||||
# !Tunnel allows forward tunnels
|
||||
- !Tunnel
|
||||
# Protocol that are allowed. Empty list means all protocols are allowed
|
||||
protocol:
|
||||
- Tcp
|
||||
- Udp
|
||||
# Port that are allowed. Can be a single port or an inclusive range (i.e. 80..90)
|
||||
port: 9999
|
||||
|
||||
# if the tunnel wants to connect to a specific host, this regex must match
|
||||
host: ^.*$
|
||||
# if the tunnel wants to connect to a specific IP, it must match one of the network cidr
|
||||
cidr:
|
||||
- 0.0.0.0/0
|
||||
- ::/0
|
||||
|
||||
# !ReverseTunnel allows reverse tunnels
|
||||
# Not specifying anything means all reverse tunnels are allowed
|
||||
- !ReverseTunnel
|
||||
protocol:
|
||||
- Tcp
|
||||
- Udp
|
||||
- Socks5
|
||||
- Unix
|
||||
port: 1..65535
|
||||
cidr:
|
||||
- 0.0.0.0/0
|
||||
- ::/0
|
||||
|
||||
---
|
||||
# Examples
|
||||
restrictions:
|
||||
- name: "example 1"
|
||||
description: "Only allow forward tunnels to port 443 and forbid reverse tunnels"
|
||||
match: !PathPrefix "^.*$"
|
||||
allow:
|
||||
- !Tunnel
|
||||
port: 443
|
||||
---
|
||||
restrictions:
|
||||
- name: "example 2"
|
||||
description: "Only allow forward tunnels to local ssh and forbid reverse tunnels"
|
||||
match: !PathPrefix "^.*$"
|
||||
allow:
|
||||
- !Tunnel
|
||||
protocol:
|
||||
- Tcp
|
||||
port: 22
|
||||
host: ^localhost$
|
||||
cidr:
|
||||
- 127.0.0.1/32
|
||||
---
|
||||
restrictions:
|
||||
- name: "example 3"
|
||||
description: "Only allow socks5 reverse tunnels listening on port between 1080..1443 on lan network"
|
||||
match: !PathPrefix "^.*$"
|
||||
allow:
|
||||
- !ReverseTunnel
|
||||
protocol:
|
||||
- Socks5
|
||||
port: 1080..1443
|
||||
cidr:
|
||||
- 192.168.0.0/16
|
||||
---
|
||||
restrictions:
|
||||
- name: "example 4"
|
||||
description: "Allow everything for client using path prefix my-super-secret-path"
|
||||
match: !PathPrefix "^my-super-secret-path$"
|
||||
allow:
|
||||
- !Tunnel
|
||||
- !ReverseTunnel
|
47
src/main.rs
47
src/main.rs
|
@ -1,5 +1,6 @@
|
|||
mod dns;
|
||||
mod embedded_certificate;
|
||||
mod restrictions;
|
||||
mod socks5;
|
||||
mod socks5_udp;
|
||||
mod stdio;
|
||||
|
@ -40,6 +41,7 @@ use tokio_rustls::TlsConnector;
|
|||
use tracing::{error, info};
|
||||
|
||||
use crate::dns::DnsResolver;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||
use crate::udp::MyUdpSocket;
|
||||
|
@ -287,6 +289,11 @@ struct Server {
|
|||
)]
|
||||
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
||||
|
||||
/// Path to the location of the restriction yaml config file.
|
||||
/// Restriction file is automatically reloaded if it changes
|
||||
#[arg(long, verbatim_doc_comment)]
|
||||
restriction_file: Option<PathBuf>,
|
||||
|
||||
/// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate.
|
||||
/// The certificate will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
|
@ -319,6 +326,15 @@ enum LocalProtocol {
|
|||
Unix { path: PathBuf },
|
||||
}
|
||||
|
||||
impl LocalProtocol {
|
||||
pub fn is_reverse_tunnel(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
LocalProtocol::ReverseTcp | LocalProtocol::ReverseUdp { .. } | LocalProtocol::ReverseSocks5
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LocalToRemote {
|
||||
local_protocol: LocalProtocol,
|
||||
|
@ -607,13 +623,12 @@ pub struct TlsServerConfig {
|
|||
pub struct WsServerConfig {
|
||||
pub socket_so_mark: Option<u32>,
|
||||
pub bind: SocketAddr,
|
||||
pub restrict_to: Option<Vec<String>>,
|
||||
pub restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
||||
pub websocket_ping_frequency: Option<Duration>,
|
||||
pub timeout_connect: Duration,
|
||||
pub websocket_mask_frame: bool,
|
||||
pub tls: Option<TlsServerConfig>,
|
||||
pub dns_resolver: DnsResolver,
|
||||
pub restrictions: RestrictionsRules,
|
||||
}
|
||||
|
||||
impl Debug for WsServerConfig {
|
||||
|
@ -621,8 +636,6 @@ impl Debug for WsServerConfig {
|
|||
f.debug_struct("WsServerConfig")
|
||||
.field("socket_so_mark", &self.socket_so_mark)
|
||||
.field("bind", &self.bind)
|
||||
.field("restrict_to", &self.restrict_to)
|
||||
.field("restrict_http_upgrade_path_prefix", &self.restrict_http_upgrade_path_prefix)
|
||||
.field("websocket_ping_frequency", &self.websocket_ping_frequency)
|
||||
.field("timeout_connect", &self.timeout_connect)
|
||||
.field("websocket_mask_frame", &self.websocket_mask_frame)
|
||||
|
@ -1246,16 +1259,38 @@ async fn main() {
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
let restrictions = if let Some(path) = &args.restriction_file {
|
||||
RestrictionsRules::from_config_file(path).expect("Cannot parse restriction file")
|
||||
} else {
|
||||
let restrict_to: Vec<(String, u16)> = args
|
||||
.restrict_to
|
||||
.as_deref()
|
||||
.unwrap_or(&[])
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format");
|
||||
(host.to_string(), port.parse::<u16>().expect("Invalid restrict-to port format"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let restriction_cfg = RestrictionsRules::from_path_prefix(
|
||||
args.restrict_http_upgrade_path_prefix.as_deref().unwrap_or(&[]),
|
||||
&restrict_to,
|
||||
)
|
||||
.expect("Cannot covertion restriction rules from path-prefix and restric-to");
|
||||
restriction_cfg
|
||||
};
|
||||
|
||||
let server_config = WsServerConfig {
|
||||
socket_so_mark: args.socket_so_mark,
|
||||
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
|
||||
restrict_to: args.restrict_to,
|
||||
restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix,
|
||||
websocket_ping_frequency: args.websocket_ping_frequency_sec,
|
||||
timeout_connect: Duration::from_secs(10),
|
||||
websocket_mask_frame: args.websocket_mask_frame,
|
||||
tls: tls_config,
|
||||
dns_resolver,
|
||||
restrictions,
|
||||
};
|
||||
|
||||
info!(
|
||||
|
|
79
src/restrictions/mod.rs
Normal file
79
src/restrictions/mod.rs
Normal file
|
@ -0,0 +1,79 @@
|
|||
use crate::restrictions::types::{default_cidr, default_host, default_port};
|
||||
use regex::Regex;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::ops::RangeInclusive;
|
||||
use std::path::Path;
|
||||
use types::RestrictionsRules;
|
||||
|
||||
pub mod types;
|
||||
|
||||
impl RestrictionsRules {
|
||||
pub fn from_config_file(config_path: &Path) -> anyhow::Result<RestrictionsRules> {
|
||||
let restrictions: RestrictionsRules = serde_yaml::from_reader(BufReader::new(File::open(config_path)?))?;
|
||||
Ok(restrictions)
|
||||
}
|
||||
|
||||
pub fn from_path_prefix(
|
||||
path_prefixes: &[String],
|
||||
restrict_to: &[(String, u16)],
|
||||
) -> anyhow::Result<RestrictionsRules> {
|
||||
let mut tunnels_restrictions = if restrict_to.is_empty() {
|
||||
let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig {
|
||||
protocol: vec![],
|
||||
port: default_port(),
|
||||
host: default_host(),
|
||||
cidr: default_cidr(),
|
||||
});
|
||||
vec![r]
|
||||
} else {
|
||||
restrict_to
|
||||
.iter()
|
||||
.map(|(host, port)| {
|
||||
// Fixme: Remove the unwrap
|
||||
let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap();
|
||||
types::AllowConfig::Tunnel(types::AllowTunnelConfig {
|
||||
protocol: vec![],
|
||||
port: RangeInclusive::new(*port, *port),
|
||||
host: reg,
|
||||
cidr: default_cidr(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig {
|
||||
protocol: vec![],
|
||||
port: default_port(),
|
||||
cidr: default_cidr(),
|
||||
}));
|
||||
|
||||
let restrictions = if path_prefixes.is_empty() {
|
||||
// if no path prefixes are provided, we allow all
|
||||
let reg = Regex::new(".").unwrap();
|
||||
let r = types::RestrictionConfig {
|
||||
name: "Allow All".to_string(),
|
||||
r#match: types::MatchConfig::PathPrefix(reg),
|
||||
allow: tunnels_restrictions,
|
||||
};
|
||||
vec![r]
|
||||
} else {
|
||||
path_prefixes
|
||||
.iter()
|
||||
.map(|path_prefix| {
|
||||
// Fixme: Remove the unwrap
|
||||
let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap();
|
||||
types::RestrictionConfig {
|
||||
name: format!("Allow path prefix {}", path_prefix),
|
||||
r#match: types::MatchConfig::PathPrefix(reg),
|
||||
allow: tunnels_restrictions.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
let restrictions = RestrictionsRules { restrictions };
|
||||
|
||||
Ok(restrictions)
|
||||
}
|
||||
}
|
141
src/restrictions/types.rs
Normal file
141
src/restrictions/types.rs
Normal file
|
@ -0,0 +1,141 @@
|
|||
use crate::LocalProtocol;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RestrictionsRules {
|
||||
pub restrictions: Vec<RestrictionConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RestrictionConfig {
|
||||
pub name: String,
|
||||
pub r#match: MatchConfig,
|
||||
pub allow: Vec<AllowConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum MatchConfig {
|
||||
Any,
|
||||
#[serde(with = "serde_regex")]
|
||||
PathPrefix(Regex),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum AllowConfig {
|
||||
ReverseTunnel(AllowReverseTunnelConfig),
|
||||
Tunnel(AllowTunnelConfig),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AllowTunnelConfig {
|
||||
#[serde(default)]
|
||||
pub protocol: Vec<TunnelConfigProtocol>,
|
||||
|
||||
#[serde(deserialize_with = "deserialize_port_range")]
|
||||
#[serde(default = "default_port")]
|
||||
pub port: RangeInclusive<u16>,
|
||||
|
||||
#[serde(with = "serde_regex")]
|
||||
#[serde(default = "default_host")]
|
||||
pub host: Regex,
|
||||
|
||||
#[serde(default = "default_cidr")]
|
||||
pub cidr: Vec<IpNet>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AllowReverseTunnelConfig {
|
||||
#[serde(default)]
|
||||
pub protocol: Vec<ReverseTunnelConfigProtocol>,
|
||||
|
||||
#[serde(deserialize_with = "deserialize_port_range")]
|
||||
#[serde(default = "default_port")]
|
||||
pub port: RangeInclusive<u16>,
|
||||
|
||||
#[serde(default = "default_cidr")]
|
||||
pub cidr: Vec<IpNet>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
|
||||
pub enum TunnelConfigProtocol {
|
||||
Tcp,
|
||||
Udp,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
|
||||
pub enum ReverseTunnelConfigProtocol {
|
||||
Tcp,
|
||||
Udp,
|
||||
Socks5,
|
||||
Unix,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
pub fn default_port() -> RangeInclusive<u16> {
|
||||
RangeInclusive::new(1, 65535)
|
||||
}
|
||||
|
||||
pub fn default_host() -> Regex {
|
||||
Regex::new("^.*$").unwrap()
|
||||
}
|
||||
|
||||
pub fn default_cidr() -> Vec<IpNet> {
|
||||
vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())]
|
||||
}
|
||||
|
||||
fn deserialize_port_range<'de, D>(deserializer: D) -> Result<RangeInclusive<u16>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
let range = if let Some((l, r)) = s.split_once("..") {
|
||||
RangeInclusive::new(
|
||||
l.parse().map_err(serde::de::Error::custom)?,
|
||||
r.parse().map_err(serde::de::Error::custom)?,
|
||||
)
|
||||
} else {
|
||||
let port = s.parse::<u16>().map_err(serde::de::Error::custom)?;
|
||||
RangeInclusive::new(port, port)
|
||||
};
|
||||
|
||||
Ok(range)
|
||||
}
|
||||
|
||||
impl From<&LocalProtocol> for ReverseTunnelConfigProtocol {
|
||||
fn from(value: &LocalProtocol) -> Self {
|
||||
match value {
|
||||
LocalProtocol::Tcp { .. }
|
||||
| LocalProtocol::Udp { .. }
|
||||
| LocalProtocol::Stdio
|
||||
| LocalProtocol::Socks5 { .. }
|
||||
| LocalProtocol::TProxyTcp { .. }
|
||||
| LocalProtocol::TProxyUdp { .. }
|
||||
| LocalProtocol::Unix { .. } => ReverseTunnelConfigProtocol::Unknown,
|
||||
LocalProtocol::ReverseTcp => ReverseTunnelConfigProtocol::Tcp,
|
||||
LocalProtocol::ReverseUdp { .. } => ReverseTunnelConfigProtocol::Udp,
|
||||
LocalProtocol::ReverseSocks5 => ReverseTunnelConfigProtocol::Socks5,
|
||||
LocalProtocol::ReverseUnix { .. } => ReverseTunnelConfigProtocol::Unix,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<&LocalProtocol> for TunnelConfigProtocol {
|
||||
fn from(value: &LocalProtocol) -> Self {
|
||||
match value {
|
||||
LocalProtocol::ReverseTcp
|
||||
| LocalProtocol::ReverseUdp { .. }
|
||||
| LocalProtocol::ReverseSocks5
|
||||
| LocalProtocol::ReverseUnix { .. }
|
||||
| LocalProtocol::Stdio
|
||||
| LocalProtocol::Socks5 { .. }
|
||||
| LocalProtocol::TProxyTcp { .. }
|
||||
| LocalProtocol::TProxyUdp { .. }
|
||||
| LocalProtocol::Unix { .. } => TunnelConfigProtocol::Unknown,
|
||||
LocalProtocol::Tcp { .. } => TunnelConfigProtocol::Tcp,
|
||||
LocalProtocol::Udp { .. } => TunnelConfigProtocol::Udp,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ use std::cmp::min;
|
|||
use std::fmt::Debug;
|
||||
use std::future::Future;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::ops::{Deref, Not};
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
@ -26,6 +26,9 @@ use jsonwebtoken::TokenData;
|
|||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::restrictions::types::{
|
||||
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
|
||||
};
|
||||
use crate::socks5::Socks5Stream;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
||||
|
@ -43,12 +46,11 @@ use uuid::Uuid;
|
|||
|
||||
async fn run_tunnel(
|
||||
server_config: &WsServerConfig,
|
||||
jwt: TokenData<JwtTunnelConfig>,
|
||||
remote: RemoteAddr,
|
||||
client_address: SocketAddr,
|
||||
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
||||
match jwt.claims.p {
|
||||
match remote.protocol {
|
||||
LocalProtocol::Udp { timeout, .. } => {
|
||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||
let cnx = udp::connect(
|
||||
&remote.host,
|
||||
remote.port,
|
||||
|
@ -60,7 +62,6 @@ async fn run_tunnel(
|
|||
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
||||
}
|
||||
LocalProtocol::Tcp { proxy_protocol } => {
|
||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
||||
let mut socket = tcp::connect(
|
||||
&remote.host,
|
||||
remote.port,
|
||||
|
@ -89,14 +90,14 @@ async fn run_tunnel(
|
|||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||
let local_srv = (remote.host, remote.port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server = tcp::run_server(bind.parse()?, false);
|
||||
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let (local_rx, local_tx) = tcp.into_split();
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
};
|
||||
|
@ -107,7 +108,7 @@ async fn run_tunnel(
|
|||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||
let local_srv = (remote.host, remote.port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server =
|
||||
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
|
||||
|
@ -115,7 +116,7 @@ async fn run_tunnel(
|
|||
let (local_rx, local_tx) = tokio::io::split(udp);
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
};
|
||||
|
@ -126,7 +127,7 @@ async fn run_tunnel(
|
|||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||
let local_srv = (remote.host, remote.port);
|
||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||
let listening_server = socks5::run_server(bind.parse()?, None);
|
||||
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
|
@ -149,13 +150,13 @@ async fn run_tunnel(
|
|||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UnixStream>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||
|
||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||
let local_srv = (remote.host, remote.port);
|
||||
let listening_server = unix_socket::run_server(path);
|
||||
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||
let (local_rx, local_tx) = stream.into_split();
|
||||
|
||||
let remote = RemoteAddr {
|
||||
protocol: jwt.claims.p.clone(),
|
||||
protocol: remote.protocol,
|
||||
host: local_srv.0,
|
||||
port: local_srv.1,
|
||||
};
|
||||
|
@ -163,7 +164,7 @@ async fn run_tunnel(
|
|||
}
|
||||
#[cfg(not(unix))]
|
||||
LocalProtocol::ReverseUnix { ref path } => {
|
||||
error!("Received an unsupported target protocol {:?}", jwt.claims);
|
||||
error!("Received an unsupported target protocol {:?}", remote);
|
||||
Err(anyhow::anyhow!("Invalid upgrade request"))
|
||||
}
|
||||
LocalProtocol::Stdio
|
||||
|
@ -171,7 +172,7 @@ async fn run_tunnel(
|
|||
| LocalProtocol::TProxyTcp
|
||||
| LocalProtocol::TProxyUdp { .. }
|
||||
| LocalProtocol::Unix { .. } => {
|
||||
error!("Received an unsupported target protocol {:?}", jwt.claims);
|
||||
error!("Received an unsupported target protocol {:?}", remote);
|
||||
Err(anyhow::anyhow!("Invalid upgrade request"))
|
||||
}
|
||||
}
|
||||
|
@ -251,11 +252,26 @@ fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &s
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn validate_url(
|
||||
req: &Request<Incoming>,
|
||||
path_restriction_prefix: &Option<Vec<String>>,
|
||||
) -> Result<(), Response<String>> {
|
||||
if !req.uri().path().ends_with("/events") {
|
||||
fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, Response<String>> {
|
||||
let path = req.uri().path();
|
||||
let min_len = min(path.len(), 1);
|
||||
if &path[0..min_len] != "/" {
|
||||
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||
return Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".to_string())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
let Some((l, r)) = path[min_len..].split_once('/') else {
|
||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||
return Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".into())
|
||||
.unwrap());
|
||||
};
|
||||
|
||||
if !r.ends_with("events") {
|
||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||
return Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
|
@ -263,26 +279,7 @@ fn validate_url(
|
|||
.unwrap());
|
||||
}
|
||||
|
||||
if let Some(paths_prefix) = &path_restriction_prefix {
|
||||
let path = req.uri().path();
|
||||
let min_len = min(path.len(), 1);
|
||||
let mut max_len = 0;
|
||||
if &path[0..min_len] != "/"
|
||||
|| !paths_prefix.iter().any(|p| {
|
||||
max_len = min(path.len(), p.len() + 1);
|
||||
p == &path[min_len..max_len]
|
||||
})
|
||||
|| !path[max_len..].starts_with('/')
|
||||
{
|
||||
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||
return Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".to_string())
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(l)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -316,25 +313,102 @@ fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelCon
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn validate_destination(
|
||||
fn validate_tunnel<'a>(
|
||||
_req: &Request<Incoming>,
|
||||
jwt: &TokenData<JwtTunnelConfig>,
|
||||
destination_restriction: &Option<Vec<String>>,
|
||||
) -> Result<(), Response<String>> {
|
||||
let Some(allowed_dests) = &destination_restriction else {
|
||||
return Ok(());
|
||||
};
|
||||
remote: &RemoteAddr,
|
||||
path_prefix: &str,
|
||||
restrictions: &'a RestrictionsRules,
|
||||
) -> Result<&'a RestrictionConfig, Response<String>> {
|
||||
for restriction in &restrictions.restrictions {
|
||||
match &restriction.r#match {
|
||||
MatchConfig::Any => {}
|
||||
MatchConfig::PathPrefix(path) => {
|
||||
if !path.is_match(path_prefix) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
||||
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
||||
return Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".to_string())
|
||||
.unwrap());
|
||||
for allow in &restriction.allow {
|
||||
match allow {
|
||||
AllowConfig::ReverseTunnel(allow) => {
|
||||
if !remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !allow.protocol.is_empty()
|
||||
&& !allow
|
||||
.protocol
|
||||
.contains(&ReverseTunnelConfigProtocol::from(&remote.protocol))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
match &remote.host {
|
||||
Host::Domain(_) => {}
|
||||
Host::Ipv4(ip) => {
|
||||
let ip = IpAddr::V4(*ip);
|
||||
for cidr in &allow.cidr {
|
||||
if cidr.contains(&ip) {
|
||||
return Ok(restriction);
|
||||
}
|
||||
}
|
||||
}
|
||||
Host::Ipv6(ip) => {
|
||||
let ip = IpAddr::V6(*ip);
|
||||
for cidr in &allow.cidr {
|
||||
if cidr.contains(&ip) {
|
||||
return Ok(restriction);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AllowConfig::Tunnel(allow) => {
|
||||
if remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !allow.protocol.is_empty()
|
||||
&& !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
match &remote.host {
|
||||
Host::Domain(host) => {
|
||||
if allow.host.is_match(host) {
|
||||
return Ok(restriction);
|
||||
}
|
||||
}
|
||||
Host::Ipv4(ip) => {
|
||||
let ip = IpAddr::V4(*ip);
|
||||
for cidr in &allow.cidr {
|
||||
if cidr.contains(&ip) {
|
||||
return Ok(restriction);
|
||||
}
|
||||
}
|
||||
}
|
||||
Host::Ipv6(ip) => {
|
||||
let ip = IpAddr::V6(*ip);
|
||||
for cidr in &allow.cidr {
|
||||
if cidr.contains(&ip) {
|
||||
return Ok(restriction);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
warn!("Rejecting connection with not allowed destination: {:?}", remote);
|
||||
Err(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".to_string())
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
async fn ws_server_upgrade(
|
||||
|
@ -360,9 +434,10 @@ async fn ws_server_upgrade(
|
|||
Err(err) => return err,
|
||||
};
|
||||
|
||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||
return err;
|
||||
}
|
||||
let path_prefix = match extract_path_prefix(&req) {
|
||||
Ok(p) => p,
|
||||
Err(err) => return err,
|
||||
};
|
||||
|
||||
let jwt = match extract_tunnel_info(&req) {
|
||||
Ok(jwt) => jwt,
|
||||
|
@ -372,12 +447,26 @@ async fn ws_server_upgrade(
|
|||
Span::current().record("id", &jwt.claims.id);
|
||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||
|
||||
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
||||
return err;
|
||||
let remote = match RemoteAddr::try_from(jwt.claims) {
|
||||
Ok(remote) => remote,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
|
||||
return http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body("Invalid upgrade request".to_string())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
|
||||
Ok(matched_restriction) => {
|
||||
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
|
||||
}
|
||||
Err(err) => return err,
|
||||
}
|
||||
|
||||
let req_protocol = jwt.claims.p.clone();
|
||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
||||
let req_protocol = remote.protocol.clone();
|
||||
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
|
@ -461,9 +550,10 @@ async fn http_server_upgrade(
|
|||
Err(err) => return err.map(Either::Left),
|
||||
};
|
||||
|
||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||
return err.map(Either::Left);
|
||||
}
|
||||
let path_prefix = match extract_path_prefix(&req) {
|
||||
Ok(p) => p,
|
||||
Err(err) => return err.map(Either::Left),
|
||||
};
|
||||
|
||||
let jwt = match extract_tunnel_info(&req) {
|
||||
Ok(jwt) => jwt,
|
||||
|
@ -472,13 +562,26 @@ async fn http_server_upgrade(
|
|||
|
||||
Span::current().record("id", &jwt.claims.id);
|
||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||
let remote = match RemoteAddr::try_from(jwt.claims) {
|
||||
Ok(remote) => remote,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
|
||||
return http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Either::Left("Invalid upgrade request".to_string()))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
||||
return err.map(Either::Left);
|
||||
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
|
||||
Ok(matched_restriction) => {
|
||||
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
|
||||
}
|
||||
Err(err) => return err.map(Either::Left),
|
||||
}
|
||||
|
||||
let req_protocol = jwt.claims.p.clone();
|
||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
||||
let req_protocol = remote.protocol.clone();
|
||||
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
|
|
Loading…
Reference in a new issue