Support auto-reload of tls certificate

This commit is contained in:
Σrebe - Romain GERARD 2023-12-29 09:56:47 +01:00
parent c9bc107e3b
commit 640102f82e
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
5 changed files with 316 additions and 18 deletions

View file

@ -13,6 +13,7 @@ use futures_util::{stream, TryStreamExt};
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter};
@ -217,10 +218,12 @@ struct Server {
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
/// The certificate will be automatically reloaded if it changes
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_certificate: Option<PathBuf>,
/// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one
/// The private key will be automatically reloaded if it changes
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_private_key: Option<PathBuf>,
}
@ -481,13 +484,14 @@ pub struct TlsClientConfig {
pub tls_verify_certificate: bool,
}
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct TlsServerConfig {
pub tls_certificate: Vec<Certificate>,
pub tls_key: PrivateKey,
pub tls_certificate: Mutex<Vec<Certificate>>,
pub tls_key: Mutex<PrivateKey>,
pub tls_certificate_path: Option<PathBuf>,
pub tls_key_path: Option<PathBuf>,
}
#[derive(Clone)]
pub struct WsServerConfig {
pub socket_so_mark: Option<u32>,
pub bind: SocketAddr,
@ -814,20 +818,23 @@ async fn main() {
}
Commands::Server(args) => {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
} else {
embedded_certificate::TLS_CERTIFICATE.clone()
};
let tls_key = if let Some(key_path) = args.tls_private_key {
tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key")
let tls_key = if let Some(key_path) = &args.tls_private_key {
tls::load_private_key_from_file(key_path).expect("Cannot load tls private key")
} else {
embedded_certificate::TLS_PRIVATE_KEY.clone()
};
Some(TlsServerConfig {
tls_certificate,
tls_key,
tls_certificate: Mutex::new(tls_certificate),
tls_key: Mutex::new(tls_key),
tls_certificate_path: args.tls_certificate,
tls_key_path: args.tls_private_key,
})
} else {
None

View file

@ -92,7 +92,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone())
.with_context(|| "invalid tls certificate or private key")?;
if let Some(alpn_protocols) = alpn_protocols {

View file

@ -1,14 +1,18 @@
use ahash::{HashMap, HashMapExt};
use anyhow::anyhow;
use anyhow::{anyhow, Context};
use base64::Engine;
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
use std::cmp::min;
use std::fmt::Debug;
use std::fs::File;
use std::future::Future;
use std::ops::{Deref, Not};
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::thread;
use std::time::{Duration, SystemTime};
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig};
@ -19,6 +23,8 @@ use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{http, Request, Response, StatusCode};
use jsonwebtoken::TokenData;
use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
@ -369,18 +375,33 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!("Starting wstunnel server listening on {}", server_config.bind);
// setup upgrade request handler
let config = server_config.clone();
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
let listener = TcpListener::bind(&server_config.bind).await?;
let tls_acceptor = if let Some(tls) = &server_config.tls {
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
// Init TLS if needed
let mut tls_context = if let Some(tls) = &server_config.tls {
let tls_reloader = TlsReloader::new(server_config.clone())?;
Some((
Arc::new(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?),
tls_reloader.tls_reload_certificate.clone(),
tls,
tls_reloader,
))
} else {
None
};
// Bind server and run forever to serve incoming connections.
let listener = TcpListener::bind(&server_config.bind).await?;
loop {
let (stream, peer_addr) = listener.accept().await?;
let (stream, peer_addr) = match listener.accept().await {
Ok(ret) => ret,
Err(err) => {
warn!("Error while accepting connection {:?}", err);
continue;
}
};
let _ = stream.set_nodelay(true);
let span = span!(
@ -395,7 +416,15 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
info!("Accepting connection");
let upgrade_fn = upgrade_fn.clone();
// TLS
if let Some(tls_acceptor) = &tls_acceptor {
if let Some((tls_acceptor, tls_reload_certificate, tls_config, _)) = tls_context.as_mut() {
// Reload TLS certificate if needed
if tls_reload_certificate.swap(false, Ordering::Relaxed) {
match tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()])) {
Ok(acceptor) => *tls_acceptor = Arc::new(acceptor),
Err(err) => error!("Cannot reload TLS certificate {:?}", err),
};
}
let tls_acceptor = tls_acceptor.clone();
let fut = async move {
info!("Doing TLS handshake");
@ -436,3 +465,136 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
};
}
}
struct TlsReloader {
fs_watcher: Arc<Mutex<Option<RecommendedWatcher>>>,
tls_reload_certificate: Arc<AtomicBool>,
}
impl TlsReloader {
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
let this = Self {
fs_watcher: Arc::new(Mutex::new(None)),
tls_reload_certificate: Arc::new(AtomicBool::new(false)),
};
// If there is no custom certificate and private key, there is nothing to watch
let Some((Some(cert_path), Some(key_path))) = server_config
.tls
.as_ref()
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
else {
return Ok(this);
};
let span = span!(Level::INFO, "tls_reloader");
let _enter = span.enter();
info!("Starting to watch tls certificate and private key for changes to reload them");
let tls_reload_certificate = this.tls_reload_certificate.clone();
let watcher = this.fs_watcher.clone();
let server_config = server_config.clone();
let span = span.clone();
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
let _enter = span.enter();
let event = match event {
Ok(event) => event,
Err(err) => {
error!("Error while watching tls certificate and private key for changes {:?}", err);
return;
}
};
if event.kind.is_access() {
return;
}
let tls = server_config.tls.as_ref().unwrap();
let cert_path = tls.tls_certificate_path.as_ref().unwrap();
let key_path = tls.tls_key_path.as_ref().unwrap();
if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) {
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) {
Ok(tls_certs) => {
*tls.tls_certificate.lock() = tls_certs;
tls_reload_certificate.store(true, Ordering::Relaxed);
}
Err(err) => {
warn!("Error while loading TLS certificate {:?}", err);
}
},
EventKind::Remove(_) => {
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) {
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) {
Ok(tls_key) => {
*tls.tls_key.lock() = tls_key;
tls_reload_certificate.store(true, Ordering::Relaxed);
}
Err(err) => {
warn!("Error while loading TLS private key {:?}", err);
}
},
EventKind::Remove(_) => {
warn!("TLS private key file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
})
.with_context(|| "Cannot create tls certificate watcher")?;
watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?;
watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?;
*this.fs_watcher.lock() = Some(watcher);
Ok(this)
}
fn try_rewatch_certificate(watcher: Arc<Mutex<Option<RecommendedWatcher>>>, path: PathBuf) {
thread::spawn(move || {
let span = span!(Level::INFO, "tls_reloader");
let _enter = span.enter();
thread::sleep(Duration::from_secs(10));
while !path.exists() {
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
thread::sleep(Duration::from_secs(10));
}
let mut watcher = watcher.lock();
let _ = watcher.as_mut().unwrap().unwatch(&path);
let Ok(_) = watcher
.as_mut()
.unwrap()
.watch(&path, notify::RecursiveMode::NonRecursive)
.map_err(|err| {
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
error!("TLS certificate will not be auto-reloaded anymore");
})
else {
return;
};
let Ok(file) = File::open(&path) else {
return;
};
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
error!("Old certificate will be used until the next change");
});
});
}
}