Add support for mTLS

This commit is contained in:
Σrebe - Romain GERARD 2024-04-17 20:13:49 +02:00
parent 4524397d4f
commit 70b5a216b0
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
20 changed files with 1051 additions and 57 deletions

View file

@ -1,6 +1,6 @@
pub mod client;
pub mod server;
mod tls_reloader;
pub mod tls_reloader;
mod transport;
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
@ -104,6 +104,15 @@ impl TransportScheme {
TransportScheme::Https => "https",
}
}
pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
match self {
TransportScheme::Ws => vec![],
TransportScheme::Wss => vec![b"http/1.1".to_vec()],
TransportScheme::Http => vec![],
TransportScheme::Https => vec![b"h2".to_vec()],
}
}
}
impl FromStr for TransportScheme {
type Err = ();

View file

@ -597,7 +597,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
let mut tls_context = if let Some(tls_config) = &server_config.tls {
let tls_context = TlsContext {
tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]))?),
tls_reloader: TlsReloader::new(server_config.clone())?,
tls_reloader: TlsReloader::new_for_server(server_config.clone())?,
tls_config,
};
Some(tls_context)

View file

@ -1,4 +1,5 @@
use crate::{tls, WsServerConfig};
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
use crate::{tls, WsClientConfig, WsServerConfig};
use anyhow::Context;
use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher};
@ -10,41 +11,108 @@ use std::thread;
use std::time::Duration;
use tracing::{error, info, warn};
struct TlsReloaderState {
struct TlsReloaderServerState {
fs_watcher: Mutex<RecommendedWatcher>,
tls_reload_certificate: AtomicBool,
server_config: Arc<WsServerConfig>,
cert_path: PathBuf,
key_path: PathBuf,
client_ca_path: Option<PathBuf>,
}
struct TlsReloaderClientState {
fs_watcher: Mutex<RecommendedWatcher>,
tls_reload_certificate: AtomicBool,
client_config: Arc<WsClientConfig>,
cert_path: PathBuf,
key_path: PathBuf,
}
enum TlsReloaderState {
Empty,
Server(Arc<TlsReloaderServerState>),
Client(Arc<TlsReloaderClientState>),
}
impl TlsReloaderState {
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
match self {
TlsReloaderState::Empty => unreachable!(),
Server(this) => &this.fs_watcher,
Client(this) => &this.fs_watcher,
}
}
}
pub struct TlsReloader {
state: Option<Arc<TlsReloaderState>>,
state: TlsReloaderState,
}
impl TlsReloader {
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
pub fn new_for_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
// If there is no custom certificate and private key, there is nothing to watch
let Some((Some(cert_path), Some(key_path))) = server_config
let Some((Some(cert_path), Some(key_path), client_ca_certs)) = server_config
.tls
.as_ref()
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
.map(|t| (&t.tls_certificate_path, &t.tls_key_path, &t.tls_client_ca_certs_path))
else {
return Ok(Self { state: None });
return Ok(Self {
state: TlsReloaderState::Empty,
});
};
let this = Arc::new(TlsReloaderState {
let this = Arc::new(TlsReloaderServerState {
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
tls_reload_certificate: AtomicBool::new(false),
cert_path: cert_path.to_path_buf(),
key_path: key_path.to_path_buf(),
client_ca_path: client_ca_certs.as_ref().map(|x| x.to_path_buf()),
server_config,
});
info!("Starting to watch tls certificate and private key for changes to reload them");
info!("Starting to watch tls certificates and private key for changes to reload them");
let mut watcher = notify::recommended_watcher({
let this = this.clone();
let this = Server(this.clone());
move |event: notify::Result<notify::Event>| Self::handle_fs_event(&this, event)
move |event: notify::Result<notify::Event>| Self::handle_server_fs_event(&this, event)
})
.with_context(|| "Cannot create tls certificate watcher")?;
watcher.watch(&this.cert_path, notify::RecursiveMode::NonRecursive)?;
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
if let Some(client_ca_path) = &this.client_ca_path {
watcher.watch(client_ca_path, notify::RecursiveMode::NonRecursive)?;
}
*this.fs_watcher.lock() = watcher;
Ok(Self { state: Server(this) })
}
pub fn new_for_client(client_config: Arc<WsClientConfig>) -> anyhow::Result<Self> {
// If there is no custom certificate and private key, there is nothing to watch
let Some((Some(cert_path), Some(key_path))) = client_config
.remote_addr
.tls()
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
else {
return Ok(Self {
state: TlsReloaderState::Empty,
});
};
let this = Arc::new(TlsReloaderClientState {
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
tls_reload_certificate: AtomicBool::new(false),
cert_path: cert_path.to_path_buf(),
key_path: key_path.to_path_buf(),
client_config,
});
info!("Starting to watch tls certificates and private key for changes to reload them");
let mut watcher = notify::recommended_watcher({
let this = Client(this.clone());
move |event: notify::Result<notify::Event>| Self::handle_client_fs_event(&this, event)
})
.with_context(|| "Cannot create tls certificate watcher")?;
@ -52,24 +120,25 @@ impl TlsReloader {
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
*this.fs_watcher.lock() = watcher;
Ok(Self { state: Some(this) })
Ok(Self { state: Client(this) })
}
#[inline]
pub fn should_reload_certificate(&self) -> bool {
match &self.state {
None => false,
Some(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
TlsReloaderState::Empty => false,
Server(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
Client(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
}
}
fn try_rewatch_certificate(this: Arc<TlsReloaderState>, path: PathBuf) {
fn try_rewatch_certificate(this: TlsReloaderState, path: PathBuf) {
thread::spawn(move || {
while !path.exists() {
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
thread::sleep(Duration::from_secs(10));
}
let mut watcher = this.fs_watcher.lock();
let mut watcher = this.fs_watcher().lock();
let _ = watcher.unwatch(&path);
let Ok(_) = watcher
.watch(&path, notify::RecursiveMode::NonRecursive)
@ -88,11 +157,21 @@ impl TlsReloader {
paths: vec![path],
attrs: Default::default(),
};
Self::handle_fs_event(&this, Ok(event));
match &this {
Server(_) => Self::handle_server_fs_event(&this, Ok(event)),
Client(_) => Self::handle_client_fs_event(&this, Ok(event)),
TlsReloaderState::Empty => {}
}
});
}
fn handle_fs_event(this: &Arc<TlsReloaderState>, event: notify::Result<notify::Event>) {
fn handle_server_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
let this = match this {
TlsReloaderState::Empty | Client(_) => return,
Server(st) => st,
};
let event = match event {
Ok(event) => event,
Err(err) => {
@ -115,12 +194,12 @@ impl TlsReloader {
}
Err(err) => {
warn!("Error while loading TLS certificate {:?}", err);
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
},
EventKind::Remove(_) => {
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
@ -137,12 +216,142 @@ impl TlsReloader {
}
Err(err) => {
warn!("Error while loading TLS private key {:?}", err);
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
},
EventKind::Remove(_) => {
warn!("TLS private key file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
if let Some(client_ca_path) = &this.client_ca_path {
if let Some(path) = event.paths.iter().find(|p| p.ends_with(client_ca_path)) {
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => {
match tls::load_certificates_from_pem(client_ca_path) {
Ok(tls_certs) => {
if let Some(client_certs) = &tls.tls_client_ca_certificates {
*client_certs.lock() = tls_certs;
this.tls_reload_certificate.store(true, Ordering::Relaxed);
}
}
Err(err) => {
warn!("Error while loading TLS client certificate {:?}", err);
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
}
}
EventKind::Remove(_) => {
warn!("TLS client certificate has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
}
}
fn handle_client_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
let this = match this {
TlsReloaderState::Empty | Server(_) => return,
Client(st) => st,
};
let event = match event {
Ok(event) => event,
Err(err) => {
error!("Error while watching tls certificate and private key for changes {:?}", err);
return;
}
};
if event.kind.is_access() {
return;
}
let tls = this.client_config.remote_addr.tls().unwrap();
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.cert_path)) {
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match (
tls::load_certificates_from_pem(&this.cert_path),
tls::load_private_key_from_file(&this.key_path),
) {
(Ok(tls_certs), Ok(tls_key)) => {
let tls_connector = tls::tls_connector(
tls.tls_verify_certificate,
this.client_config.remote_addr.scheme().alpn_protocols(),
!tls.tls_sni_disabled,
Some(tls_certs),
Some(tls_key),
);
let tls_connector = match tls_connector {
Ok(cn) => cn,
Err(err) => {
error!("Error while creating TLS connector {:?}", err);
return;
}
};
*tls.tls_connector.write() = tls_connector;
this.tls_reload_certificate.store(true, Ordering::Relaxed);
}
(Err(err), _) | (_, Err(err)) => {
warn!("Error while loading TLS certificate {:?}", err);
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
}
},
EventKind::Remove(_) => {
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);
}
}
}
if let Some(path) = event
.paths
.iter()
.find(|p| p.ends_with(tls.tls_key_path.as_ref().unwrap()))
{
match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => match (
tls::load_certificates_from_pem(&this.cert_path),
tls::load_private_key_from_file(&this.key_path),
) {
(Ok(tls_certs), Ok(tls_key)) => {
let tls_connector = tls::tls_connector(
tls.tls_verify_certificate,
this.client_config.remote_addr.scheme().alpn_protocols(),
!tls.tls_sni_disabled,
Some(tls_certs),
Some(tls_key),
);
let tls_connector = match tls_connector {
Ok(cn) => cn,
Err(err) => {
error!("Error while creating TLS connector {:?}", err);
return;
}
};
*tls.tls_connector.write() = tls_connector;
this.tls_reload_certificate.store(true, Ordering::Relaxed);
}
(Err(err), _) | (_, Err(err)) => {
warn!("Error while loading TLS private key {:?}", err);
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
}
},
EventKind::Remove(_) => {
warn!("TLS private key file has been removed, trying to re-set a watch for it");
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
}
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event);