From 5c7bc03e5f4e0633e103350747df595af1eaf65e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Fri, 29 Dec 2023 19:03:40 +0100 Subject: [PATCH] TlsReloader cleanup --- src/tunnel/server.rs | 3 +- src/tunnel/tls_reloader.rs | 170 ++++++++++++++++++++----------------- 2 files changed, 93 insertions(+), 80 deletions(-) diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 6edd476..8fe7b99 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -7,7 +7,6 @@ use std::fmt::Debug; use std::future::Future; use std::ops::{Deref, Not}; use std::pin::Pin; -use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; @@ -376,7 +375,7 @@ struct TlsContext<'a> { } impl TlsContext<'_> { pub fn tls_acceptor(&mut self) -> &Arc { - if self.tls_reloader.tls_reload_certificate.swap(false, Ordering::Relaxed) { + if self.tls_reloader.should_reload_certificate() { match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) { Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor), Err(err) => error!("Cannot reload TLS certificate {:?}", err), diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs index e2e22b9..df18965 100644 --- a/src/tunnel/tls_reloader.rs +++ b/src/tunnel/tls_reloader.rs @@ -11,111 +11,67 @@ use std::thread; use std::time::{Duration, SystemTime}; use tracing::{error, info, warn}; +struct TlsReloaderState { + fs_watcher: Mutex, + tls_reload_certificate: AtomicBool, + server_config: Arc, + cert_path: PathBuf, + key_path: PathBuf, +} pub struct TlsReloader { - fs_watcher: Arc>>, - pub tls_reload_certificate: Arc, + state: Option>, } impl TlsReloader { pub fn new(server_config: Arc) -> anyhow::Result { - let this = Self { - fs_watcher: Arc::new(Mutex::new(None)), - tls_reload_certificate: Arc::new(AtomicBool::new(false)), - }; - // If there is no custom certificate and private key, there is nothing to watch let Some((Some(cert_path), Some(key_path))) = server_config .tls .as_ref() .map(|t| (&t.tls_certificate_path, &t.tls_key_path)) else { - return Ok(this); + return Ok(Self { state: None }); }; + let this = Arc::new(TlsReloaderState { + fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?), + tls_reload_certificate: AtomicBool::new(false), + cert_path: cert_path.to_path_buf(), + key_path: key_path.to_path_buf(), + server_config, + }); + info!("Starting to watch tls certificate and private key for changes to reload them"); - let tls_reload_certificate = this.tls_reload_certificate.clone(); - let watcher = this.fs_watcher.clone(); - let server_config = server_config.clone(); + let mut watcher = notify::recommended_watcher({ + let this = this.clone(); - let mut watcher = notify::recommended_watcher(move |event: notify::Result| { - let event = match event { - Ok(event) => event, - Err(err) => { - error!("Error while watching tls certificate and private key for changes {:?}", err); - return; - } - }; - - if event.kind.is_access() { - return; - } - - let tls = server_config.tls.as_ref().unwrap(); - let cert_path = tls.tls_certificate_path.as_ref().unwrap(); - let key_path = tls.tls_key_path.as_ref().unwrap(); - - if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) { - match event.kind { - EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) { - Ok(tls_certs) => { - *tls.tls_certificate.lock() = tls_certs; - tls_reload_certificate.store(true, Ordering::Relaxed); - } - Err(err) => { - warn!("Error while loading TLS certificate {:?}", err); - } - }, - EventKind::Remove(_) => { - warn!("TLS certificate file has been removed, trying to re-set a watch for it"); - Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); - } - EventKind::Access(_) | EventKind::Other | EventKind::Any => { - trace!("Ignoring event {:?}", event); - } - } - } - - if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) { - match event.kind { - EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) { - Ok(tls_key) => { - *tls.tls_key.lock() = tls_key; - tls_reload_certificate.store(true, Ordering::Relaxed); - } - Err(err) => { - warn!("Error while loading TLS private key {:?}", err); - } - }, - EventKind::Remove(_) => { - warn!("TLS private key file has been removed, trying to re-set a watch for it"); - Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); - } - EventKind::Access(_) | EventKind::Other | EventKind::Any => { - trace!("Ignoring event {:?}", event); - } - } - } + move |event: notify::Result| Self::handle_fs_event(&this, event) }) .with_context(|| "Cannot create tls certificate watcher")?; - watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?; - watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?; - *this.fs_watcher.lock() = Some(watcher); + watcher.watch(&this.cert_path, notify::RecursiveMode::NonRecursive)?; + watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?; + *this.fs_watcher.lock() = watcher; - Ok(this) + Ok(Self { state: Some(this) }) } - fn try_rewatch_certificate(watcher: Arc>>, path: PathBuf) { + pub fn should_reload_certificate(&self) -> bool { + match &self.state { + None => false, + Some(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed), + } + } + + fn try_rewatch_certificate(this: Arc, path: PathBuf) { thread::spawn(move || { while !path.exists() { warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path); thread::sleep(Duration::from_secs(10)); } - let mut watcher = watcher.lock(); - let _ = watcher.as_mut().unwrap().unwatch(&path); + let mut watcher = this.fs_watcher.lock(); + let _ = watcher.unwatch(&path); let Ok(_) = watcher - .as_mut() - .unwrap() .watch(&path, notify::RecursiveMode::NonRecursive) .map_err(|err| { error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err); @@ -124,6 +80,7 @@ impl TlsReloader { else { return; }; + drop(watcher); let Ok(file) = File::open(&path) else { return; @@ -134,4 +91,61 @@ impl TlsReloader { }); }); } + + fn handle_fs_event(this: &Arc, event: notify::Result) { + let event = match event { + Ok(event) => event, + Err(err) => { + error!("Error while watching tls certificate and private key for changes {:?}", err); + return; + } + }; + + if event.kind.is_access() { + return; + } + + let tls = this.server_config.tls.as_ref().unwrap(); + if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.cert_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(&this.cert_path) { + Ok(tls_certs) => { + *tls.tls_certificate.lock() = tls_certs; + this.tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS certificate {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS certificate file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(this.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + + if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.key_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(&this.key_path) { + Ok(tls_key) => { + *tls.tls_key.lock() = tls_key; + this.tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS private key {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS private key file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(this.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + } }