diff --git a/src/main.rs b/src/main.rs index b0cc1a4..80fda7a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -628,7 +628,7 @@ pub struct WsServerConfig { pub websocket_mask_frame: bool, pub tls: Option, pub dns_resolver: DnsResolver, - pub restrictions: RestrictionsRules, + pub restriction_config: Option, } impl Debug for WsServerConfig { @@ -639,6 +639,7 @@ impl Debug for WsServerConfig { .field("websocket_ping_frequency", &self.websocket_ping_frequency) .field("timeout_connect", &self.timeout_connect) .field("websocket_mask_frame", &self.websocket_mask_frame) + .field("restriction_config", &self.restriction_config) .field("tls", &self.tls.is_some()) .field( "mTLS", @@ -1270,7 +1271,10 @@ async fn main() { .iter() .map(|x| { let (host, port) = x.rsplit_once(':').expect("Invalid restrict-to format"); - (host.to_string(), port.parse::().expect("Invalid restrict-to port format")) + ( + host.trim_matches(&['[', ']']).to_string(), + port.parse::().expect("Invalid restrict-to port format"), + ) }) .collect(); @@ -1281,7 +1285,6 @@ async fn main() { .expect("Cannot convert restriction rules from path-prefix and restric-to"); restriction_cfg }; - debug!("Restriction rules: {:?}", restrictions); let server_config = WsServerConfig { socket_so_mark: args.socket_so_mark, @@ -1291,7 +1294,7 @@ async fn main() { websocket_mask_frame: args.websocket_mask_frame, tls: tls_config, dns_resolver, - restrictions, + restriction_config: args.restrict_config, }; info!( @@ -1299,7 +1302,8 @@ async fn main() { env!("CARGO_PKG_VERSION"), server_config ); - tunnel::server::run_server(Arc::new(server_config)) + debug!("Restriction rules: {:#?}", restrictions); + tunnel::server::run_server(Arc::new(server_config), restrictions) .await .unwrap_or_else(|err| { panic!("Cannot start wstunnel server: {:?}", err); diff --git a/src/restrictions/config_reloader.rs b/src/restrictions/config_reloader.rs new file mode 100644 index 0000000..95ca2b6 --- /dev/null +++ b/src/restrictions/config_reloader.rs @@ -0,0 +1,174 @@ +use super::types::RestrictionsRules; +use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static}; +use anyhow::Context; +use log::trace; +use notify::{EventKind, RecommendedWatcher, Watcher}; +use parking_lot::Mutex; +use std::path::PathBuf; +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use tokio::sync::futures::Notified; +use tokio::sync::Notify; +use tracing::{error, info, warn}; + +struct ConfigReloaderState { + fs_watcher: Mutex, + config_path: PathBuf, + should_reload_config: Notify, +} + +enum RestrictionsRulesReloaderState { + Static(Notify), + Config(Arc), +} + +impl RestrictionsRulesReloaderState { + fn fs_watcher(&self) -> &Mutex { + match self { + Static(_) => unreachable!(), + Config(this) => &this.fs_watcher, + } + } +} + +pub struct RestrictionsRulesReloader { + state: RestrictionsRulesReloaderState, + restrictions: Arc, +} + +impl RestrictionsRulesReloader { + pub fn new(restrictions_rules: RestrictionsRules, config_path: Option) -> anyhow::Result { + // If there is no custom certificate and private key, there is nothing to watch + let config_path = if let Some(config_path) = config_path { + config_path + } else { + return Ok(Self { + state: Static(Notify::new()), + restrictions: Arc::new(restrictions_rules), + }); + }; + + let this = Arc::new(ConfigReloaderState { + fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?), + should_reload_config: Notify::new(), + config_path, + }); + + info!("Starting to watch restriction config file for changes to reload them"); + let mut watcher = notify::recommended_watcher({ + let this = Config(this.clone()); + + move |event: notify::Result| Self::handle_config_fs_event(&this, event) + }) + .with_context(|| "Cannot create restriction config watcher")?; + + watcher.watch(&this.config_path, notify::RecursiveMode::NonRecursive)?; + *this.fs_watcher.lock() = watcher; + + Ok(Self { + state: Config(this), + restrictions: Arc::new(restrictions_rules), + }) + } + + pub fn reload_restrictions_config(&mut self) { + let restrictions = match &self.state { + Static(_) => return, + Config(st) => match RestrictionsRules::from_config_file(&st.config_path) { + Ok(restrictions) => { + info!("Restrictions config file has been reloaded"); + restrictions + } + Err(err) => { + error!("Cannot reload restrictions config file, keeping the old one. Error: {:?}", err); + return; + } + }, + }; + + self.restrictions = Arc::new(restrictions); + } + + pub fn restrictions_rules(&self) -> &Arc { + &self.restrictions + } + + pub fn wait_for_reload(&self) -> Notified { + match &self.state { + Static(st) => st.notified(), + Config(st) => st.should_reload_config.notified(), + } + } + + fn try_rewatch_config(this: RestrictionsRulesReloaderState, path: PathBuf) { + thread::spawn(move || { + while !path.exists() { + warn!( + "Restrictions config file {:?} does not exist anymore, waiting for it to be created", + path + ); + thread::sleep(Duration::from_secs(10)); + } + let mut watcher = this.fs_watcher().lock(); + let _ = watcher.unwatch(&path); + let Ok(_) = watcher + .watch(&path, notify::RecursiveMode::NonRecursive) + .map_err(|err| { + error!("Cannot re-set a watch for Restriction config file {:?}: {:?}", path, err); + error!("Restriction config file will not be auto-reloaded anymore"); + }) + else { + return; + }; + drop(watcher); + + // Generate a fake event to force-reload the certificate + let event = notify::Event { + kind: EventKind::Create(notify::event::CreateKind::Any), + paths: vec![path], + attrs: Default::default(), + }; + + match &this { + Static(_) => Self::handle_config_fs_event(&this, Ok(event)), + Config(_) => Self::handle_config_fs_event(&this, Ok(event)), + } + }); + } + + fn handle_config_fs_event(this: &RestrictionsRulesReloaderState, event: notify::Result) { + let this = match this { + Static(_) => return, + Config(st) => st, + }; + + let event = match event { + Ok(event) => event, + Err(err) => { + error!("Error while watching restrictions config file for changes {:?}", err); + return; + } + }; + + if event.kind.is_access() { + return; + } + + trace!("Received event: {:#?}", event); + if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => { + this.should_reload_config.notify_one(); + } + EventKind::Remove(_) => { + warn!("Restriction config file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_config(Config(this.clone()), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + } +} diff --git a/src/restrictions/mod.rs b/src/restrictions/mod.rs index e19edf0..cf09ca3 100644 --- a/src/restrictions/mod.rs +++ b/src/restrictions/mod.rs @@ -13,6 +13,7 @@ use types::RestrictionsRules; use crate::restrictions::types::{default_cidr, default_host}; +pub mod config_reloader; pub mod types; impl RestrictionsRules { diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index fd382ea..28424e1 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -26,6 +26,7 @@ use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; +use crate::restrictions::config_reloader::RestrictionsRulesReloader; use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; @@ -418,6 +419,7 @@ fn validate_tunnel<'a>( async fn ws_server_upgrade( server_config: Arc, + restrictions: Arc, mut client_addr: SocketAddr, mut req: Request, ) -> Response { @@ -463,7 +465,7 @@ async fn ws_server_upgrade( } }; - match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { + match validate_tunnel(&remote, path_prefix, &restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); } @@ -542,6 +544,7 @@ async fn ws_server_upgrade( async fn http_server_upgrade( server_config: Arc, + restrictions: Arc, mut client_addr: SocketAddr, mut req: Request, ) -> Response>> { @@ -578,7 +581,7 @@ async fn http_server_upgrade( } }; - match validate_tunnel(&remote, path_prefix, &server_config.restrictions) { + match validate_tunnel(&remote, path_prefix, &restrictions) { Ok(matched_restriction) => { info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); } @@ -664,32 +667,39 @@ impl TlsContext<'_> { } } -pub async fn run_server(server_config: Arc) -> anyhow::Result<()> { +pub async fn run_server(server_config: Arc, restrictions: RestrictionsRules) -> anyhow::Result<()> { info!("Starting wstunnel server listening on {}", server_config.bind); // setup upgrade request handler - let mk_websocket_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { - move |req: Request| { - ws_server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) - } - }; + let mk_websocket_upgrade_fn = + |server_config: Arc, restrictions: Arc, client_addr: SocketAddr| { + move |req: Request| { + ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + .map::, _>(Ok) + } + }; - let mk_http_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { - move |req: Request| { - http_server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) - } - }; + let mk_http_upgrade_fn = + |server_config: Arc, restrictions: Arc, client_addr: SocketAddr| { + move |req: Request| { + http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + .map::, _>(Ok) + } + }; - let mk_auto_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { + let mk_auto_upgrade_fn = |server_config: Arc, + restrictions: Arc, + client_addr: SocketAddr| { move |req: Request| { let server_config = server_config.clone(); + let restrictions = restrictions.clone(); async move { if fastwebsockets::upgrade::is_upgrade_request(&req) { - ws_server_upgrade(server_config.clone(), client_addr, req) + ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) .await } else if req.version() == Version::HTTP_2 { - http_server_upgrade(server_config.clone(), client_addr, req) + http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) .map::, _>(Ok) .await } else { @@ -716,9 +726,21 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() }; // Bind server and run forever to serve incoming connections. + let mut restrictions = RestrictionsRulesReloader::new(restrictions, server_config.restriction_config.clone())?; let listener = TcpListener::bind(&server_config.bind).await?; loop { - let (stream, peer_addr) = match listener.accept().await { + let cnx = select! { + biased; + + _ = restrictions.wait_for_reload() => { + restrictions.reload_restrictions_config(); + continue; + }, + + cnx = listener.accept() => { cnx } + }; + + let (stream, peer_addr) = match cnx { Ok(ret) => ret, Err(err) => { warn!("Error while accepting connection {:?}", err); @@ -738,6 +760,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Accepting connection"); let server_config = server_config.clone(); + let restrictions = restrictions.restrictions_rules().clone(); // Check if we need to enable TLS or not match tls_context.as_mut() { @@ -762,7 +785,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() conn_builder.keep_alive_interval(ping); } - let http_upgrade_fn = mk_http_upgrade_fn(server_config, peer_addr); + let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), peer_addr); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); if let Err(e) = con_fut.await { error!("Error while upgrading cnx to http: {:?}", e); @@ -770,7 +793,8 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() } // websocket _ => { - let websocket_upgrade_fn = mk_websocket_upgrade_fn(server_config, peer_addr); + let websocket_upgrade_fn = + mk_websocket_upgrade_fn(server_config, restrictions.clone(), peer_addr); let conn_fut = http1::Builder::new() .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .with_upgrades(); @@ -795,7 +819,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() conn_fut.http2().keep_alive_interval(ping); } - let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, peer_addr); + let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), peer_addr); let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); if let Err(e) = upgradable.await {