TlsReloader cleanup

This commit is contained in:
Σrebe - Romain GERARD 2023-12-29 19:03:40 +01:00
parent 7ad36709bc
commit 5c7bc03e5f
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 93 additions and 80 deletions

View file

@ -7,7 +7,6 @@ use std::fmt::Debug;
use std::future::Future; use std::future::Future;
use std::ops::{Deref, Not}; use std::ops::{Deref, Not};
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::Ordering;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -376,7 +375,7 @@ struct TlsContext<'a> {
} }
impl TlsContext<'_> { impl TlsContext<'_> {
pub fn tls_acceptor(&mut self) -> &Arc<TlsAcceptor> { pub fn tls_acceptor(&mut self) -> &Arc<TlsAcceptor> {
if self.tls_reloader.tls_reload_certificate.swap(false, Ordering::Relaxed) { if self.tls_reloader.should_reload_certificate() {
match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) { match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) {
Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor), Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor),
Err(err) => error!("Cannot reload TLS certificate {:?}", err), Err(err) => error!("Cannot reload TLS certificate {:?}", err),

View file

@ -11,33 +11,88 @@ use std::thread;
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
struct TlsReloaderState {
fs_watcher: Mutex<RecommendedWatcher>,
tls_reload_certificate: AtomicBool,
server_config: Arc<WsServerConfig>,
cert_path: PathBuf,
key_path: PathBuf,
}
pub struct TlsReloader { pub struct TlsReloader {
fs_watcher: Arc<Mutex<Option<RecommendedWatcher>>>, state: Option<Arc<TlsReloaderState>>,
pub tls_reload_certificate: Arc<AtomicBool>,
} }
impl TlsReloader { impl TlsReloader {
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> { pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
let this = Self {
fs_watcher: Arc::new(Mutex::new(None)),
tls_reload_certificate: Arc::new(AtomicBool::new(false)),
};
// If there is no custom certificate and private key, there is nothing to watch // If there is no custom certificate and private key, there is nothing to watch
let Some((Some(cert_path), Some(key_path))) = server_config let Some((Some(cert_path), Some(key_path))) = server_config
.tls .tls
.as_ref() .as_ref()
.map(|t| (&t.tls_certificate_path, &t.tls_key_path)) .map(|t| (&t.tls_certificate_path, &t.tls_key_path))
else { else {
return Ok(this); return Ok(Self { state: None });
}; };
info!("Starting to watch tls certificate and private key for changes to reload them"); let this = Arc::new(TlsReloaderState {
let tls_reload_certificate = this.tls_reload_certificate.clone(); fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
let watcher = this.fs_watcher.clone(); tls_reload_certificate: AtomicBool::new(false),
let server_config = server_config.clone(); cert_path: cert_path.to_path_buf(),
key_path: key_path.to_path_buf(),
server_config,
});
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| { info!("Starting to watch tls certificate and private key for changes to reload them");
let mut watcher = notify::recommended_watcher({
let this = this.clone();
move |event: notify::Result<notify::Event>| Self::handle_fs_event(&this, event)
})
.with_context(|| "Cannot create tls certificate watcher")?;
watcher.watch(&this.cert_path, notify::RecursiveMode::NonRecursive)?;
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
*this.fs_watcher.lock() = watcher;
Ok(Self { state: Some(this) })
}
pub fn should_reload_certificate(&self) -> bool {
match &self.state {
None => false,
Some(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
}
}
fn try_rewatch_certificate(this: Arc<TlsReloaderState>, path: PathBuf) {
thread::spawn(move || {
while !path.exists() {
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
thread::sleep(Duration::from_secs(10));
}
let mut watcher = this.fs_watcher.lock();
let _ = watcher.unwatch(&path);
let Ok(_) = watcher
.watch(&path, notify::RecursiveMode::NonRecursive)
.map_err(|err| {
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
error!("TLS certificate will not be auto-reloaded anymore");
})
else {
return;
};
drop(watcher);
let Ok(file) = File::open(&path) else {
return;
};
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
error!("Old certificate will be used until the next change");
});
});
}
fn handle_fs_event(this: &Arc<TlsReloaderState>, event: notify::Result<notify::Event>) {
let event = match event { let event = match event {
Ok(event) => event, Ok(event) => event,
Err(err) => { Err(err) => {
@ -50,16 +105,13 @@ impl TlsReloader {
return; return;
} }
let tls = server_config.tls.as_ref().unwrap(); let tls = this.server_config.tls.as_ref().unwrap();
let cert_path = tls.tls_certificate_path.as_ref().unwrap(); if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.cert_path)) {
let key_path = tls.tls_key_path.as_ref().unwrap();
if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) {
match event.kind { match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) { EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(&this.cert_path) {
Ok(tls_certs) => { Ok(tls_certs) => {
*tls.tls_certificate.lock() = tls_certs; *tls.tls_certificate.lock() = tls_certs;
tls_reload_certificate.store(true, Ordering::Relaxed); this.tls_reload_certificate.store(true, Ordering::Relaxed);
} }
Err(err) => { Err(err) => {
warn!("Error while loading TLS certificate {:?}", err); warn!("Error while loading TLS certificate {:?}", err);
@ -67,7 +119,7 @@ impl TlsReloader {
}, },
EventKind::Remove(_) => { EventKind::Remove(_) => {
warn!("TLS certificate file has been removed, trying to re-set a watch for it"); warn!("TLS certificate file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
} }
EventKind::Access(_) | EventKind::Other | EventKind::Any => { EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event); trace!("Ignoring event {:?}", event);
@ -75,12 +127,12 @@ impl TlsReloader {
} }
} }
if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) { if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.key_path)) {
match event.kind { match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) { EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(&this.key_path) {
Ok(tls_key) => { Ok(tls_key) => {
*tls.tls_key.lock() = tls_key; *tls.tls_key.lock() = tls_key;
tls_reload_certificate.store(true, Ordering::Relaxed); this.tls_reload_certificate.store(true, Ordering::Relaxed);
} }
Err(err) => { Err(err) => {
warn!("Error while loading TLS private key {:?}", err); warn!("Error while loading TLS private key {:?}", err);
@ -88,50 +140,12 @@ impl TlsReloader {
}, },
EventKind::Remove(_) => { EventKind::Remove(_) => {
warn!("TLS private key file has been removed, trying to re-set a watch for it"); warn!("TLS private key file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
} }
EventKind::Access(_) | EventKind::Other | EventKind::Any => { EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event); trace!("Ignoring event {:?}", event);
} }
} }
} }
})
.with_context(|| "Cannot create tls certificate watcher")?;
watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?;
watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?;
*this.fs_watcher.lock() = Some(watcher);
Ok(this)
}
fn try_rewatch_certificate(watcher: Arc<Mutex<Option<RecommendedWatcher>>>, path: PathBuf) {
thread::spawn(move || {
while !path.exists() {
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
thread::sleep(Duration::from_secs(10));
}
let mut watcher = watcher.lock();
let _ = watcher.as_mut().unwrap().unwatch(&path);
let Ok(_) = watcher
.as_mut()
.unwrap()
.watch(&path, notify::RecursiveMode::NonRecursive)
.map_err(|err| {
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
error!("TLS certificate will not be auto-reloaded anymore");
})
else {
return;
};
let Ok(file) = File::open(&path) else {
return;
};
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
error!("Old certificate will be used until the next change");
});
});
} }
} }