feat(restriction): Auto-reload restriction file

This commit is contained in:
Σrebe - Romain GERARD 2024-05-01 12:07:18 +02:00
parent 368f6657fd
commit 5ef14d1a8c
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 228 additions and 25 deletions

View file

@ -628,7 +628,7 @@ pub struct WsServerConfig {
pub websocket_mask_frame: bool, pub websocket_mask_frame: bool,
pub tls: Option<TlsServerConfig>, pub tls: Option<TlsServerConfig>,
pub dns_resolver: DnsResolver, pub dns_resolver: DnsResolver,
pub restrictions: RestrictionsRules, pub restriction_config: Option<PathBuf>,
} }
impl Debug for WsServerConfig { impl Debug for WsServerConfig {
@ -639,6 +639,7 @@ impl Debug for WsServerConfig {
.field("websocket_ping_frequency", &self.websocket_ping_frequency) .field("websocket_ping_frequency", &self.websocket_ping_frequency)
.field("timeout_connect", &self.timeout_connect) .field("timeout_connect", &self.timeout_connect)
.field("websocket_mask_frame", &self.websocket_mask_frame) .field("websocket_mask_frame", &self.websocket_mask_frame)
.field("restriction_config", &self.restriction_config)
.field("tls", &self.tls.is_some()) .field("tls", &self.tls.is_some())
.field( .field(
"mTLS", "mTLS",
@ -1270,7 +1271,10 @@ async fn main() {
.iter() .iter()
.map(|x| { .map(|x| {
let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format"); let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format");
(host.to_string(), port.parse::<u16>().expect("Invalid restrict-to port format")) (
host.trim_matches(&['[', ']']).to_string(),
port.parse::<u16>().expect("Invalid restrict-to port format"),
)
}) })
.collect(); .collect();
@ -1281,7 +1285,6 @@ async fn main() {
.expect("Cannot convert restriction rules from path-prefix and restric-to"); .expect("Cannot convert restriction rules from path-prefix and restric-to");
restriction_cfg restriction_cfg
}; };
debug!("Restriction rules: {:?}", restrictions);
let server_config = WsServerConfig { let server_config = WsServerConfig {
socket_so_mark: args.socket_so_mark, socket_so_mark: args.socket_so_mark,
@ -1291,7 +1294,7 @@ async fn main() {
websocket_mask_frame: args.websocket_mask_frame, websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config, tls: tls_config,
dns_resolver, dns_resolver,
restrictions, restriction_config: args.restrict_config,
}; };
info!( info!(
@ -1299,7 +1302,8 @@ async fn main() {
env!("CARGO_PKG_VERSION"), env!("CARGO_PKG_VERSION"),
server_config server_config
); );
tunnel::server::run_server(Arc::new(server_config)) debug!("Restriction rules: {:#?}", restrictions);
tunnel::server::run_server(Arc::new(server_config), restrictions)
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err); panic!("Cannot start wstunnel server: {:?}", err);

View file

@ -0,0 +1,174 @@
use super::types::RestrictionsRules;
use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static};
use anyhow::Context;
use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher};
use parking_lot::Mutex;
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tokio::sync::futures::Notified;
use tokio::sync::Notify;
use tracing::{error, info, warn};
struct ConfigReloaderState {
fs_watcher: Mutex<RecommendedWatcher>,
config_path: PathBuf,
should_reload_config: Notify,
}
enum RestrictionsRulesReloaderState {
Static(Notify),
Config(Arc<ConfigReloaderState>),
}
impl RestrictionsRulesReloaderState {
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
match self {
Static(_) => unreachable!(),
Config(this) => &this.fs_watcher,
}
}
}
pub struct RestrictionsRulesReloader {
state: RestrictionsRulesReloaderState,
restrictions: Arc<RestrictionsRules>,
}
impl RestrictionsRulesReloader {
pub fn new(restrictions_rules: RestrictionsRules, config_path: Option<PathBuf>) -> anyhow::Result<Self> {
// If there is no custom certificate and private key, there is nothing to watch
let config_path = if let Some(config_path) = config_path {
config_path
} else {
return Ok(Self {
state: Static(Notify::new()),
restrictions: Arc::new(restrictions_rules),
});
};
let this = Arc::new(ConfigReloaderState {
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
should_reload_config: Notify::new(),
config_path,
});
info!("Starting to watch restriction config file for changes to reload them");
let mut watcher = notify::recommended_watcher({
let this = Config(this.clone());
move |event: notify::Result<notify::Event>| Self::handle_config_fs_event(&this, event)
})
.with_context(|| "Cannot create restriction config watcher")?;
watcher.watch(&this.config_path, notify::RecursiveMode::NonRecursive)?;
*this.fs_watcher.lock() = watcher;
Ok(Self {
state: Config(this),
restrictions: Arc::new(restrictions_rules),
})
}
pub fn reload_restrictions_config(&mut self) {
let restrictions = match &self.state {
Static(_) => return,
Config(st) => match RestrictionsRules::from_config_file(&st.config_path) {
Ok(restrictions) => {
info!("Restrictions config file has been reloaded");
restrictions
}
Err(err) => {
error!("Cannot reload restrictions config file, keeping the old one. Error: {:?}", err);
return;
}
},
};
self.restrictions = Arc::new(restrictions);
}
pub fn restrictions_rules(&self) -> &Arc<RestrictionsRules> {
&self.restrictions
}
pub fn wait_for_reload(&self) -> Notified {
match &self.state {
Static(st) => st.notified(),
Config(st) => st.should_reload_config.notified(),
}
}
fn try_rewatch_config(this: RestrictionsRulesReloaderState, path: PathBuf) {
thread::spawn(move || {
while !path.exists() {
warn!(
"Restrictions config file {:?} does not exist anymore, waiting for it to be created",
path
);
thread::sleep(Duration::from_secs(10));
}
let mut watcher = this.fs_watcher().lock();
let _ = watcher.unwatch(&path);
let Ok(_) = watcher
.watch(&path, notify::RecursiveMode::NonRecursive)
.map_err(|err| {
error!("Cannot re-set a watch for Restriction config file {:?}: {:?}", path, err);
error!("Restriction config file will not be auto-reloaded anymore");
})
else {
return;
};
drop(watcher);
// Generate a fake event to force-reload the certificate
let event = notify::Event {
kind: EventKind::Create(notify::event::CreateKind::Any),
paths: vec![path],
attrs: Default::default(),
};
match &this {
Static(_) => Self::handle_config_fs_event(&this, Ok(event)),
Config(_) => Self::handle_config_fs_event(&this, Ok(event)),
}
});
}
fn handle_config_fs_event(this: &RestrictionsRulesReloaderState, event: notify::Result<notify::Event>) {
let this = match this {
Static(_) => return,
Config(st) => st,
};
let event = match event {
Ok(event) => event,
Err(err) => {
error!("Error while watching restrictions config file for changes {:?}", err);
return;
}
};
if event.kind.is_access() {
return;
}
trace!("Received event: {:#?}", event);
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) {
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => {
this.should_reload_config.notify_one();
}
EventKind::Remove(_) => {
warn!("Restriction config file has been removed, trying to re-set a watch for it");
Self::try_rewatch_config(Config(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
}
}

View file

@ -13,6 +13,7 @@ use types::RestrictionsRules;
use crate::restrictions::types::{default_cidr, default_host}; use crate::restrictions::types::{default_cidr, default_host};
pub mod config_reloader;
pub mod types; pub mod types;
impl RestrictionsRules { impl RestrictionsRules {

View file

@ -26,6 +26,7 @@ use jsonwebtoken::TokenData;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::restrictions::config_reloader::RestrictionsRulesReloader;
use crate::restrictions::types::{ use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
}; };
@ -418,6 +419,7 @@ fn validate_tunnel<'a>(
async fn ws_server_upgrade( async fn ws_server_upgrade(
server_config: Arc<WsServerConfig>, server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
mut client_addr: SocketAddr, mut client_addr: SocketAddr,
mut req: Request<Incoming>, mut req: Request<Incoming>,
) -> Response<String> { ) -> Response<String> {
@ -463,7 +465,7 @@ async fn ws_server_upgrade(
} }
}; };
match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { match validate_tunnel(&remote, path_prefix, &restrictions) {
Ok(matched_restriction) => { Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
} }
@ -542,6 +544,7 @@ async fn ws_server_upgrade(
async fn http_server_upgrade( async fn http_server_upgrade(
server_config: Arc<WsServerConfig>, server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
mut client_addr: SocketAddr, mut client_addr: SocketAddr,
mut req: Request<Incoming>, mut req: Request<Incoming>,
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> { ) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
@ -578,7 +581,7 @@ async fn http_server_upgrade(
} }
}; };
match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { match validate_tunnel(&remote, path_prefix, &restrictions) {
Ok(matched_restriction) => { Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
} }
@ -664,32 +667,39 @@ impl TlsContext<'_> {
} }
} }
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> { pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: RestrictionsRules) -> anyhow::Result<()> {
info!("Starting wstunnel server listening on {}", server_config.bind); info!("Starting wstunnel server listening on {}", server_config.bind);
// setup upgrade request handler // setup upgrade request handler
let mk_websocket_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| { let mk_websocket_upgrade_fn =
move |req: Request<Incoming>| { |server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, client_addr: SocketAddr| {
ws_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok) move |req: Request<Incoming>| {
} ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
}; .map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_http_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| { let mk_http_upgrade_fn =
move |req: Request<Incoming>| { |server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, client_addr: SocketAddr| {
http_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok) move |req: Request<Incoming>| {
} http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
}; .map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| { let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
client_addr: SocketAddr| {
move |req: Request<Incoming>| { move |req: Request<Incoming>| {
let server_config = server_config.clone(); let server_config = server_config.clone();
let restrictions = restrictions.clone();
async move { async move {
if fastwebsockets::upgrade::is_upgrade_request(&req) { if fastwebsockets::upgrade::is_upgrade_request(&req) {
ws_server_upgrade(server_config.clone(), client_addr, req) ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left)))
.await .await
} else if req.version() == Version::HTTP_2 { } else if req.version() == Version::HTTP_2 {
http_server_upgrade(server_config.clone(), client_addr, req) http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok) .map::<anyhow::Result<_>, _>(Ok)
.await .await
} else { } else {
@ -716,9 +726,21 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
}; };
// Bind server and run forever to serve incoming connections. // Bind server and run forever to serve incoming connections.
let mut restrictions = RestrictionsRulesReloader::new(restrictions, server_config.restriction_config.clone())?;
let listener = TcpListener::bind(&server_config.bind).await?; let listener = TcpListener::bind(&server_config.bind).await?;
loop { loop {
let (stream, peer_addr) = match listener.accept().await { let cnx = select! {
biased;
_ = restrictions.wait_for_reload() => {
restrictions.reload_restrictions_config();
continue;
},
cnx = listener.accept() => { cnx }
};
let (stream, peer_addr) = match cnx {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
warn!("Error while accepting connection {:?}", err); warn!("Error while accepting connection {:?}", err);
@ -738,6 +760,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
info!("Accepting connection"); info!("Accepting connection");
let server_config = server_config.clone(); let server_config = server_config.clone();
let restrictions = restrictions.restrictions_rules().clone();
// Check if we need to enable TLS or not // Check if we need to enable TLS or not
match tls_context.as_mut() { match tls_context.as_mut() {
@ -762,7 +785,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
conn_builder.keep_alive_interval(ping); conn_builder.keep_alive_interval(ping);
} }
let http_upgrade_fn = mk_http_upgrade_fn(server_config, peer_addr); let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
if let Err(e) = con_fut.await { if let Err(e) = con_fut.await {
error!("Error while upgrading cnx to http: {:?}", e); error!("Error while upgrading cnx to http: {:?}", e);
@ -770,7 +793,8 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
} }
// websocket // websocket
_ => { _ => {
let websocket_upgrade_fn = mk_websocket_upgrade_fn(server_config, peer_addr); let websocket_upgrade_fn =
mk_websocket_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let conn_fut = http1::Builder::new() let conn_fut = http1::Builder::new()
.serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
.with_upgrades(); .with_upgrades();
@ -795,7 +819,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
conn_fut.http2().keep_alive_interval(ping); conn_fut.http2().keep_alive_interval(ping);
} }
let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, peer_addr); let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));
if let Err(e) = upgradable.await { if let Err(e) = upgradable.await {