diff --git a/Cargo.lock b/Cargo.lock index 57c1027..96f3694 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -279,6 +279,25 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a9b73a36529d9c47029b9fb3a6f0ea3cc916a261195352ba19e770fc1748b2" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossterm" version = "0.27.0" @@ -426,6 +445,18 @@ dependencies = [ "utf-8", ] +[[package]] +name = "filetime" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "windows-sys 0.52.0", +] + [[package]] name = "fnv" version = "1.0.7" @@ -441,6 +472,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + [[package]] name = "futures" version = "0.3.29" @@ -807,6 +847,26 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" +dependencies = [ + "bitflags 1.3.2", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "ipconfig" version = "0.3.2" @@ -853,6 +913,26 @@ dependencies = [ "serde_json", ] +[[package]] +name = "kqueue" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -959,6 +1039,25 @@ dependencies = [ "memoffset", ] +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.4.1", + "crossbeam-channel", + "filetime", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1291,6 +1390,15 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.22" @@ -1874,6 +1982,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -1965,6 +2083,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2135,6 +2262,7 @@ dependencies = [ "jsonwebtoken", "log", "nix", + "notify", "once_cell", "parking_lot", "pin-project", diff --git a/Cargo.toml b/Cargo.toml index 8f1e3ca..1e10e3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ nix = { version = "0.27.1", features = ["socket", "net", "uio"] } once_cell = { version = "1.19.0", features = [] } parking_lot = "0.12.1" pin-project = "1" +notify = { version = "6.1.1", features = [] } rustls-native-certs = { version = "0.7.0", features = [] } rustls-pemfile = { version = "2.0.0", features = [] } diff --git a/src/main.rs b/src/main.rs index f3e8d38..813ad8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use futures_util::{stream, TryStreamExt}; use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts}; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Debug, Formatter}; @@ -217,10 +218,12 @@ struct Server { restrict_http_upgrade_path_prefix: Option>, /// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate. + /// The certificate will be automatically reloaded if it changes #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] tls_certificate: Option, /// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one + /// The private key will be automatically reloaded if it changes #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] tls_private_key: Option, } @@ -481,13 +484,14 @@ pub struct TlsClientConfig { pub tls_verify_certificate: bool, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct TlsServerConfig { - pub tls_certificate: Vec, - pub tls_key: PrivateKey, + pub tls_certificate: Mutex>, + pub tls_key: Mutex, + pub tls_certificate_path: Option, + pub tls_key_path: Option, } -#[derive(Clone)] pub struct WsServerConfig { pub socket_so_mark: Option, pub bind: SocketAddr, @@ -814,20 +818,23 @@ async fn main() { } Commands::Server(args) => { let tls_config = if args.remote_addr.scheme() == "wss" { - let tls_certificate = if let Some(cert_path) = args.tls_certificate { - tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate") + let tls_certificate = if let Some(cert_path) = &args.tls_certificate { + tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate") } else { embedded_certificate::TLS_CERTIFICATE.clone() }; - let tls_key = if let Some(key_path) = args.tls_private_key { - tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key") + let tls_key = if let Some(key_path) = &args.tls_private_key { + tls::load_private_key_from_file(key_path).expect("Cannot load tls private key") } else { embedded_certificate::TLS_PRIVATE_KEY.clone() }; + Some(TlsServerConfig { - tls_certificate, - tls_key, + tls_certificate: Mutex::new(tls_certificate), + tls_key: Mutex::new(tls_key), + tls_certificate_path: args.tls_certificate, + tls_key_path: args.tls_private_key, }) } else { None diff --git a/src/tls.rs b/src/tls.rs index c343c36..b0ad955 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -92,7 +92,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option, mut req: Request) -> anyhow::Result<()> { info!("Starting wstunnel server listening on {}", server_config.bind); + // setup upgrade request handler let config = server_config.clone(); let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req).map::, _>(Ok); - let listener = TcpListener::bind(&server_config.bind).await?; - let tls_acceptor = if let Some(tls) = &server_config.tls { - Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?) + // Init TLS if needed + let mut tls_context = if let Some(tls) = &server_config.tls { + let tls_reloader = TlsReloader::new(server_config.clone())?; + Some(( + Arc::new(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?), + tls_reloader.tls_reload_certificate.clone(), + tls, + tls_reloader, + )) } else { None }; + // Bind server and run forever to serve incoming connections. + let listener = TcpListener::bind(&server_config.bind).await?; loop { - let (stream, peer_addr) = listener.accept().await?; + let (stream, peer_addr) = match listener.accept().await { + Ok(ret) => ret, + Err(err) => { + warn!("Error while accepting connection {:?}", err); + continue; + } + }; let _ = stream.set_nodelay(true); let span = span!( @@ -395,7 +416,15 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Accepting connection"); let upgrade_fn = upgrade_fn.clone(); // TLS - if let Some(tls_acceptor) = &tls_acceptor { + if let Some((tls_acceptor, tls_reload_certificate, tls_config, _)) = tls_context.as_mut() { + // Reload TLS certificate if needed + if tls_reload_certificate.swap(false, Ordering::Relaxed) { + match tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()])) { + Ok(acceptor) => *tls_acceptor = Arc::new(acceptor), + Err(err) => error!("Cannot reload TLS certificate {:?}", err), + }; + } + let tls_acceptor = tls_acceptor.clone(); let fut = async move { info!("Doing TLS handshake"); @@ -436,3 +465,136 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() }; } } + +struct TlsReloader { + fs_watcher: Arc>>, + tls_reload_certificate: Arc, +} + +impl TlsReloader { + pub fn new(server_config: Arc) -> anyhow::Result { + let this = Self { + fs_watcher: Arc::new(Mutex::new(None)), + tls_reload_certificate: Arc::new(AtomicBool::new(false)), + }; + + // If there is no custom certificate and private key, there is nothing to watch + let Some((Some(cert_path), Some(key_path))) = server_config + .tls + .as_ref() + .map(|t| (&t.tls_certificate_path, &t.tls_key_path)) + else { + return Ok(this); + }; + + let span = span!(Level::INFO, "tls_reloader"); + let _enter = span.enter(); + info!("Starting to watch tls certificate and private key for changes to reload them"); + let tls_reload_certificate = this.tls_reload_certificate.clone(); + let watcher = this.fs_watcher.clone(); + let server_config = server_config.clone(); + let span = span.clone(); + + let mut watcher = notify::recommended_watcher(move |event: notify::Result| { + let _enter = span.enter(); + let event = match event { + Ok(event) => event, + Err(err) => { + error!("Error while watching tls certificate and private key for changes {:?}", err); + return; + } + }; + + if event.kind.is_access() { + return; + } + + let tls = server_config.tls.as_ref().unwrap(); + let cert_path = tls.tls_certificate_path.as_ref().unwrap(); + let key_path = tls.tls_key_path.as_ref().unwrap(); + + if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) { + Ok(tls_certs) => { + *tls.tls_certificate.lock() = tls_certs; + tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS certificate {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS certificate file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + + if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) { + Ok(tls_key) => { + *tls.tls_key.lock() = tls_key; + tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS private key {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS private key file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + }) + .with_context(|| "Cannot create tls certificate watcher")?; + + watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?; + watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?; + *this.fs_watcher.lock() = Some(watcher); + + Ok(this) + } + + fn try_rewatch_certificate(watcher: Arc>>, path: PathBuf) { + thread::spawn(move || { + let span = span!(Level::INFO, "tls_reloader"); + let _enter = span.enter(); + + thread::sleep(Duration::from_secs(10)); + while !path.exists() { + warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path); + thread::sleep(Duration::from_secs(10)); + } + let mut watcher = watcher.lock(); + let _ = watcher.as_mut().unwrap().unwatch(&path); + let Ok(_) = watcher + .as_mut() + .unwrap() + .watch(&path, notify::RecursiveMode::NonRecursive) + .map_err(|err| { + error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err); + error!("TLS certificate will not be auto-reloaded anymore"); + }) + else { + return; + }; + + let Ok(file) = File::open(&path) else { + return; + }; + let _ = file.set_modified(SystemTime::now()).map_err(|err| { + error!("Cannot force reload TLS file {:?}: {:?}", path, err); + error!("Old certificate will be used until the next change"); + }); + }); + } +}