Add support for mTLS
This commit is contained in:
parent
4524397d4f
commit
70b5a216b0
20 changed files with 1051 additions and 57 deletions
120
src/main.rs
120
src/main.rs
|
@ -19,7 +19,7 @@ use hyper::header::HOST;
|
|||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use log::{debug, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::fmt::{Debug, Formatter};
|
||||
|
@ -40,6 +40,7 @@ use tokio_rustls::TlsConnector;
|
|||
use tracing::{error, info};
|
||||
|
||||
use crate::dns::DnsResolver;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||
use crate::udp::MyUdpSocket;
|
||||
use tracing_subscriber::filter::Directive;
|
||||
|
@ -60,7 +61,7 @@ struct Wstunnel {
|
|||
|
||||
/// *WARNING* The flag does nothing, you need to set the env variable *WARNING*
|
||||
/// Control the number of threads that will be used.
|
||||
/// By default it is equal the number of cpus
|
||||
/// By default, it is equal the number of cpus
|
||||
#[arg(
|
||||
long,
|
||||
global = true,
|
||||
|
@ -132,7 +133,7 @@ struct Client {
|
|||
#[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)]
|
||||
connection_min_idle: u32,
|
||||
|
||||
/// Domain name that will be use as SNI during TLS handshake
|
||||
/// Domain name that will be used as SNI during TLS handshake
|
||||
/// Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header.
|
||||
/// or it will be flagged as fishy and your request rejected
|
||||
#[arg(long, value_name = "DOMAIN_NAME", value_parser = parse_sni_override, verbatim_doc_comment)]
|
||||
|
@ -144,7 +145,7 @@ struct Client {
|
|||
tls_sni_disable: bool,
|
||||
|
||||
/// Enable TLS certificate verification.
|
||||
/// Disabled by default. The client will happily connect to any server with self signed certificate.
|
||||
/// Disabled by default. The client will happily connect to any server with self-signed certificate.
|
||||
#[arg(long, verbatim_doc_comment)]
|
||||
tls_verify_certificate: bool,
|
||||
|
||||
|
@ -192,7 +193,7 @@ struct Client {
|
|||
websocket_ping_frequency_sec: Option<Duration>,
|
||||
|
||||
/// Enable the masking of websocket frames. Default is false
|
||||
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
|
||||
/// Enable this option only if you use unsecure (non TLS) websocket server, and you see some issues. Otherwise, it is just overhead.
|
||||
#[arg(long, default_value = "false", verbatim_doc_comment)]
|
||||
websocket_mask_frame: bool,
|
||||
|
||||
|
@ -203,7 +204,7 @@ struct Client {
|
|||
|
||||
/// Send custom headers in the upgrade request reading them from a file.
|
||||
/// It overrides http_headers specified from command line.
|
||||
/// File is read everytime and file format must contains lines with `HEADER_NAME: HEADER_VALUE`
|
||||
/// File is read everytime and file format must contain lines with `HEADER_NAME: HEADER_VALUE`
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
http_headers_file: Option<PathBuf>,
|
||||
|
||||
|
@ -220,6 +221,17 @@ struct Client {
|
|||
/// The only way to make it works with http2 is to have wstunnel directly exposed to the internet without any reverse proxy in front of it
|
||||
#[arg(value_name = "ws[s]|http[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
|
||||
remote_addr: Url,
|
||||
|
||||
/// [Optional] Certificate (pem) to present to the server when connecting over TLS (HTTPS).
|
||||
/// Used when the server requires clients to authenticate themselves with a certificate (i.e. mTLS).
|
||||
/// The certificate will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_certificate: Option<PathBuf>,
|
||||
|
||||
/// [Optional] The private key for the corresponding certificate used with mTLS.
|
||||
/// The certificate will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_private_key: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Debug)]
|
||||
|
@ -241,7 +253,7 @@ struct Server {
|
|||
websocket_ping_frequency_sec: Option<Duration>,
|
||||
|
||||
/// Enable the masking of websocket frames. Default is false
|
||||
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
|
||||
/// Enable this option only if you use unsecure (non TLS) websocket server, and you see some issues. Otherwise, it is just overhead.
|
||||
#[arg(long, default_value = "false", verbatim_doc_comment)]
|
||||
websocket_mask_frame: bool,
|
||||
|
||||
|
@ -264,7 +276,7 @@ struct Server {
|
|||
dns_resolver: Option<Vec<Url>>,
|
||||
|
||||
/// Server will only accept connection from if this specific path prefix is used during websocket upgrade.
|
||||
/// Useful if you specify in the client a custom path prefix and you want the server to only allow this one.
|
||||
/// Useful if you specify in the client a custom path prefix, and you want the server to only allow this one.
|
||||
/// The path prefix act as a secret to authenticate clients
|
||||
/// Disabled by default. Accept all path prefix. Can be specified multiple time
|
||||
#[arg(
|
||||
|
@ -275,7 +287,7 @@ struct Server {
|
|||
)]
|
||||
restrict_http_upgrade_path_prefix: Option<Vec<String>>,
|
||||
|
||||
/// [Optional] Use custom certificate (pem) instead of the default embedded self signed certificate.
|
||||
/// [Optional] Use custom certificate (pem) instead of the default embedded self-signed certificate.
|
||||
/// The certificate will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_certificate: Option<PathBuf>,
|
||||
|
@ -284,6 +296,12 @@ struct Server {
|
|||
/// The private key will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_private_key: Option<PathBuf>,
|
||||
|
||||
/// [Optional] Enables mTLS (client authentication with certificate). Argument must be PEM file
|
||||
/// containing one or more certificates of CA's of which the certificate of clients needs to be signed with.
|
||||
/// The ca will be automatically reloaded if it changes
|
||||
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||
tls_client_ca_certs: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
|
@ -565,15 +583,25 @@ pub struct TlsClientConfig {
|
|||
pub tls_sni_disabled: bool,
|
||||
pub tls_sni_override: Option<DnsName>,
|
||||
pub tls_verify_certificate: bool,
|
||||
pub tls_connector: TlsConnector,
|
||||
tls_connector: Arc<RwLock<TlsConnector>>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl TlsClientConfig {
|
||||
pub fn tls_connector(&self) -> TlsConnector {
|
||||
self.tls_connector.read().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsServerConfig {
|
||||
pub tls_certificate: Mutex<Vec<Certificate>>,
|
||||
pub tls_key: Mutex<PrivateKey>,
|
||||
pub tls_client_ca_certificates: Option<Mutex<Vec<Certificate>>>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
pub tls_client_ca_certs_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
pub struct WsServerConfig {
|
||||
|
@ -599,6 +627,14 @@ impl Debug for WsServerConfig {
|
|||
.field("timeout_connect", &self.timeout_connect)
|
||||
.field("websocket_mask_frame", &self.websocket_mask_frame)
|
||||
.field("tls", &self.tls.is_some())
|
||||
.field(
|
||||
"mTLS",
|
||||
&self
|
||||
.tls
|
||||
.as_ref()
|
||||
.map(|x| x.tls_client_ca_certificates.is_some())
|
||||
.unwrap_or(false),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -617,6 +653,7 @@ pub struct WsClientConfig {
|
|||
pub websocket_mask_frame: bool,
|
||||
pub http_proxy: Option<Url>,
|
||||
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
|
||||
tls_reloader: Option<Arc<TlsReloader>>,
|
||||
pub dns_resolver: DnsResolver,
|
||||
}
|
||||
|
||||
|
@ -682,30 +719,54 @@ async fn main() {
|
|||
|
||||
match args.commands {
|
||||
Commands::Client(args) => {
|
||||
let tls = match TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url")
|
||||
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
|
||||
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
|
||||
{
|
||||
let tls_certificate =
|
||||
tls::load_certificates_from_pem(cert).expect("Cannot load client TLS certificate (mTLS)");
|
||||
let tls_key = tls::load_private_key_from_file(key).expect("Cannot load client TLS private key (mTLS)");
|
||||
(Some(tls_certificate), Some(tls_key))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let transport_scheme =
|
||||
TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
|
||||
let tls = match transport_scheme {
|
||||
TransportScheme::Ws | TransportScheme::Http => None,
|
||||
TransportScheme::Wss => Some(TlsClientConfig {
|
||||
tls_connector: tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
Some(vec![b"http/1.1".to_vec()]),
|
||||
!args.tls_sni_disable,
|
||||
)
|
||||
.expect("Cannot create tls connector"),
|
||||
tls_connector: Arc::new(RwLock::new(
|
||||
tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
transport_scheme.alpn_protocols(),
|
||||
!args.tls_sni_disable,
|
||||
tls_certificate,
|
||||
tls_key,
|
||||
)
|
||||
.expect("Cannot create tls connector"),
|
||||
)),
|
||||
tls_sni_override: args.tls_sni_override,
|
||||
tls_verify_certificate: args.tls_verify_certificate,
|
||||
tls_sni_disabled: args.tls_sni_disable,
|
||||
tls_certificate_path: args.tls_certificate.clone(),
|
||||
tls_key_path: args.tls_private_key.clone(),
|
||||
}),
|
||||
TransportScheme::Https => Some(TlsClientConfig {
|
||||
tls_connector: tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
Some(vec![b"h2".to_vec()]),
|
||||
!args.tls_sni_disable,
|
||||
)
|
||||
.expect("Cannot create tls connector"),
|
||||
tls_connector: Arc::new(RwLock::new(
|
||||
tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
transport_scheme.alpn_protocols(),
|
||||
!args.tls_sni_disable,
|
||||
tls_certificate,
|
||||
tls_key,
|
||||
)
|
||||
.expect("Cannot create tls connector"),
|
||||
)),
|
||||
tls_sni_override: args.tls_sni_override,
|
||||
tls_verify_certificate: args.tls_verify_certificate,
|
||||
tls_sni_disabled: args.tls_sni_disable,
|
||||
tls_certificate_path: args.tls_certificate.clone(),
|
||||
tls_key_path: args.tls_private_key.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
|
@ -761,6 +822,7 @@ async fn main() {
|
|||
None
|
||||
},
|
||||
cnx_pool: None,
|
||||
tls_reloader: None,
|
||||
dns_resolver: if let Ok(resolver) = hickory_resolver::AsyncResolver::tokio_from_system_conf() {
|
||||
DnsResolver::TrustDns(resolver)
|
||||
} else {
|
||||
|
@ -769,6 +831,9 @@ async fn main() {
|
|||
},
|
||||
};
|
||||
|
||||
let tls_reloader =
|
||||
TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader");
|
||||
client_config.tls_reloader = Some(Arc::new(tls_reloader));
|
||||
let pool = bb8::Pool::builder()
|
||||
.max_size(1000)
|
||||
.min_idle(Some(args.connection_min_idle))
|
||||
|
@ -1120,11 +1185,20 @@ async fn main() {
|
|||
embedded_certificate::TLS_PRIVATE_KEY.clone()
|
||||
};
|
||||
|
||||
let tls_client_ca_certificates = args.tls_client_ca_certs.as_ref().map(|tls_client_ca| {
|
||||
Mutex::new(
|
||||
tls::load_certificates_from_pem(tls_client_ca)
|
||||
.expect("Cannot load client CA certificate (mTLS)"),
|
||||
)
|
||||
});
|
||||
|
||||
Some(TlsServerConfig {
|
||||
tls_certificate: Mutex::new(tls_certificate),
|
||||
tls_key: Mutex::new(tls_key),
|
||||
tls_client_ca_certificates,
|
||||
tls_certificate_path: args.tls_certificate,
|
||||
tls_key_path: args.tls_private_key,
|
||||
tls_client_ca_certs_path: args.tls_client_ca_certs,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
|
40
src/tls.rs
40
src/tls.rs
|
@ -12,6 +12,7 @@ use tokio_rustls::client::TlsStream;
|
|||
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
|
||||
use crate::tunnel::TransportAddr;
|
||||
use tokio_rustls::rustls::server::{AllowAnyAuthenticatedClient, NoClientAuth};
|
||||
use tokio_rustls::rustls::{Certificate, ClientConfig, KeyLogFile, PrivateKey, ServerName};
|
||||
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
|
||||
use tracing::info;
|
||||
|
@ -65,8 +66,10 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
|
|||
|
||||
pub fn tls_connector(
|
||||
tls_verify_certificate: bool,
|
||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||
alpn_protocols: Vec<Vec<u8>>,
|
||||
enable_sni: bool,
|
||||
tls_client_certificate: Option<Vec<Certificate>>,
|
||||
tls_client_key: Option<PrivateKey>,
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
|
@ -79,10 +82,16 @@ pub fn tls_connector(
|
|||
}
|
||||
}
|
||||
|
||||
let mut config = ClientConfig::builder()
|
||||
let config_builder = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
.with_root_certificates(root_store);
|
||||
|
||||
let mut config = match (tls_client_certificate, tls_client_key) {
|
||||
(Some(tls_client_certificate), Some(tls_client_key)) => config_builder
|
||||
.with_client_auth_cert(tls_client_certificate, tls_client_key)
|
||||
.with_context(|| "Error setting up mTLS")?,
|
||||
_ => config_builder.with_no_client_auth(),
|
||||
};
|
||||
|
||||
config.enable_sni = enable_sni;
|
||||
config.key_log = Arc::new(KeyLogFile::new());
|
||||
|
@ -92,17 +101,28 @@ pub fn tls_connector(
|
|||
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
|
||||
}
|
||||
|
||||
if let Some(alpn_protocols) = alpn_protocols {
|
||||
config.alpn_protocols = alpn_protocols;
|
||||
}
|
||||
config.alpn_protocols = alpn_protocols;
|
||||
let tls_connector = TlsConnector::from(Arc::new(config));
|
||||
Ok(tls_connector)
|
||||
}
|
||||
|
||||
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
|
||||
let client_cert_verifier = if let Some(tls_client_ca_certificates) = &tls_cfg.tls_client_ca_certificates {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
for tls_client_ca_certificate in tls_client_ca_certificates.lock().iter() {
|
||||
root_store
|
||||
.add(tls_client_ca_certificate)
|
||||
.with_context(|| "Failed to add mTLS client CA certificate")?;
|
||||
}
|
||||
|
||||
Arc::new(AllowAnyAuthenticatedClient::new(root_store))
|
||||
} else {
|
||||
NoClientAuth::boxed()
|
||||
};
|
||||
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_client_cert_verifier(client_cert_verifier)
|
||||
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone())
|
||||
.with_context(|| "invalid tls certificate or private key")?;
|
||||
|
||||
|
@ -116,8 +136,8 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
|
|||
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
let sni = client_cfg.tls_server_name();
|
||||
let (tls_connector, sni_disabled) = match &client_cfg.remote_addr {
|
||||
TransportAddr::Wss { tls, .. } => (&tls.tls_connector, tls.tls_sni_disabled),
|
||||
TransportAddr::Https { tls, .. } => (&tls.tls_connector, tls.tls_sni_disabled),
|
||||
TransportAddr::Wss { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
|
||||
TransportAddr::Https { tls, .. } => (tls.tls_connector(), tls.tls_sni_disabled),
|
||||
TransportAddr::Http { .. } | TransportAddr::Ws { .. } => {
|
||||
return Err(anyhow!("Transport does not support TLS: {}", client_cfg.remote_addr.scheme()))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod server;
|
||||
mod tls_reloader;
|
||||
pub mod tls_reloader;
|
||||
mod transport;
|
||||
|
||||
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
|
||||
|
@ -104,6 +104,15 @@ impl TransportScheme {
|
|||
TransportScheme::Https => "https",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
|
||||
match self {
|
||||
TransportScheme::Ws => vec![],
|
||||
TransportScheme::Wss => vec![b"http/1.1".to_vec()],
|
||||
TransportScheme::Http => vec![],
|
||||
TransportScheme::Https => vec![b"h2".to_vec()],
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromStr for TransportScheme {
|
||||
type Err = ();
|
||||
|
|
|
@ -597,7 +597,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
||||
let tls_context = TlsContext {
|
||||
tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]))?),
|
||||
tls_reloader: TlsReloader::new(server_config.clone())?,
|
||||
tls_reloader: TlsReloader::new_for_server(server_config.clone())?,
|
||||
tls_config,
|
||||
};
|
||||
Some(tls_context)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{tls, WsServerConfig};
|
||||
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
|
||||
use crate::{tls, WsClientConfig, WsServerConfig};
|
||||
use anyhow::Context;
|
||||
use log::trace;
|
||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||
|
@ -10,41 +11,108 @@ use std::thread;
|
|||
use std::time::Duration;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
struct TlsReloaderState {
|
||||
struct TlsReloaderServerState {
|
||||
fs_watcher: Mutex<RecommendedWatcher>,
|
||||
tls_reload_certificate: AtomicBool,
|
||||
server_config: Arc<WsServerConfig>,
|
||||
cert_path: PathBuf,
|
||||
key_path: PathBuf,
|
||||
client_ca_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
struct TlsReloaderClientState {
|
||||
fs_watcher: Mutex<RecommendedWatcher>,
|
||||
tls_reload_certificate: AtomicBool,
|
||||
client_config: Arc<WsClientConfig>,
|
||||
cert_path: PathBuf,
|
||||
key_path: PathBuf,
|
||||
}
|
||||
|
||||
enum TlsReloaderState {
|
||||
Empty,
|
||||
Server(Arc<TlsReloaderServerState>),
|
||||
Client(Arc<TlsReloaderClientState>),
|
||||
}
|
||||
|
||||
impl TlsReloaderState {
|
||||
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
|
||||
match self {
|
||||
TlsReloaderState::Empty => unreachable!(),
|
||||
Server(this) => &this.fs_watcher,
|
||||
Client(this) => &this.fs_watcher,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TlsReloader {
|
||||
state: Option<Arc<TlsReloaderState>>,
|
||||
state: TlsReloaderState,
|
||||
}
|
||||
|
||||
impl TlsReloader {
|
||||
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||
pub fn new_for_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||
// If there is no custom certificate and private key, there is nothing to watch
|
||||
let Some((Some(cert_path), Some(key_path))) = server_config
|
||||
let Some((Some(cert_path), Some(key_path), client_ca_certs)) = server_config
|
||||
.tls
|
||||
.as_ref()
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path, &t.tls_client_ca_certs_path))
|
||||
else {
|
||||
return Ok(Self { state: None });
|
||||
return Ok(Self {
|
||||
state: TlsReloaderState::Empty,
|
||||
});
|
||||
};
|
||||
|
||||
let this = Arc::new(TlsReloaderState {
|
||||
let this = Arc::new(TlsReloaderServerState {
|
||||
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
|
||||
tls_reload_certificate: AtomicBool::new(false),
|
||||
cert_path: cert_path.to_path_buf(),
|
||||
key_path: key_path.to_path_buf(),
|
||||
client_ca_path: client_ca_certs.as_ref().map(|x| x.to_path_buf()),
|
||||
server_config,
|
||||
});
|
||||
|
||||
info!("Starting to watch tls certificate and private key for changes to reload them");
|
||||
info!("Starting to watch tls certificates and private key for changes to reload them");
|
||||
let mut watcher = notify::recommended_watcher({
|
||||
let this = this.clone();
|
||||
let this = Server(this.clone());
|
||||
|
||||
move |event: notify::Result<notify::Event>| Self::handle_fs_event(&this, event)
|
||||
move |event: notify::Result<notify::Event>| Self::handle_server_fs_event(&this, event)
|
||||
})
|
||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||
|
||||
watcher.watch(&this.cert_path, notify::RecursiveMode::NonRecursive)?;
|
||||
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
|
||||
if let Some(client_ca_path) = &this.client_ca_path {
|
||||
watcher.watch(client_ca_path, notify::RecursiveMode::NonRecursive)?;
|
||||
}
|
||||
*this.fs_watcher.lock() = watcher;
|
||||
|
||||
Ok(Self { state: Server(this) })
|
||||
}
|
||||
|
||||
pub fn new_for_client(client_config: Arc<WsClientConfig>) -> anyhow::Result<Self> {
|
||||
// If there is no custom certificate and private key, there is nothing to watch
|
||||
let Some((Some(cert_path), Some(key_path))) = client_config
|
||||
.remote_addr
|
||||
.tls()
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||
else {
|
||||
return Ok(Self {
|
||||
state: TlsReloaderState::Empty,
|
||||
});
|
||||
};
|
||||
|
||||
let this = Arc::new(TlsReloaderClientState {
|
||||
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
|
||||
tls_reload_certificate: AtomicBool::new(false),
|
||||
cert_path: cert_path.to_path_buf(),
|
||||
key_path: key_path.to_path_buf(),
|
||||
client_config,
|
||||
});
|
||||
|
||||
info!("Starting to watch tls certificates and private key for changes to reload them");
|
||||
let mut watcher = notify::recommended_watcher({
|
||||
let this = Client(this.clone());
|
||||
|
||||
move |event: notify::Result<notify::Event>| Self::handle_client_fs_event(&this, event)
|
||||
})
|
||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||
|
||||
|
@ -52,24 +120,25 @@ impl TlsReloader {
|
|||
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
|
||||
*this.fs_watcher.lock() = watcher;
|
||||
|
||||
Ok(Self { state: Some(this) })
|
||||
Ok(Self { state: Client(this) })
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn should_reload_certificate(&self) -> bool {
|
||||
match &self.state {
|
||||
None => false,
|
||||
Some(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
TlsReloaderState::Empty => false,
|
||||
Server(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
Client(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_rewatch_certificate(this: Arc<TlsReloaderState>, path: PathBuf) {
|
||||
fn try_rewatch_certificate(this: TlsReloaderState, path: PathBuf) {
|
||||
thread::spawn(move || {
|
||||
while !path.exists() {
|
||||
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
|
||||
thread::sleep(Duration::from_secs(10));
|
||||
}
|
||||
let mut watcher = this.fs_watcher.lock();
|
||||
let mut watcher = this.fs_watcher().lock();
|
||||
let _ = watcher.unwatch(&path);
|
||||
let Ok(_) = watcher
|
||||
.watch(&path, notify::RecursiveMode::NonRecursive)
|
||||
|
@ -88,11 +157,21 @@ impl TlsReloader {
|
|||
paths: vec![path],
|
||||
attrs: Default::default(),
|
||||
};
|
||||
Self::handle_fs_event(&this, Ok(event));
|
||||
|
||||
match &this {
|
||||
Server(_) => Self::handle_server_fs_event(&this, Ok(event)),
|
||||
Client(_) => Self::handle_client_fs_event(&this, Ok(event)),
|
||||
TlsReloaderState::Empty => {}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_fs_event(this: &Arc<TlsReloaderState>, event: notify::Result<notify::Event>) {
|
||||
fn handle_server_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
|
||||
let this = match this {
|
||||
TlsReloaderState::Empty | Client(_) => return,
|
||||
Server(st) => st,
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
|
@ -115,12 +194,12 @@ impl TlsReloader {
|
|||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
|
@ -137,12 +216,142 @@ impl TlsReloader {
|
|||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS private key {:?}", err);
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(client_ca_path) = &this.client_ca_path {
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(client_ca_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => {
|
||||
match tls::load_certificates_from_pem(client_ca_path) {
|
||||
Ok(tls_certs) => {
|
||||
if let Some(client_certs) = &tls.tls_client_ca_certificates {
|
||||
*client_certs.lock() = tls_certs;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS client certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS client certificate has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_client_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
|
||||
let this = match this {
|
||||
TlsReloaderState::Empty | Server(_) => return,
|
||||
Client(st) => st,
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
error!("Error while watching tls certificate and private key for changes {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if event.kind.is_access() {
|
||||
return;
|
||||
}
|
||||
|
||||
let tls = this.client_config.remote_addr.tls().unwrap();
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.cert_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match (
|
||||
tls::load_certificates_from_pem(&this.cert_path),
|
||||
tls::load_private_key_from_file(&this.key_path),
|
||||
) {
|
||||
(Ok(tls_certs), Ok(tls_key)) => {
|
||||
let tls_connector = tls::tls_connector(
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
let tls_connector = match tls_connector {
|
||||
Ok(cn) => cn,
|
||||
Err(err) => {
|
||||
error!("Error while creating TLS connector {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
warn!("Error while loading TLS certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = event
|
||||
.paths
|
||||
.iter()
|
||||
.find(|p| p.ends_with(tls.tls_key_path.as_ref().unwrap()))
|
||||
{
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match (
|
||||
tls::load_certificates_from_pem(&this.cert_path),
|
||||
tls::load_private_key_from_file(&this.key_path),
|
||||
) {
|
||||
(Ok(tls_certs), Ok(tls_key)) => {
|
||||
let tls_connector = tls::tls_connector(
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
let tls_connector = match tls_connector {
|
||||
Ok(cn) => cn,
|
||||
Err(err) => {
|
||||
error!("Error while creating TLS connector {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
warn!("Error while loading TLS private key {:?}", err);
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue