Support auto-reload of tls certificate
This commit is contained in:
parent
c9bc107e3b
commit
640102f82e
5 changed files with 316 additions and 18 deletions
128
Cargo.lock
generated
128
Cargo.lock
generated
|
@ -279,6 +279,25 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82a9b73a36529d9c47029b9fb3a6f0ea3cc916a261195352ba19e770fc1748b2"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossterm"
|
||||
version = "0.27.0"
|
||||
|
@ -426,6 +445,18 @@ dependencies = [
|
|||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filetime"
|
||||
version = "0.2.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -441,6 +472,15 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fsevent-sys"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.29"
|
||||
|
@ -807,6 +847,26 @@ dependencies = [
|
|||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inotify"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"inotify-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inotify-sys"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipconfig"
|
||||
version = "0.3.2"
|
||||
|
@ -853,6 +913,26 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kqueue"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c"
|
||||
dependencies = [
|
||||
"kqueue-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kqueue-sys"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
|
@ -959,6 +1039,25 @@ dependencies = [
|
|||
"memoffset",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "6.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"crossbeam-channel",
|
||||
"filetime",
|
||||
"fsevent-sys",
|
||||
"inotify",
|
||||
"kqueue",
|
||||
"libc",
|
||||
"log",
|
||||
"mio",
|
||||
"walkdir",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
|
@ -1291,6 +1390,15 @@ version = "1.0.16"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.22"
|
||||
|
@ -1874,6 +1982,16 @@ version = "0.9.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
|
@ -1965,6 +2083,15 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
@ -2135,6 +2262,7 @@ dependencies = [
|
|||
"jsonwebtoken",
|
||||
"log",
|
||||
"nix",
|
||||
"notify",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"pin-project",
|
||||
|
|
|
@ -28,6 +28,7 @@ nix = { version = "0.27.1", features = ["socket", "net", "uio"] }
|
|||
once_cell = { version = "1.19.0", features = [] }
|
||||
parking_lot = "0.12.1"
|
||||
pin-project = "1"
|
||||
notify = { version = "6.1.1", features = [] }
|
||||
|
||||
rustls-native-certs = { version = "0.7.0", features = [] }
|
||||
rustls-pemfile = { version = "2.0.0", features = [] }
|
||||
|
|
27
src/main.rs
27
src/main.rs
|
@ -13,6 +13,7 @@ use futures_util::{stream, TryStreamExt};
|
|||
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
|
||||
use hyper::header::HOST;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::fmt::{Debug, Formatter};
|
||||
|
@ -217,10 +218,12 @@ struct Server {
|
|||
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
||||
|
||||
/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
|
||||
/// The certificate will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_certificate: Option<PathBuf>,
|
||||
|
||||
/// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one
|
||||
/// The private key will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_private_key: Option<PathBuf>,
|
||||
}
|
||||
|
@ -481,13 +484,14 @@ pub struct TlsClientConfig {
|
|||
pub tls_verify_certificate: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct TlsServerConfig {
|
||||
pub tls_certificate: Vec<Certificate>,
|
||||
pub tls_key: PrivateKey,
|
||||
pub tls_certificate: Mutex<Vec<Certificate>>,
|
||||
pub tls_key: Mutex<PrivateKey>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsServerConfig {
|
||||
pub socket_so_mark: Option<u32>,
|
||||
pub bind: SocketAddr,
|
||||
|
@ -814,20 +818,23 @@ async fn main() {
|
|||
}
|
||||
Commands::Server(args) => {
|
||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
|
||||
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
|
||||
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
|
||||
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
|
||||
} else {
|
||||
embedded_certificate::TLS_CERTIFICATE.clone()
|
||||
};
|
||||
|
||||
let tls_key = if let Some(key_path) = args.tls_private_key {
|
||||
tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key")
|
||||
let tls_key = if let Some(key_path) = &args.tls_private_key {
|
||||
tls::load_private_key_from_file(key_path).expect("Cannot load tls private key")
|
||||
} else {
|
||||
embedded_certificate::TLS_PRIVATE_KEY.clone()
|
||||
};
|
||||
|
||||
Some(TlsServerConfig {
|
||||
tls_certificate,
|
||||
tls_key,
|
||||
tls_certificate: Mutex::new(tls_certificate),
|
||||
tls_key: Mutex::new(tls_key),
|
||||
tls_certificate_path: args.tls_certificate,
|
||||
tls_key_path: args.tls_private_key,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
|
|
@ -92,7 +92,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
|
|||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
|
||||
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone())
|
||||
.with_context(|| "invalid tls certificate or private key")?;
|
||||
|
||||
if let Some(alpn_protocols) = alpn_protocols {
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
use ahash::{HashMap, HashMapExt};
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{anyhow, Context};
|
||||
use base64::Engine;
|
||||
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||
use std::cmp::min;
|
||||
use std::fmt::Debug;
|
||||
use std::fs::File;
|
||||
use std::future::Future;
|
||||
use std::ops::{Deref, Not};
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::thread;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
|
||||
use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig};
|
||||
|
@ -19,6 +23,8 @@ use hyper::server::conn::http1;
|
|||
use hyper::service::service_fn;
|
||||
use hyper::{http, Request, Response, StatusCode};
|
||||
use jsonwebtoken::TokenData;
|
||||
use log::trace;
|
||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
|
@ -369,18 +375,33 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
|
|||
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||
|
||||
// setup upgrade request handler
|
||||
let config = server_config.clone();
|
||||
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
||||
|
||||
let listener = TcpListener::bind(&server_config.bind).await?;
|
||||
let tls_acceptor = if let Some(tls) = &server_config.tls {
|
||||
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
|
||||
// Init TLS if needed
|
||||
let mut tls_context = if let Some(tls) = &server_config.tls {
|
||||
let tls_reloader = TlsReloader::new(server_config.clone())?;
|
||||
Some((
|
||||
Arc::new(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?),
|
||||
tls_reloader.tls_reload_certificate.clone(),
|
||||
tls,
|
||||
tls_reloader,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Bind server and run forever to serve incoming connections.
|
||||
let listener = TcpListener::bind(&server_config.bind).await?;
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let (stream, peer_addr) = match listener.accept().await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!("Error while accepting connection {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let span = span!(
|
||||
|
@ -395,7 +416,15 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
info!("Accepting connection");
|
||||
let upgrade_fn = upgrade_fn.clone();
|
||||
// TLS
|
||||
if let Some(tls_acceptor) = &tls_acceptor {
|
||||
if let Some((tls_acceptor, tls_reload_certificate, tls_config, _)) = tls_context.as_mut() {
|
||||
// Reload TLS certificate if needed
|
||||
if tls_reload_certificate.swap(false, Ordering::Relaxed) {
|
||||
match tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()])) {
|
||||
Ok(acceptor) => *tls_acceptor = Arc::new(acceptor),
|
||||
Err(err) => error!("Cannot reload TLS certificate {:?}", err),
|
||||
};
|
||||
}
|
||||
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let fut = async move {
|
||||
info!("Doing TLS handshake");
|
||||
|
@ -436,3 +465,136 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
};
|
||||
}
|
||||
}
|
||||
|
||||
struct TlsReloader {
|
||||
fs_watcher: Arc<Mutex<Option<RecommendedWatcher>>>,
|
||||
tls_reload_certificate: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl TlsReloader {
|
||||
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||
let this = Self {
|
||||
fs_watcher: Arc::new(Mutex::new(None)),
|
||||
tls_reload_certificate: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
// If there is no custom certificate and private key, there is nothing to watch
|
||||
let Some((Some(cert_path), Some(key_path))) = server_config
|
||||
.tls
|
||||
.as_ref()
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||
else {
|
||||
return Ok(this);
|
||||
};
|
||||
|
||||
let span = span!(Level::INFO, "tls_reloader");
|
||||
let _enter = span.enter();
|
||||
info!("Starting to watch tls certificate and private key for changes to reload them");
|
||||
let tls_reload_certificate = this.tls_reload_certificate.clone();
|
||||
let watcher = this.fs_watcher.clone();
|
||||
let server_config = server_config.clone();
|
||||
let span = span.clone();
|
||||
|
||||
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
|
||||
let _enter = span.enter();
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
error!("Error while watching tls certificate and private key for changes {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if event.kind.is_access() {
|
||||
return;
|
||||
}
|
||||
|
||||
let tls = server_config.tls.as_ref().unwrap();
|
||||
let cert_path = tls.tls_certificate_path.as_ref().unwrap();
|
||||
let key_path = tls.tls_key_path.as_ref().unwrap();
|
||||
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) {
|
||||
Ok(tls_certs) => {
|
||||
*tls.tls_certificate.lock() = tls_certs;
|
||||
tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS certificate {:?}", err);
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) {
|
||||
Ok(tls_key) => {
|
||||
*tls.tls_key.lock() = tls_key;
|
||||
tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS private key {:?}", err);
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||
|
||||
watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?;
|
||||
watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?;
|
||||
*this.fs_watcher.lock() = Some(watcher);
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn try_rewatch_certificate(watcher: Arc<Mutex<Option<RecommendedWatcher>>>, path: PathBuf) {
|
||||
thread::spawn(move || {
|
||||
let span = span!(Level::INFO, "tls_reloader");
|
||||
let _enter = span.enter();
|
||||
|
||||
thread::sleep(Duration::from_secs(10));
|
||||
while !path.exists() {
|
||||
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
|
||||
thread::sleep(Duration::from_secs(10));
|
||||
}
|
||||
let mut watcher = watcher.lock();
|
||||
let _ = watcher.as_mut().unwrap().unwatch(&path);
|
||||
let Ok(_) = watcher
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.watch(&path, notify::RecursiveMode::NonRecursive)
|
||||
.map_err(|err| {
|
||||
error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err);
|
||||
error!("TLS certificate will not be auto-reloaded anymore");
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(file) = File::open(&path) else {
|
||||
return;
|
||||
};
|
||||
let _ = file.set_modified(SystemTime::now()).map_err(|err| {
|
||||
error!("Cannot force reload TLS file {:?}: {:?}", path, err);
|
||||
error!("Old certificate will be used until the next change");
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue