Add config file for restrictions
This commit is contained in:
parent
727e92902c
commit
8a228248d7
7 changed files with 559 additions and 75 deletions
36
Cargo.lock
generated
36
Cargo.lock
generated
|
@ -913,6 +913,9 @@ name = "ipnet"
|
||||||
version = "2.9.0"
|
version = "2.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
|
@ -1524,6 +1527,16 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_regex"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf"
|
||||||
|
dependencies = [
|
||||||
|
"regex",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_with"
|
name = "serde_with"
|
||||||
version = "1.14.0"
|
version = "1.14.0"
|
||||||
|
@ -1546,6 +1559,19 @@ dependencies = [
|
||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_yaml"
|
||||||
|
version = "0.9.34+deprecated"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
"unsafe-libyaml",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha1"
|
name = "sha1"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
|
@ -1950,6 +1976,12 @@ dependencies = [
|
||||||
"tinyvec",
|
"tinyvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unsafe-libyaml"
|
||||||
|
version = "0.2.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "untrusted"
|
name = "untrusted"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
@ -2285,6 +2317,7 @@ dependencies = [
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
|
"ipnet",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"log",
|
"log",
|
||||||
"nix",
|
"nix",
|
||||||
|
@ -2293,10 +2326,13 @@ dependencies = [
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"ppp",
|
"ppp",
|
||||||
|
"regex",
|
||||||
"rustls-native-certs",
|
"rustls-native-certs",
|
||||||
"rustls-pemfile 2.1.1",
|
"rustls-pemfile 2.1.1",
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_regex",
|
||||||
|
"serde_yaml",
|
||||||
"socket2",
|
"socket2",
|
||||||
"testcontainers",
|
"testcontainers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
@ -18,7 +18,13 @@ fast-socks5 = { version = "0.9.6", features = [] }
|
||||||
fastwebsockets = { version = "0.7.1", features = ["upgrade", "simd", "unstable-split"] }
|
fastwebsockets = { version = "0.7.1", features = ["upgrade", "simd", "unstable-split"] }
|
||||||
futures-util = { version = "0.3.30" }
|
futures-util = { version = "0.3.30" }
|
||||||
hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] }
|
hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] }
|
||||||
ppp = { version = "2.2.0", features = [] }
|
ppp = { version = "2.2.0", features = [] }
|
||||||
|
|
||||||
|
# For config file parsing
|
||||||
|
regex = { version = "1.10.4", default-features = false, features = ["std", "perf"] }
|
||||||
|
serde_regex = "1.1.0"
|
||||||
|
serde_yaml = { version = "0.9.34", features = [] }
|
||||||
|
ipnet = { version = "2.9.0", features = ["serde"] }
|
||||||
|
|
||||||
hyper = { version = "1.2.0", features = ["client", "http1", "http2"] }
|
hyper = { version = "1.2.0", features = ["client", "http1", "http2"] }
|
||||||
hyper-util = { version = "0.1.3", features = ["tokio", "server", "server-auto"] }
|
hyper-util = { version = "0.1.3", features = ["tokio", "server", "server-auto"] }
|
||||||
|
|
84
restrictions.yaml
Normal file
84
restrictions.yaml
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Restrictions are whitelist rules for the tunnels
|
||||||
|
# By default, all requests are denied and only if a restriction match, the request is allowed
|
||||||
|
restrictions:
|
||||||
|
- name: "Allow all"
|
||||||
|
description: "This restriction allows all requests"
|
||||||
|
# This restriction apply only if it matches the prefix that match the given regex
|
||||||
|
# The regex does a match, so if you want to match exactly you need to bound the pattern with ^ $
|
||||||
|
# I.e: "tesotron" is going to match "XXXtesotronXXX", but "^tesotron$" is going to match only "tesotron"
|
||||||
|
match: !PathPrefix "^.*$"
|
||||||
|
|
||||||
|
# This is th list of tunnels your restriction is going to allow
|
||||||
|
# The list is going to be checked in order, the first match is going to allow the request
|
||||||
|
allow:
|
||||||
|
# !Tunnel allows forward tunnels
|
||||||
|
- !Tunnel
|
||||||
|
# Protocol that are allowed. Empty list means all protocols are allowed
|
||||||
|
protocol:
|
||||||
|
- Tcp
|
||||||
|
- Udp
|
||||||
|
# Port that are allowed. Can be a single port or an inclusive range (i.e. 80..90)
|
||||||
|
port: 9999
|
||||||
|
|
||||||
|
# if the tunnel wants to connect to a specific host, this regex must match
|
||||||
|
host: ^.*$
|
||||||
|
# if the tunnel wants to connect to a specific IP, it must match one of the network cidr
|
||||||
|
cidr:
|
||||||
|
- 0.0.0.0/0
|
||||||
|
- ::/0
|
||||||
|
|
||||||
|
# !ReverseTunnel allows reverse tunnels
|
||||||
|
# Not specifying anything means all reverse tunnels are allowed
|
||||||
|
- !ReverseTunnel
|
||||||
|
protocol:
|
||||||
|
- Tcp
|
||||||
|
- Udp
|
||||||
|
- Socks5
|
||||||
|
- Unix
|
||||||
|
port: 1..65535
|
||||||
|
cidr:
|
||||||
|
- 0.0.0.0/0
|
||||||
|
- ::/0
|
||||||
|
|
||||||
|
---
|
||||||
|
# Examples
|
||||||
|
restrictions:
|
||||||
|
- name: "example 1"
|
||||||
|
description: "Only allow forward tunnels to port 443 and forbid reverse tunnels"
|
||||||
|
match: !PathPrefix "^.*$"
|
||||||
|
allow:
|
||||||
|
- !Tunnel
|
||||||
|
port: 443
|
||||||
|
---
|
||||||
|
restrictions:
|
||||||
|
- name: "example 2"
|
||||||
|
description: "Only allow forward tunnels to local ssh and forbid reverse tunnels"
|
||||||
|
match: !PathPrefix "^.*$"
|
||||||
|
allow:
|
||||||
|
- !Tunnel
|
||||||
|
protocol:
|
||||||
|
- Tcp
|
||||||
|
port: 22
|
||||||
|
host: ^localhost$
|
||||||
|
cidr:
|
||||||
|
- 127.0.0.1/32
|
||||||
|
---
|
||||||
|
restrictions:
|
||||||
|
- name: "example 3"
|
||||||
|
description: "Only allow socks5 reverse tunnels listening on port between 1080..1443 on lan network"
|
||||||
|
match: !PathPrefix "^.*$"
|
||||||
|
allow:
|
||||||
|
- !ReverseTunnel
|
||||||
|
protocol:
|
||||||
|
- Socks5
|
||||||
|
port: 1080..1443
|
||||||
|
cidr:
|
||||||
|
- 192.168.0.0/16
|
||||||
|
---
|
||||||
|
restrictions:
|
||||||
|
- name: "example 4"
|
||||||
|
description: "Allow everything for client using path prefix my-super-secret-path"
|
||||||
|
match: !PathPrefix "^my-super-secret-path$"
|
||||||
|
allow:
|
||||||
|
- !Tunnel
|
||||||
|
- !ReverseTunnel
|
47
src/main.rs
47
src/main.rs
|
@ -1,5 +1,6 @@
|
||||||
mod dns;
|
mod dns;
|
||||||
mod embedded_certificate;
|
mod embedded_certificate;
|
||||||
|
mod restrictions;
|
||||||
mod socks5;
|
mod socks5;
|
||||||
mod socks5_udp;
|
mod socks5_udp;
|
||||||
mod stdio;
|
mod stdio;
|
||||||
|
@ -40,6 +41,7 @@ use tokio_rustls::TlsConnector;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use crate::dns::DnsResolver;
|
use crate::dns::DnsResolver;
|
||||||
|
use crate::restrictions::types::RestrictionsRules;
|
||||||
use crate::tunnel::tls_reloader::TlsReloader;
|
use crate::tunnel::tls_reloader::TlsReloader;
|
||||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||||
use crate::udp::MyUdpSocket;
|
use crate::udp::MyUdpSocket;
|
||||||
|
@ -287,6 +289,11 @@ struct Server {
|
||||||
)]
|
)]
|
||||||
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
||||||
|
|
||||||
|
/// Path to the location of the restriction yaml config file.
|
||||||
|
/// Restriction file is automatically reloaded if it changes
|
||||||
|
#[arg(long, verbatim_doc_comment)]
|
||||||
|
restriction_file: Option<PathBuf>,
|
||||||
|
|
||||||
/// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate.
|
/// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate.
|
||||||
/// The certificate will be automatically reloaded if it changes
|
/// The certificate will be automatically reloaded if it changes
|
||||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||||
|
@ -319,6 +326,15 @@ enum LocalProtocol {
|
||||||
Unix { path: PathBuf },
|
Unix { path: PathBuf },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LocalProtocol {
|
||||||
|
pub fn is_reverse_tunnel(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self,
|
||||||
|
LocalProtocol::ReverseTcp | LocalProtocol::ReverseUdp { .. } | LocalProtocol::ReverseSocks5
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LocalToRemote {
|
pub struct LocalToRemote {
|
||||||
local_protocol: LocalProtocol,
|
local_protocol: LocalProtocol,
|
||||||
|
@ -607,13 +623,12 @@ pub struct TlsServerConfig {
|
||||||
pub struct WsServerConfig {
|
pub struct WsServerConfig {
|
||||||
pub socket_so_mark: Option<u32>,
|
pub socket_so_mark: Option<u32>,
|
||||||
pub bind: SocketAddr,
|
pub bind: SocketAddr,
|
||||||
pub restrict_to: Option<Vec<String>>,
|
|
||||||
pub restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
|
||||||
pub websocket_ping_frequency: Option<Duration>,
|
pub websocket_ping_frequency: Option<Duration>,
|
||||||
pub timeout_connect: Duration,
|
pub timeout_connect: Duration,
|
||||||
pub websocket_mask_frame: bool,
|
pub websocket_mask_frame: bool,
|
||||||
pub tls: Option<TlsServerConfig>,
|
pub tls: Option<TlsServerConfig>,
|
||||||
pub dns_resolver: DnsResolver,
|
pub dns_resolver: DnsResolver,
|
||||||
|
pub restrictions: RestrictionsRules,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for WsServerConfig {
|
impl Debug for WsServerConfig {
|
||||||
|
@ -621,8 +636,6 @@ impl Debug for WsServerConfig {
|
||||||
f.debug_struct("WsServerConfig")
|
f.debug_struct("WsServerConfig")
|
||||||
.field("socket_so_mark", &self.socket_so_mark)
|
.field("socket_so_mark", &self.socket_so_mark)
|
||||||
.field("bind", &self.bind)
|
.field("bind", &self.bind)
|
||||||
.field("restrict_to", &self.restrict_to)
|
|
||||||
.field("restrict_http_upgrade_path_prefix", &self.restrict_http_upgrade_path_prefix)
|
|
||||||
.field("websocket_ping_frequency", &self.websocket_ping_frequency)
|
.field("websocket_ping_frequency", &self.websocket_ping_frequency)
|
||||||
.field("timeout_connect", &self.timeout_connect)
|
.field("timeout_connect", &self.timeout_connect)
|
||||||
.field("websocket_mask_frame", &self.websocket_mask_frame)
|
.field("websocket_mask_frame", &self.websocket_mask_frame)
|
||||||
|
@ -1246,16 +1259,38 @@ async fn main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let restrictions = if let Some(path) = &args.restriction_file {
|
||||||
|
RestrictionsRules::from_config_file(path).expect("Cannot parse restriction file")
|
||||||
|
} else {
|
||||||
|
let restrict_to: Vec<(String, u16)> = args
|
||||||
|
.restrict_to
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&[])
|
||||||
|
.iter()
|
||||||
|
.map(|x| {
|
||||||
|
let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format");
|
||||||
|
(host.to_string(), port.parse::<u16>().expect("Invalid restrict-to port format"))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let restriction_cfg = RestrictionsRules::from_path_prefix(
|
||||||
|
args.restrict_http_upgrade_path_prefix.as_deref().unwrap_or(&[]),
|
||||||
|
&restrict_to,
|
||||||
|
)
|
||||||
|
.expect("Cannot covertion restriction rules from path-prefix and restric-to");
|
||||||
|
restriction_cfg
|
||||||
|
};
|
||||||
|
|
||||||
let server_config = WsServerConfig {
|
let server_config = WsServerConfig {
|
||||||
socket_so_mark: args.socket_so_mark,
|
socket_so_mark: args.socket_so_mark,
|
||||||
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
|
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
|
||||||
restrict_to: args.restrict_to,
|
|
||||||
restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix,
|
|
||||||
websocket_ping_frequency: args.websocket_ping_frequency_sec,
|
websocket_ping_frequency: args.websocket_ping_frequency_sec,
|
||||||
timeout_connect: Duration::from_secs(10),
|
timeout_connect: Duration::from_secs(10),
|
||||||
websocket_mask_frame: args.websocket_mask_frame,
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
tls: tls_config,
|
tls: tls_config,
|
||||||
dns_resolver,
|
dns_resolver,
|
||||||
|
restrictions,
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
79
src/restrictions/mod.rs
Normal file
79
src/restrictions/mod.rs
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
use crate::restrictions::types::{default_cidr, default_host, default_port};
|
||||||
|
use regex::Regex;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
use std::path::Path;
|
||||||
|
use types::RestrictionsRules;
|
||||||
|
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
impl RestrictionsRules {
|
||||||
|
pub fn from_config_file(config_path: &Path) -> anyhow::Result<RestrictionsRules> {
|
||||||
|
let restrictions: RestrictionsRules = serde_yaml::from_reader(BufReader::new(File::open(config_path)?))?;
|
||||||
|
Ok(restrictions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_path_prefix(
|
||||||
|
path_prefixes: &[String],
|
||||||
|
restrict_to: &[(String, u16)],
|
||||||
|
) -> anyhow::Result<RestrictionsRules> {
|
||||||
|
let mut tunnels_restrictions = if restrict_to.is_empty() {
|
||||||
|
let r = types::AllowConfig::Tunnel(types::AllowTunnelConfig {
|
||||||
|
protocol: vec![],
|
||||||
|
port: default_port(),
|
||||||
|
host: default_host(),
|
||||||
|
cidr: default_cidr(),
|
||||||
|
});
|
||||||
|
vec![r]
|
||||||
|
} else {
|
||||||
|
restrict_to
|
||||||
|
.iter()
|
||||||
|
.map(|(host, port)| {
|
||||||
|
// Fixme: Remove the unwrap
|
||||||
|
let reg = Regex::new(&format!("^{}$", regex::escape(host))).unwrap();
|
||||||
|
types::AllowConfig::Tunnel(types::AllowTunnelConfig {
|
||||||
|
protocol: vec![],
|
||||||
|
port: RangeInclusive::new(*port, *port),
|
||||||
|
host: reg,
|
||||||
|
cidr: default_cidr(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
tunnels_restrictions.push(types::AllowConfig::ReverseTunnel(types::AllowReverseTunnelConfig {
|
||||||
|
protocol: vec![],
|
||||||
|
port: default_port(),
|
||||||
|
cidr: default_cidr(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
let restrictions = if path_prefixes.is_empty() {
|
||||||
|
// if no path prefixes are provided, we allow all
|
||||||
|
let reg = Regex::new(".").unwrap();
|
||||||
|
let r = types::RestrictionConfig {
|
||||||
|
name: "Allow All".to_string(),
|
||||||
|
r#match: types::MatchConfig::PathPrefix(reg),
|
||||||
|
allow: tunnels_restrictions,
|
||||||
|
};
|
||||||
|
vec![r]
|
||||||
|
} else {
|
||||||
|
path_prefixes
|
||||||
|
.iter()
|
||||||
|
.map(|path_prefix| {
|
||||||
|
// Fixme: Remove the unwrap
|
||||||
|
let reg = Regex::new(&format!("^{}$", regex::escape(path_prefix))).unwrap();
|
||||||
|
types::RestrictionConfig {
|
||||||
|
name: format!("Allow path prefix {}", path_prefix),
|
||||||
|
r#match: types::MatchConfig::PathPrefix(reg),
|
||||||
|
allow: tunnels_restrictions.clone(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let restrictions = RestrictionsRules { restrictions };
|
||||||
|
|
||||||
|
Ok(restrictions)
|
||||||
|
}
|
||||||
|
}
|
141
src/restrictions/types.rs
Normal file
141
src/restrictions/types.rs
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
use crate::LocalProtocol;
|
||||||
|
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||||
|
use regex::Regex;
|
||||||
|
use serde::{Deserialize, Deserializer};
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RestrictionsRules {
|
||||||
|
pub restrictions: Vec<RestrictionConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RestrictionConfig {
|
||||||
|
pub name: String,
|
||||||
|
pub r#match: MatchConfig,
|
||||||
|
pub allow: Vec<AllowConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub enum MatchConfig {
|
||||||
|
Any,
|
||||||
|
#[serde(with = "serde_regex")]
|
||||||
|
PathPrefix(Regex),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub enum AllowConfig {
|
||||||
|
ReverseTunnel(AllowReverseTunnelConfig),
|
||||||
|
Tunnel(AllowTunnelConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AllowTunnelConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub protocol: Vec<TunnelConfigProtocol>,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "deserialize_port_range")]
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: RangeInclusive<u16>,
|
||||||
|
|
||||||
|
#[serde(with = "serde_regex")]
|
||||||
|
#[serde(default = "default_host")]
|
||||||
|
pub host: Regex,
|
||||||
|
|
||||||
|
#[serde(default = "default_cidr")]
|
||||||
|
pub cidr: Vec<IpNet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AllowReverseTunnelConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub protocol: Vec<ReverseTunnelConfigProtocol>,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "deserialize_port_range")]
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: RangeInclusive<u16>,
|
||||||
|
|
||||||
|
#[serde(default = "default_cidr")]
|
||||||
|
pub cidr: Vec<IpNet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
|
||||||
|
pub enum TunnelConfigProtocol {
|
||||||
|
Tcp,
|
||||||
|
Udp,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Eq, PartialEq)]
|
||||||
|
pub enum ReverseTunnelConfigProtocol {
|
||||||
|
Tcp,
|
||||||
|
Udp,
|
||||||
|
Socks5,
|
||||||
|
Unix,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_port() -> RangeInclusive<u16> {
|
||||||
|
RangeInclusive::new(1, 65535)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_host() -> Regex {
|
||||||
|
Regex::new("^.*$").unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_cidr() -> Vec<IpNet> {
|
||||||
|
vec![IpNet::V4(Ipv4Net::default()), IpNet::V6(Ipv6Net::default())]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_port_range<'de, D>(deserializer: D) -> Result<RangeInclusive<u16>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let s = String::deserialize(deserializer)?;
|
||||||
|
let range = if let Some((l, r)) = s.split_once("..") {
|
||||||
|
RangeInclusive::new(
|
||||||
|
l.parse().map_err(serde::de::Error::custom)?,
|
||||||
|
r.parse().map_err(serde::de::Error::custom)?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let port = s.parse::<u16>().map_err(serde::de::Error::custom)?;
|
||||||
|
RangeInclusive::new(port, port)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(range)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&LocalProtocol> for ReverseTunnelConfigProtocol {
|
||||||
|
fn from(value: &LocalProtocol) -> Self {
|
||||||
|
match value {
|
||||||
|
LocalProtocol::Tcp { .. }
|
||||||
|
| LocalProtocol::Udp { .. }
|
||||||
|
| LocalProtocol::Stdio
|
||||||
|
| LocalProtocol::Socks5 { .. }
|
||||||
|
| LocalProtocol::TProxyTcp { .. }
|
||||||
|
| LocalProtocol::TProxyUdp { .. }
|
||||||
|
| LocalProtocol::Unix { .. } => ReverseTunnelConfigProtocol::Unknown,
|
||||||
|
LocalProtocol::ReverseTcp => ReverseTunnelConfigProtocol::Tcp,
|
||||||
|
LocalProtocol::ReverseUdp { .. } => ReverseTunnelConfigProtocol::Udp,
|
||||||
|
LocalProtocol::ReverseSocks5 => ReverseTunnelConfigProtocol::Socks5,
|
||||||
|
LocalProtocol::ReverseUnix { .. } => ReverseTunnelConfigProtocol::Unix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<&LocalProtocol> for TunnelConfigProtocol {
|
||||||
|
fn from(value: &LocalProtocol) -> Self {
|
||||||
|
match value {
|
||||||
|
LocalProtocol::ReverseTcp
|
||||||
|
| LocalProtocol::ReverseUdp { .. }
|
||||||
|
| LocalProtocol::ReverseSocks5
|
||||||
|
| LocalProtocol::ReverseUnix { .. }
|
||||||
|
| LocalProtocol::Stdio
|
||||||
|
| LocalProtocol::Socks5 { .. }
|
||||||
|
| LocalProtocol::TProxyTcp { .. }
|
||||||
|
| LocalProtocol::TProxyUdp { .. }
|
||||||
|
| LocalProtocol::Unix { .. } => TunnelConfigProtocol::Unknown,
|
||||||
|
LocalProtocol::Tcp { .. } => TunnelConfigProtocol::Tcp,
|
||||||
|
LocalProtocol::Udp { .. } => TunnelConfigProtocol::Udp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ use std::cmp::min;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::ops::{Deref, Not};
|
use std::ops::Deref;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
@ -26,6 +26,9 @@ use jsonwebtoken::TokenData;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
|
use crate::restrictions::types::{
|
||||||
|
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
|
||||||
|
};
|
||||||
use crate::socks5::Socks5Stream;
|
use crate::socks5::Socks5Stream;
|
||||||
use crate::tunnel::tls_reloader::TlsReloader;
|
use crate::tunnel::tls_reloader::TlsReloader;
|
||||||
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
||||||
|
@ -43,12 +46,11 @@ use uuid::Uuid;
|
||||||
|
|
||||||
async fn run_tunnel(
|
async fn run_tunnel(
|
||||||
server_config: &WsServerConfig,
|
server_config: &WsServerConfig,
|
||||||
jwt: TokenData<JwtTunnelConfig>,
|
remote: RemoteAddr,
|
||||||
client_address: SocketAddr,
|
client_address: SocketAddr,
|
||||||
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
|
||||||
match jwt.claims.p {
|
match remote.protocol {
|
||||||
LocalProtocol::Udp { timeout, .. } => {
|
LocalProtocol::Udp { timeout, .. } => {
|
||||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
|
||||||
let cnx = udp::connect(
|
let cnx = udp::connect(
|
||||||
&remote.host,
|
&remote.host,
|
||||||
remote.port,
|
remote.port,
|
||||||
|
@ -60,7 +62,6 @@ async fn run_tunnel(
|
||||||
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
|
||||||
}
|
}
|
||||||
LocalProtocol::Tcp { proxy_protocol } => {
|
LocalProtocol::Tcp { proxy_protocol } => {
|
||||||
let remote = RemoteAddr::try_from(jwt.claims)?;
|
|
||||||
let mut socket = tcp::connect(
|
let mut socket = tcp::connect(
|
||||||
&remote.host,
|
&remote.host,
|
||||||
remote.port,
|
remote.port,
|
||||||
|
@ -89,14 +90,14 @@ async fn run_tunnel(
|
||||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
|
||||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (remote.host, remote.port);
|
||||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||||
let listening_server = tcp::run_server(bind.parse()?, false);
|
let listening_server = tcp::run_server(bind.parse()?, false);
|
||||||
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||||
let (local_rx, local_tx) = tcp.into_split();
|
let (local_rx, local_tx) = tcp.into_split();
|
||||||
|
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: jwt.claims.p,
|
protocol: remote.protocol,
|
||||||
host: local_srv.0,
|
host: local_srv.0,
|
||||||
port: local_srv.1,
|
port: local_srv.1,
|
||||||
};
|
};
|
||||||
|
@ -107,7 +108,7 @@ async fn run_tunnel(
|
||||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
|
||||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (remote.host, remote.port);
|
||||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||||
let listening_server =
|
let listening_server =
|
||||||
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
|
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
|
||||||
|
@ -115,7 +116,7 @@ async fn run_tunnel(
|
||||||
let (local_rx, local_tx) = tokio::io::split(udp);
|
let (local_rx, local_tx) = tokio::io::split(udp);
|
||||||
|
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: jwt.claims.p,
|
protocol: remote.protocol,
|
||||||
host: local_srv.0,
|
host: local_srv.0,
|
||||||
port: local_srv.1,
|
port: local_srv.1,
|
||||||
};
|
};
|
||||||
|
@ -126,7 +127,7 @@ async fn run_tunnel(
|
||||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> =
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> =
|
||||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (remote.host, remote.port);
|
||||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||||
let listening_server = socks5::run_server(bind.parse()?, None);
|
let listening_server = socks5::run_server(bind.parse()?, None);
|
||||||
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||||
|
@ -149,13 +150,13 @@ async fn run_tunnel(
|
||||||
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UnixStream>>>> =
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UnixStream>>>> =
|
||||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (remote.host, remote.port);
|
||||||
let listening_server = unix_socket::run_server(path);
|
let listening_server = unix_socket::run_server(path);
|
||||||
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||||
let (local_rx, local_tx) = stream.into_split();
|
let (local_rx, local_tx) = stream.into_split();
|
||||||
|
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: jwt.claims.p.clone(),
|
protocol: remote.protocol,
|
||||||
host: local_srv.0,
|
host: local_srv.0,
|
||||||
port: local_srv.1,
|
port: local_srv.1,
|
||||||
};
|
};
|
||||||
|
@ -163,7 +164,7 @@ async fn run_tunnel(
|
||||||
}
|
}
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
LocalProtocol::ReverseUnix { ref path } => {
|
LocalProtocol::ReverseUnix { ref path } => {
|
||||||
error!("Received an unsupported target protocol {:?}", jwt.claims);
|
error!("Received an unsupported target protocol {:?}", remote);
|
||||||
Err(anyhow::anyhow!("Invalid upgrade request"))
|
Err(anyhow::anyhow!("Invalid upgrade request"))
|
||||||
}
|
}
|
||||||
LocalProtocol::Stdio
|
LocalProtocol::Stdio
|
||||||
|
@ -171,7 +172,7 @@ async fn run_tunnel(
|
||||||
| LocalProtocol::TProxyTcp
|
| LocalProtocol::TProxyTcp
|
||||||
| LocalProtocol::TProxyUdp { .. }
|
| LocalProtocol::TProxyUdp { .. }
|
||||||
| LocalProtocol::Unix { .. } => {
|
| LocalProtocol::Unix { .. } => {
|
||||||
error!("Received an unsupported target protocol {:?}", jwt.claims);
|
error!("Received an unsupported target protocol {:?}", remote);
|
||||||
Err(anyhow::anyhow!("Invalid upgrade request"))
|
Err(anyhow::anyhow!("Invalid upgrade request"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -251,11 +252,26 @@ fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &s
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn validate_url(
|
fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, Response<String>> {
|
||||||
req: &Request<Incoming>,
|
let path = req.uri().path();
|
||||||
path_restriction_prefix: &Option<Vec<String>>,
|
let min_len = min(path.len(), 1);
|
||||||
) -> Result<(), Response<String>> {
|
if &path[0..min_len] != "/" {
|
||||||
if !req.uri().path().ends_with("/events") {
|
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||||
|
return Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some((l, r)) = path[min_len..].split_once('/') else {
|
||||||
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
|
return Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".into())
|
||||||
|
.unwrap());
|
||||||
|
};
|
||||||
|
|
||||||
|
if !r.ends_with("events") {
|
||||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
return Err(http::Response::builder()
|
return Err(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
@ -263,26 +279,7 @@ fn validate_url(
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(paths_prefix) = &path_restriction_prefix {
|
Ok(l)
|
||||||
let path = req.uri().path();
|
|
||||||
let min_len = min(path.len(), 1);
|
|
||||||
let mut max_len = 0;
|
|
||||||
if &path[0..min_len] != "/"
|
|
||||||
|| !paths_prefix.iter().any(|p| {
|
|
||||||
max_len = min(path.len(), p.len() + 1);
|
|
||||||
p == &path[min_len..max_len]
|
|
||||||
})
|
|
||||||
|| !path[max_len..].starts_with('/')
|
|
||||||
{
|
|
||||||
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
|
||||||
return Err(http::Response::builder()
|
|
||||||
.status(StatusCode::BAD_REQUEST)
|
|
||||||
.body("Invalid upgrade request".to_string())
|
|
||||||
.unwrap());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -316,25 +313,102 @@ fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelCon
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn validate_destination(
|
fn validate_tunnel<'a>(
|
||||||
_req: &Request<Incoming>,
|
_req: &Request<Incoming>,
|
||||||
jwt: &TokenData<JwtTunnelConfig>,
|
remote: &RemoteAddr,
|
||||||
destination_restriction: &Option<Vec<String>>,
|
path_prefix: &str,
|
||||||
) -> Result<(), Response<String>> {
|
restrictions: &'a RestrictionsRules,
|
||||||
let Some(allowed_dests) = &destination_restriction else {
|
) -> Result<&'a RestrictionConfig, Response<String>> {
|
||||||
return Ok(());
|
for restriction in &restrictions.restrictions {
|
||||||
};
|
match &restriction.r#match {
|
||||||
|
MatchConfig::Any => {}
|
||||||
|
MatchConfig::PathPrefix(path) => {
|
||||||
|
if !path.is_match(path_prefix) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
for allow in &restriction.allow {
|
||||||
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
match allow {
|
||||||
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
AllowConfig::ReverseTunnel(allow) => {
|
||||||
return Err(http::Response::builder()
|
if !remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
|
||||||
.status(StatusCode::BAD_REQUEST)
|
continue;
|
||||||
.body("Invalid upgrade request".to_string())
|
}
|
||||||
.unwrap());
|
|
||||||
|
if !allow.protocol.is_empty()
|
||||||
|
&& !allow
|
||||||
|
.protocol
|
||||||
|
.contains(&ReverseTunnelConfigProtocol::from(&remote.protocol))
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match &remote.host {
|
||||||
|
Host::Domain(_) => {}
|
||||||
|
Host::Ipv4(ip) => {
|
||||||
|
let ip = IpAddr::V4(*ip);
|
||||||
|
for cidr in &allow.cidr {
|
||||||
|
if cidr.contains(&ip) {
|
||||||
|
return Ok(restriction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Host::Ipv6(ip) => {
|
||||||
|
let ip = IpAddr::V6(*ip);
|
||||||
|
for cidr in &allow.cidr {
|
||||||
|
if cidr.contains(&ip) {
|
||||||
|
return Ok(restriction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AllowConfig::Tunnel(allow) => {
|
||||||
|
if remote.protocol.is_reverse_tunnel() || !allow.port.contains(&remote.port) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allow.protocol.is_empty()
|
||||||
|
&& !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol))
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match &remote.host {
|
||||||
|
Host::Domain(host) => {
|
||||||
|
if allow.host.is_match(host) {
|
||||||
|
return Ok(restriction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Host::Ipv4(ip) => {
|
||||||
|
let ip = IpAddr::V4(*ip);
|
||||||
|
for cidr in &allow.cidr {
|
||||||
|
if cidr.contains(&ip) {
|
||||||
|
return Ok(restriction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Host::Ipv6(ip) => {
|
||||||
|
let ip = IpAddr::V6(*ip);
|
||||||
|
for cidr in &allow.cidr {
|
||||||
|
if cidr.contains(&ip) {
|
||||||
|
return Ok(restriction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
warn!("Rejecting connection with not allowed destination: {:?}", remote);
|
||||||
|
Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ws_server_upgrade(
|
async fn ws_server_upgrade(
|
||||||
|
@ -360,9 +434,10 @@ async fn ws_server_upgrade(
|
||||||
Err(err) => return err,
|
Err(err) => return err,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
let path_prefix = match extract_path_prefix(&req) {
|
||||||
return err;
|
Ok(p) => p,
|
||||||
}
|
Err(err) => return err,
|
||||||
|
};
|
||||||
|
|
||||||
let jwt = match extract_tunnel_info(&req) {
|
let jwt = match extract_tunnel_info(&req) {
|
||||||
Ok(jwt) => jwt,
|
Ok(jwt) => jwt,
|
||||||
|
@ -372,12 +447,26 @@ async fn ws_server_upgrade(
|
||||||
Span::current().record("id", &jwt.claims.id);
|
Span::current().record("id", &jwt.claims.id);
|
||||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
|
|
||||||
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
let remote = match RemoteAddr::try_from(jwt.claims) {
|
||||||
return err;
|
Ok(remote) => remote,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
|
||||||
|
return http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
|
||||||
|
Ok(matched_restriction) => {
|
||||||
|
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
|
||||||
|
}
|
||||||
|
Err(err) => return err,
|
||||||
}
|
}
|
||||||
|
|
||||||
let req_protocol = jwt.claims.p.clone();
|
let req_protocol = remote.protocol.clone();
|
||||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
|
@ -461,9 +550,10 @@ async fn http_server_upgrade(
|
||||||
Err(err) => return err.map(Either::Left),
|
Err(err) => return err.map(Either::Left),
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
let path_prefix = match extract_path_prefix(&req) {
|
||||||
return err.map(Either::Left);
|
Ok(p) => p,
|
||||||
}
|
Err(err) => return err.map(Either::Left),
|
||||||
|
};
|
||||||
|
|
||||||
let jwt = match extract_tunnel_info(&req) {
|
let jwt = match extract_tunnel_info(&req) {
|
||||||
Ok(jwt) => jwt,
|
Ok(jwt) => jwt,
|
||||||
|
@ -472,13 +562,26 @@ async fn http_server_upgrade(
|
||||||
|
|
||||||
Span::current().record("id", &jwt.claims.id);
|
Span::current().record("id", &jwt.claims.id);
|
||||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
|
let remote = match RemoteAddr::try_from(jwt.claims) {
|
||||||
|
Ok(remote) => remote,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
|
||||||
|
return http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body(Either::Left("Invalid upgrade request".to_string()))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
match validate_tunnel(&req, &remote, path_prefix, &server_config.restrictions) {
|
||||||
return err.map(Either::Left);
|
Ok(matched_restriction) => {
|
||||||
|
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
|
||||||
|
}
|
||||||
|
Err(err) => return err.map(Either::Left),
|
||||||
}
|
}
|
||||||
|
|
||||||
let req_protocol = jwt.claims.p.clone();
|
let req_protocol = remote.protocol.clone();
|
||||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
let tunnel = match run_tunnel(&server_config, remote, client_addr).await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
|
|
Loading…
Reference in a new issue