diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 3c80a34..5aeb704 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -1,6 +1,7 @@ pub mod client; mod io; pub mod server; +mod tls_reloader; use crate::dns::DnsResolver; use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig}; diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index a741035..6edd476 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,21 +1,18 @@ use ahash::{HashMap, HashMapExt}; -use anyhow::{anyhow, Context}; +use anyhow::anyhow; use base64::Engine; use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; use std::cmp::min; use std::fmt::Debug; -use std::fs::File; use std::future::Future; use std::ops::{Deref, Not}; -use std::path::PathBuf; use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::Ordering; use std::sync::Arc; -use std::thread; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig}; +use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; use hyper::body::Incoming; use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; @@ -23,16 +20,16 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{http, Request, Response, StatusCode}; use jsonwebtoken::TokenData; -use log::trace; -use notify::{EventKind, RecommendedWatcher, Watcher}; use once_cell::sync::Lazy; use parking_lot::Mutex; +use crate::tunnel::tls_reloader::TlsReloader; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::select; use tokio::sync::{mpsc, oneshot}; +use tokio_rustls::TlsAcceptor; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; @@ -372,6 +369,24 @@ async fn server_upgrade(server_config: Arc, mut req: Request { + tls_acceptor: Arc, + tls_reloader: TlsReloader, + tls_config: &'a TlsServerConfig, +} +impl TlsContext<'_> { + pub fn tls_acceptor(&mut self) -> &Arc { + if self.tls_reloader.tls_reload_certificate.swap(false, Ordering::Relaxed) { + match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) { + Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor), + Err(err) => error!("Cannot reload TLS certificate {:?}", err), + }; + } + + &self.tls_acceptor + } +} + pub async fn run_server(server_config: Arc) -> anyhow::Result<()> { info!("Starting wstunnel server listening on {}", server_config.bind); @@ -380,14 +395,13 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req).map::, _>(Ok); // Init TLS if needed - let mut tls_context = if let Some(tls) = &server_config.tls { - let tls_reloader = TlsReloader::new(server_config.clone())?; - Some(( - Arc::new(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?), - tls_reloader.tls_reload_certificate.clone(), - tls, - tls_reloader, - )) + let mut tls_context = if let Some(tls_config) = &server_config.tls { + let tls_context = TlsContext { + tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()]))?), + tls_reloader: TlsReloader::new(server_config.clone())?, + tls_config, + }; + Some(tls_context) } else { None }; @@ -416,16 +430,9 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Accepting connection"); let upgrade_fn = upgrade_fn.clone(); // TLS - if let Some((tls_acceptor, tls_reload_certificate, tls_config, _)) = tls_context.as_mut() { + if let Some(tls) = tls_context.as_mut() { // Reload TLS certificate if needed - if tls_reload_certificate.swap(false, Ordering::Relaxed) { - match tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()])) { - Ok(acceptor) => *tls_acceptor = Arc::new(acceptor), - Err(err) => error!("Cannot reload TLS certificate {:?}", err), - }; - } - - let tls_acceptor = tls_acceptor.clone(); + let tls_acceptor = tls.tls_acceptor().clone(); let fut = async move { info!("Doing TLS handshake"); let tls_stream = match tls_acceptor.accept(stream).await { @@ -465,136 +472,3 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() }; } } - -struct TlsReloader { - fs_watcher: Arc>>, - tls_reload_certificate: Arc, -} - -impl TlsReloader { - pub fn new(server_config: Arc) -> anyhow::Result { - let this = Self { - fs_watcher: Arc::new(Mutex::new(None)), - tls_reload_certificate: Arc::new(AtomicBool::new(false)), - }; - - // If there is no custom certificate and private key, there is nothing to watch - let Some((Some(cert_path), Some(key_path))) = server_config - .tls - .as_ref() - .map(|t| (&t.tls_certificate_path, &t.tls_key_path)) - else { - return Ok(this); - }; - - let span = span!(Level::INFO, "tls_reloader"); - let _enter = span.enter(); - info!("Starting to watch tls certificate and private key for changes to reload them"); - let tls_reload_certificate = this.tls_reload_certificate.clone(); - let watcher = this.fs_watcher.clone(); - let server_config = server_config.clone(); - let span = span.clone(); - - let mut watcher = notify::recommended_watcher(move |event: notify::Result| { - let _enter = span.enter(); - let event = match event { - Ok(event) => event, - Err(err) => { - error!("Error while watching tls certificate and private key for changes {:?}", err); - return; - } - }; - - if event.kind.is_access() { - return; - } - - let tls = server_config.tls.as_ref().unwrap(); - let cert_path = tls.tls_certificate_path.as_ref().unwrap(); - let key_path = tls.tls_key_path.as_ref().unwrap(); - - if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) { - match event.kind { - EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) { - Ok(tls_certs) => { - *tls.tls_certificate.lock() = tls_certs; - tls_reload_certificate.store(true, Ordering::Relaxed); - } - Err(err) => { - warn!("Error while loading TLS certificate {:?}", err); - } - }, - EventKind::Remove(_) => { - warn!("TLS certificate file has been removed, trying to re-set a watch for it"); - Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); - } - EventKind::Access(_) | EventKind::Other | EventKind::Any => { - trace!("Ignoring event {:?}", event); - } - } - } - - if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) { - match event.kind { - EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) { - Ok(tls_key) => { - *tls.tls_key.lock() = tls_key; - tls_reload_certificate.store(true, Ordering::Relaxed); - } - Err(err) => { - warn!("Error while loading TLS private key {:?}", err); - } - }, - EventKind::Remove(_) => { - warn!("TLS private key file has been removed, trying to re-set a watch for it"); - Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); - } - EventKind::Access(_) | EventKind::Other | EventKind::Any => { - trace!("Ignoring event {:?}", event); - } - } - } - }) - .with_context(|| "Cannot create tls certificate watcher")?; - - watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?; - watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?; - *this.fs_watcher.lock() = Some(watcher); - - Ok(this) - } - - fn try_rewatch_certificate(watcher: Arc>>, path: PathBuf) { - thread::spawn(move || { - let span = span!(Level::INFO, "tls_reloader"); - let _enter = span.enter(); - - thread::sleep(Duration::from_secs(10)); - while !path.exists() { - warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path); - thread::sleep(Duration::from_secs(10)); - } - let mut watcher = watcher.lock(); - let _ = watcher.as_mut().unwrap().unwatch(&path); - let Ok(_) = watcher - .as_mut() - .unwrap() - .watch(&path, notify::RecursiveMode::NonRecursive) - .map_err(|err| { - error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err); - error!("TLS certificate will not be auto-reloaded anymore"); - }) - else { - return; - }; - - let Ok(file) = File::open(&path) else { - return; - }; - let _ = file.set_modified(SystemTime::now()).map_err(|err| { - error!("Cannot force reload TLS file {:?}: {:?}", path, err); - error!("Old certificate will be used until the next change"); - }); - }); - } -} diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs new file mode 100644 index 0000000..e2e22b9 --- /dev/null +++ b/src/tunnel/tls_reloader.rs @@ -0,0 +1,137 @@ +use crate::{tls, WsServerConfig}; +use anyhow::Context; +use log::trace; +use notify::{EventKind, RecommendedWatcher, Watcher}; +use parking_lot::Mutex; +use std::fs::File; +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, SystemTime}; +use tracing::{error, info, warn}; + +pub struct TlsReloader { + fs_watcher: Arc>>, + pub tls_reload_certificate: Arc, +} + +impl TlsReloader { + pub fn new(server_config: Arc) -> anyhow::Result { + let this = Self { + fs_watcher: Arc::new(Mutex::new(None)), + tls_reload_certificate: Arc::new(AtomicBool::new(false)), + }; + + // If there is no custom certificate and private key, there is nothing to watch + let Some((Some(cert_path), Some(key_path))) = server_config + .tls + .as_ref() + .map(|t| (&t.tls_certificate_path, &t.tls_key_path)) + else { + return Ok(this); + }; + + info!("Starting to watch tls certificate and private key for changes to reload them"); + let tls_reload_certificate = this.tls_reload_certificate.clone(); + let watcher = this.fs_watcher.clone(); + let server_config = server_config.clone(); + + let mut watcher = notify::recommended_watcher(move |event: notify::Result| { + let event = match event { + Ok(event) => event, + Err(err) => { + error!("Error while watching tls certificate and private key for changes {:?}", err); + return; + } + }; + + if event.kind.is_access() { + return; + } + + let tls = server_config.tls.as_ref().unwrap(); + let cert_path = tls.tls_certificate_path.as_ref().unwrap(); + let key_path = tls.tls_key_path.as_ref().unwrap(); + + if let Some(path) = event.paths.iter().find(|p| p.ends_with(cert_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_certificates_from_pem(cert_path) { + Ok(tls_certs) => { + *tls.tls_certificate.lock() = tls_certs; + tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS certificate {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS certificate file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + + if let Some(path) = event.paths.iter().find(|p| p.ends_with(key_path)) { + match event.kind { + EventKind::Create(_) | EventKind::Modify(_) => match tls::load_private_key_from_file(key_path) { + Ok(tls_key) => { + *tls.tls_key.lock() = tls_key; + tls_reload_certificate.store(true, Ordering::Relaxed); + } + Err(err) => { + warn!("Error while loading TLS private key {:?}", err); + } + }, + EventKind::Remove(_) => { + warn!("TLS private key file has been removed, trying to re-set a watch for it"); + Self::try_rewatch_certificate(watcher.clone(), path.to_path_buf()); + } + EventKind::Access(_) | EventKind::Other | EventKind::Any => { + trace!("Ignoring event {:?}", event); + } + } + } + }) + .with_context(|| "Cannot create tls certificate watcher")?; + + watcher.watch(cert_path, notify::RecursiveMode::NonRecursive)?; + watcher.watch(key_path, notify::RecursiveMode::NonRecursive)?; + *this.fs_watcher.lock() = Some(watcher); + + Ok(this) + } + + fn try_rewatch_certificate(watcher: Arc>>, path: PathBuf) { + thread::spawn(move || { + while !path.exists() { + warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path); + thread::sleep(Duration::from_secs(10)); + } + let mut watcher = watcher.lock(); + let _ = watcher.as_mut().unwrap().unwatch(&path); + let Ok(_) = watcher + .as_mut() + .unwrap() + .watch(&path, notify::RecursiveMode::NonRecursive) + .map_err(|err| { + error!("Cannot re-set a watch for TLS file {:?}: {:?}", path, err); + error!("TLS certificate will not be auto-reloaded anymore"); + }) + else { + return; + }; + + let Ok(file) = File::open(&path) else { + return; + }; + let _ = file.set_modified(SystemTime::now()).map_err(|err| { + error!("Cannot force reload TLS file {:?}: {:?}", path, err); + error!("Old certificate will be used until the next change"); + }); + }); + } +}