Add support for mTLS
This commit is contained in:
parent
4524397d4f
commit
70b5a216b0
20 changed files with 1051 additions and 57 deletions
|
@ -1,6 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod server;
|
||||
mod tls_reloader;
|
||||
pub mod tls_reloader;
|
||||
mod transport;
|
||||
|
||||
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
|
||||
|
@ -104,6 +104,15 @@ impl TransportScheme {
|
|||
TransportScheme::Https => "https",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
|
||||
match self {
|
||||
TransportScheme::Ws => vec![],
|
||||
TransportScheme::Wss => vec![b"http/1.1".to_vec()],
|
||||
TransportScheme::Http => vec![],
|
||||
TransportScheme::Https => vec![b"h2".to_vec()],
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromStr for TransportScheme {
|
||||
type Err = ();
|
||||
|
|
|
@ -597,7 +597,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
let mut tls_context = if let Some(tls_config) = &server_config.tls {
|
||||
let tls_context = TlsContext {
|
||||
tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]))?),
|
||||
tls_reloader: TlsReloader::new(server_config.clone())?,
|
||||
tls_reloader: TlsReloader::new_for_server(server_config.clone())?,
|
||||
tls_config,
|
||||
};
|
||||
Some(tls_context)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{tls, WsServerConfig};
|
||||
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
|
||||
use crate::{tls, WsClientConfig, WsServerConfig};
|
||||
use anyhow::Context;
|
||||
use log::trace;
|
||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||
|
@ -10,41 +11,108 @@ use std::thread;
|
|||
use std::time::Duration;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
struct TlsReloaderState {
|
||||
struct TlsReloaderServerState {
|
||||
fs_watcher: Mutex<RecommendedWatcher>,
|
||||
tls_reload_certificate: AtomicBool,
|
||||
server_config: Arc<WsServerConfig>,
|
||||
cert_path: PathBuf,
|
||||
key_path: PathBuf,
|
||||
client_ca_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
struct TlsReloaderClientState {
|
||||
fs_watcher: Mutex<RecommendedWatcher>,
|
||||
tls_reload_certificate: AtomicBool,
|
||||
client_config: Arc<WsClientConfig>,
|
||||
cert_path: PathBuf,
|
||||
key_path: PathBuf,
|
||||
}
|
||||
|
||||
enum TlsReloaderState {
|
||||
Empty,
|
||||
Server(Arc<TlsReloaderServerState>),
|
||||
Client(Arc<TlsReloaderClientState>),
|
||||
}
|
||||
|
||||
impl TlsReloaderState {
|
||||
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
|
||||
match self {
|
||||
TlsReloaderState::Empty => unreachable!(),
|
||||
Server(this) => &this.fs_watcher,
|
||||
Client(this) => &this.fs_watcher,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TlsReloader {
|
||||
state: Option<Arc<TlsReloaderState>>,
|
||||
state: TlsReloaderState,
|
||||
}
|
||||
|
||||
impl TlsReloader {
|
||||
pub fn new(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||
pub fn new_for_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<Self> {
|
||||
// If there is no custom certificate and private key, there is nothing to watch
|
||||
let Some((Some(cert_path), Some(key_path))) = server_config
|
||||
let Some((Some(cert_path), Some(key_path), client_ca_certs)) = server_config
|
||||
.tls
|
||||
.as_ref()
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path, &t.tls_client_ca_certs_path))
|
||||
else {
|
||||
return Ok(Self { state: None });
|
||||
return Ok(Self {
|
||||
state: TlsReloaderState::Empty,
|
||||
});
|
||||
};
|
||||
|
||||
let this = Arc::new(TlsReloaderState {
|
||||
let this = Arc::new(TlsReloaderServerState {
|
||||
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
|
||||
tls_reload_certificate: AtomicBool::new(false),
|
||||
cert_path: cert_path.to_path_buf(),
|
||||
key_path: key_path.to_path_buf(),
|
||||
client_ca_path: client_ca_certs.as_ref().map(|x| x.to_path_buf()),
|
||||
server_config,
|
||||
});
|
||||
|
||||
info!("Starting to watch tls certificate and private key for changes to reload them");
|
||||
info!("Starting to watch tls certificates and private key for changes to reload them");
|
||||
let mut watcher = notify::recommended_watcher({
|
||||
let this = this.clone();
|
||||
let this = Server(this.clone());
|
||||
|
||||
move |event: notify::Result<notify::Event>| Self::handle_fs_event(&this, event)
|
||||
move |event: notify::Result<notify::Event>| Self::handle_server_fs_event(&this, event)
|
||||
})
|
||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||
|
||||
watcher.watch(&this.cert_path, notify::RecursiveMode::NonRecursive)?;
|
||||
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
|
||||
if let Some(client_ca_path) = &this.client_ca_path {
|
||||
watcher.watch(client_ca_path, notify::RecursiveMode::NonRecursive)?;
|
||||
}
|
||||
*this.fs_watcher.lock() = watcher;
|
||||
|
||||
Ok(Self { state: Server(this) })
|
||||
}
|
||||
|
||||
pub fn new_for_client(client_config: Arc<WsClientConfig>) -> anyhow::Result<Self> {
|
||||
// If there is no custom certificate and private key, there is nothing to watch
|
||||
let Some((Some(cert_path), Some(key_path))) = client_config
|
||||
.remote_addr
|
||||
.tls()
|
||||
.map(|t| (&t.tls_certificate_path, &t.tls_key_path))
|
||||
else {
|
||||
return Ok(Self {
|
||||
state: TlsReloaderState::Empty,
|
||||
});
|
||||
};
|
||||
|
||||
let this = Arc::new(TlsReloaderClientState {
|
||||
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
|
||||
tls_reload_certificate: AtomicBool::new(false),
|
||||
cert_path: cert_path.to_path_buf(),
|
||||
key_path: key_path.to_path_buf(),
|
||||
client_config,
|
||||
});
|
||||
|
||||
info!("Starting to watch tls certificates and private key for changes to reload them");
|
||||
let mut watcher = notify::recommended_watcher({
|
||||
let this = Client(this.clone());
|
||||
|
||||
move |event: notify::Result<notify::Event>| Self::handle_client_fs_event(&this, event)
|
||||
})
|
||||
.with_context(|| "Cannot create tls certificate watcher")?;
|
||||
|
||||
|
@ -52,24 +120,25 @@ impl TlsReloader {
|
|||
watcher.watch(&this.key_path, notify::RecursiveMode::NonRecursive)?;
|
||||
*this.fs_watcher.lock() = watcher;
|
||||
|
||||
Ok(Self { state: Some(this) })
|
||||
Ok(Self { state: Client(this) })
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn should_reload_certificate(&self) -> bool {
|
||||
match &self.state {
|
||||
None => false,
|
||||
Some(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
TlsReloaderState::Empty => false,
|
||||
Server(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
Client(this) => this.tls_reload_certificate.swap(false, Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_rewatch_certificate(this: Arc<TlsReloaderState>, path: PathBuf) {
|
||||
fn try_rewatch_certificate(this: TlsReloaderState, path: PathBuf) {
|
||||
thread::spawn(move || {
|
||||
while !path.exists() {
|
||||
warn!("TLS file {:?} does not exist anymore, waiting for it to be created", path);
|
||||
thread::sleep(Duration::from_secs(10));
|
||||
}
|
||||
let mut watcher = this.fs_watcher.lock();
|
||||
let mut watcher = this.fs_watcher().lock();
|
||||
let _ = watcher.unwatch(&path);
|
||||
let Ok(_) = watcher
|
||||
.watch(&path, notify::RecursiveMode::NonRecursive)
|
||||
|
@ -88,11 +157,21 @@ impl TlsReloader {
|
|||
paths: vec![path],
|
||||
attrs: Default::default(),
|
||||
};
|
||||
Self::handle_fs_event(&this, Ok(event));
|
||||
|
||||
match &this {
|
||||
Server(_) => Self::handle_server_fs_event(&this, Ok(event)),
|
||||
Client(_) => Self::handle_client_fs_event(&this, Ok(event)),
|
||||
TlsReloaderState::Empty => {}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_fs_event(this: &Arc<TlsReloaderState>, event: notify::Result<notify::Event>) {
|
||||
fn handle_server_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
|
||||
let this = match this {
|
||||
TlsReloaderState::Empty | Client(_) => return,
|
||||
Server(st) => st,
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
|
@ -115,12 +194,12 @@ impl TlsReloader {
|
|||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
|
@ -137,12 +216,142 @@ impl TlsReloader {
|
|||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS private key {:?}", err);
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(this.clone(), path.to_path_buf());
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(client_ca_path) = &this.client_ca_path {
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(client_ca_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => {
|
||||
match tls::load_certificates_from_pem(client_ca_path) {
|
||||
Ok(tls_certs) => {
|
||||
if let Some(client_certs) = &tls.tls_client_ca_certificates {
|
||||
*client_certs.lock() = tls_certs;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error while loading TLS client certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS client certificate has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Server(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_client_fs_event(this: &TlsReloaderState, event: notify::Result<notify::Event>) {
|
||||
let this = match this {
|
||||
TlsReloaderState::Empty | Server(_) => return,
|
||||
Client(st) => st,
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
error!("Error while watching tls certificate and private key for changes {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if event.kind.is_access() {
|
||||
return;
|
||||
}
|
||||
|
||||
let tls = this.client_config.remote_addr.tls().unwrap();
|
||||
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.cert_path)) {
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match (
|
||||
tls::load_certificates_from_pem(&this.cert_path),
|
||||
tls::load_private_key_from_file(&this.key_path),
|
||||
) {
|
||||
(Ok(tls_certs), Ok(tls_key)) => {
|
||||
let tls_connector = tls::tls_connector(
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
let tls_connector = match tls_connector {
|
||||
Ok(cn) => cn,
|
||||
Err(err) => {
|
||||
error!("Error while creating TLS connector {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
warn!("Error while loading TLS certificate {:?}", err);
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS certificate file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = event
|
||||
.paths
|
||||
.iter()
|
||||
.find(|p| p.ends_with(tls.tls_key_path.as_ref().unwrap()))
|
||||
{
|
||||
match event.kind {
|
||||
EventKind::Create(_) | EventKind::Modify(_) => match (
|
||||
tls::load_certificates_from_pem(&this.cert_path),
|
||||
tls::load_private_key_from_file(&this.key_path),
|
||||
) {
|
||||
(Ok(tls_certs), Ok(tls_key)) => {
|
||||
let tls_connector = tls::tls_connector(
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
let tls_connector = match tls_connector {
|
||||
Ok(cn) => cn,
|
||||
Err(err) => {
|
||||
error!("Error while creating TLS connector {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
warn!("Error while loading TLS private key {:?}", err);
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
},
|
||||
EventKind::Remove(_) => {
|
||||
warn!("TLS private key file has been removed, trying to re-set a watch for it");
|
||||
Self::try_rewatch_certificate(Client(this.clone()), path.to_path_buf());
|
||||
}
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
|
||||
trace!("Ignoring event {:?}", event);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue