cleanup tls reloader
This commit is contained in:
parent
0e05469fc7
commit
7ad36709bc
3 changed files with 171 additions and 159 deletions
|
@ -1,6 +1,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
mod io;
|
mod io;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
mod tls_reloader;
|
||||||
|
|
||||||
use crate::dns::DnsResolver;
|
use crate::dns::DnsResolver;
|
||||||
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
|
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
|
||||||
|
|
|
@ -1,21 +1,18 @@
|
||||||
use ahash::{HashMap, HashMapExt};
|
use ahash::{HashMap, HashMapExt};
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::anyhow;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::fs::File;
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::ops::{Deref, Not};
|
use std::ops::{Deref, Not};
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::Ordering;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread;
|
use std::time::Duration;
|
||||||
use std::time::{Duration, SystemTime};
|
|
||||||
|
|
||||||
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
|
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
|
||||||
use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig};
|
use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig};
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL};
|
use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL};
|
||||||
use hyper::http::HeaderValue;
|
use hyper::http::HeaderValue;
|
||||||
|
@ -23,16 +20,16 @@ use hyper::server::conn::http1;
|
||||||
use hyper::service::service_fn;
|
use hyper::service::service_fn;
|
||||||
use hyper::{http, Request, Response, StatusCode};
|
use hyper::{http, Request, Response, StatusCode};
|
||||||
use jsonwebtoken::TokenData;
|
use jsonwebtoken::TokenData;
|
||||||
use log::trace;
|
|
||||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
|
use crate::tunnel::tls_reloader::TlsReloader;
|
||||||
use crate::udp::UdpStream;
|
use crate::udp::UdpStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
||||||
use url::Host;
|
use url::Host;
|
||||||
|
|
||||||
|
@ -372,6 +369,24 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
||||||
Response::from_parts(response.into_parts().0, "".to_string())
|
Response::from_parts(response.into_parts().0, "".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct TlsContext<'a> {
|
||||||
|
tls_acceptor: Arc<TlsAcceptor>,
|
||||||
|
tls_reloader: TlsReloader,
|
||||||
|
tls_config: &'a TlsServerConfig,
|
||||||
|
}
|
||||||
|
impl TlsContext<'_> {
|
||||||
|
pub fn tls_acceptor(&mut self) -> &Arc<TlsAcceptor> {
|
||||||
|
if self.tls_reloader.tls_reload_certificate.swap(false, Ordering::Relaxed) {
|
||||||
|
match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) {
|
||||||
|
Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor),
|
||||||
|
Err(err) => error!("Cannot reload TLS certificate {:?}", err),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
&self.tls_acceptor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||||
|
|
||||||
|
@ -380,14 +395,13 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
||||||
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
||||||
|
|
||||||
// Init TLS if needed
|
// Init TLS if needed
|
||||||
let mut tls_context = if let Some(tls) = &server_config.tls {
|
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
||||||
let tls_reloader = TlsReloader::new(server_config.clone())?;
|
let tls_context = TlsContext {
|
||||||
Some((
|
tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()]))?),
|
||||||
Arc::new(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?),
|
tls_reloader: TlsReloader::new(server_config.clone())?,
|
||||||
tls_reloader.tls_reload_certificate.clone(),
|
tls_config,
|
||||||
tls,
|
};
|
||||||
tls_reloader,
|
Some(tls_context)
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
@ -416,16 +430,9 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
||||||
info!("Accepting connection");
|
info!("Accepting connection");
|
||||||
let upgrade_fn = upgrade_fn.clone();
|
let upgrade_fn = upgrade_fn.clone();
|
||||||
// TLS
|
// TLS
|
||||||
if let Some((tls_acceptor, tls_reload_certificate, tls_config, _)) = tls_context.as_mut() {
|
if let Some(tls) = tls_context.as_mut() {
|
||||||
// Reload TLS certificate if needed
|
// Reload TLS certificate if needed
|
||||||
if tls_reload_certificate.swap(false, Ordering::Relaxed) {
|
let tls_acceptor = tls.tls_acceptor().clone();
|
||||||
match tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()])) {
|
|
||||||
Ok(acceptor) => *tls_acceptor = Arc::new(acceptor),
|
|
||||||
Err(err) => error!("Cannot reload TLS certificate {:?}", err),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let tls_acceptor = tls_acceptor.clone();
|
|
||||||
let fut = async move {
|
let fut = async move {
|
||||||
info!("Doing TLS handshake");
|
info!("Doing TLS handshake");
|
||||||
let tls_stream = match tls_acceptor.accept(stream).await {
|
let tls_stream = match tls_acceptor.accept(stream).await {
|
||||||
|
@ -465,136 +472,3 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TlsReloader {
|
|
||||||
fs_watcher: Arc<Mutex<Option<RecommendedWatcher>>>,
|
|
||||||
tls_reload_certificate: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TlsReloader {
|
|
||||||
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
|
||||||
let this = Self {
|
|
||||||
fs_watcher: Arc::new(Mutex::new(None)),
|
|
||||||
tls_reload_certificate: Arc::new(AtomicBool::new(false)),
|
|
||||||
};
|
|
||||||
|
|
||||||
// If there is no custom certificate and private key, there is nothing to watch
|
|
||||||
let Some((Some(cert_path), Some(key_path))) = server_config
|
|
||||||
.tls
|
|
||||||
.as_ref()
|
|
||||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
|
||||||
else {
|
|
||||||
return Ok(this);
|
|
||||||
};
|
|
||||||
|
|
||||||
let span = span!(Level::INFO, "tls_reloader");
|
|
||||||
let _enter = span.enter();
|
|
||||||
info!("Starting to watch tls certificate and private key for changes to reload them");
|
|
||||||
let tls_reload_certificate = this.tls_reload_certificate.clone();
|
|
||||||
let watcher = this.fs_watcher.clone();
|
|
||||||
let server_config = server_config.clone();
|
|
||||||
let span = span.clone();
|
|
||||||
|
|
||||||
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
|
|
||||||
let _enter = span.enter();
|
|
||||||
let event = match event {
|
|
||||||
Ok(event) => event,
|
|
||||||
Err(err) => {
|
|
||||||
error!("Error while watching tls certificate and private key for changes {:?}", err);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if event.kind.is_access() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tls = server_config.tls.as_ref().unwrap();
|
|
||||||
let cert_path = tls.tls_certificate_path.as_ref().unwrap();
|
|
||||||
let key_path = tls.tls_key_path.as_ref().unwrap();
|
|
||||||
|
|
||||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) {
|
|
||||||
match event.kind {
|
|
||||||
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) {
|
|
||||||
Ok(tls_certs) => {
|
|
||||||
*tls.tls_certificate.lock() = tls_certs;
|
|
||||||
tls_reload_certificate.store(true, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
warn!("Error while loading TLS certificate {:?}", err);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EventKind::Remove(_) => {
|
|
||||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
|
||||||
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
|
||||||
}
|
|
||||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
|
||||||
trace!("Ignoring event {:?}", event);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) {
|
|
||||||
match event.kind {
|
|
||||||
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) {
|
|
||||||
Ok(tls_key) => {
|
|
||||||
*tls.tls_key.lock() = tls_key;
|
|
||||||
tls_reload_certificate.store(true, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
warn!("Error while loading TLS private key {:?}", err);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EventKind::Remove(_) => {
|
|
||||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
|
||||||
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
|
||||||
}
|
|
||||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
|
||||||
trace!("Ignoring event {:?}", event);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
|
||||||
|
|
||||||
watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?;
|
|
||||||
watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?;
|
|
||||||
*this.fs_watcher.lock() = Some(watcher);
|
|
||||||
|
|
||||||
Ok(this)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_rewatch_certificate(watcher: Arc<Mutex<Option<RecommendedWatcher>>>, path: PathBuf) {
|
|
||||||
thread::spawn(move || {
|
|
||||||
let span = span!(Level::INFO, "tls_reloader");
|
|
||||||
let _enter = span.enter();
|
|
||||||
|
|
||||||
thread::sleep(Duration::from_secs(10));
|
|
||||||
while !path.exists() {
|
|
||||||
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
|
|
||||||
thread::sleep(Duration::from_secs(10));
|
|
||||||
}
|
|
||||||
let mut watcher = watcher.lock();
|
|
||||||
let _ = watcher.as_mut().unwrap().unwatch(&path);
|
|
||||||
let Ok(_) = watcher
|
|
||||||
.as_mut()
|
|
||||||
.unwrap()
|
|
||||||
.watch(&path, notify::RecursiveMode::NonRecursive)
|
|
||||||
.map_err(|err| {
|
|
||||||
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
|
|
||||||
error!("TLS certificate will not be auto-reloaded anymore");
|
|
||||||
})
|
|
||||||
else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(file) = File::open(&path) else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
|
|
||||||
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
|
|
||||||
error!("Old certificate will be used until the next change");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
137
src/tunnel/tls_reloader.rs
Normal file
137
src/tunnel/tls_reloader.rs
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
use crate::{tls, WsServerConfig};
|
||||||
|
use anyhow::Context;
|
||||||
|
use log::trace;
|
||||||
|
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, SystemTime};
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
pub struct TlsReloader {
|
||||||
|
fs_watcher: Arc<Mutex<Option<RecommendedWatcher>>>,
|
||||||
|
pub tls_reload_certificate: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TlsReloader {
|
||||||
|
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||||
|
let this = Self {
|
||||||
|
fs_watcher: Arc::new(Mutex::new(None)),
|
||||||
|
tls_reload_certificate: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// If there is no custom certificate and private key, there is nothing to watch
|
||||||
|
let Some((Some(cert_path), Some(key_path))) = server_config
|
||||||
|
.tls
|
||||||
|
.as_ref()
|
||||||
|
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||||
|
else {
|
||||||
|
return Ok(this);
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("Starting to watch tls certificate and private key for changes to reload them");
|
||||||
|
let tls_reload_certificate = this.tls_reload_certificate.clone();
|
||||||
|
let watcher = this.fs_watcher.clone();
|
||||||
|
let server_config = server_config.clone();
|
||||||
|
|
||||||
|
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
|
||||||
|
let event = match event {
|
||||||
|
Ok(event) => event,
|
||||||
|
Err(err) => {
|
||||||
|
error!("Error while watching tls certificate and private key for changes {:?}", err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if event.kind.is_access() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tls = server_config.tls.as_ref().unwrap();
|
||||||
|
let cert_path = tls.tls_certificate_path.as_ref().unwrap();
|
||||||
|
let key_path = tls.tls_key_path.as_ref().unwrap();
|
||||||
|
|
||||||
|
if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) {
|
||||||
|
match event.kind {
|
||||||
|
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) {
|
||||||
|
Ok(tls_certs) => {
|
||||||
|
*tls.tls_certificate.lock() = tls_certs;
|
||||||
|
tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Error while loading TLS certificate {:?}", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EventKind::Remove(_) => {
|
||||||
|
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||||
|
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
||||||
|
}
|
||||||
|
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||||
|
trace!("Ignoring event {:?}", event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) {
|
||||||
|
match event.kind {
|
||||||
|
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) {
|
||||||
|
Ok(tls_key) => {
|
||||||
|
*tls.tls_key.lock() = tls_key;
|
||||||
|
tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Error while loading TLS private key {:?}", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EventKind::Remove(_) => {
|
||||||
|
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||||
|
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
||||||
|
}
|
||||||
|
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||||
|
trace!("Ignoring event {:?}", event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||||
|
|
||||||
|
watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?;
|
||||||
|
watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?;
|
||||||
|
*this.fs_watcher.lock() = Some(watcher);
|
||||||
|
|
||||||
|
Ok(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_rewatch_certificate(watcher: Arc<Mutex<Option<RecommendedWatcher>>>, path: PathBuf) {
|
||||||
|
thread::spawn(move || {
|
||||||
|
while !path.exists() {
|
||||||
|
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
|
||||||
|
thread::sleep(Duration::from_secs(10));
|
||||||
|
}
|
||||||
|
let mut watcher = watcher.lock();
|
||||||
|
let _ = watcher.as_mut().unwrap().unwatch(&path);
|
||||||
|
let Ok(_) = watcher
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.watch(&path, notify::RecursiveMode::NonRecursive)
|
||||||
|
.map_err(|err| {
|
||||||
|
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
|
||||||
|
error!("TLS certificate will not be auto-reloaded anymore");
|
||||||
|
})
|
||||||
|
else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(file) = File::open(&path) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
|
||||||
|
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
|
||||||
|
error!("Old certificate will be used until the next change");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue