From 88e42d3b9fa02feb435711ab474a694c676bc450 Mon Sep 17 00:00:00 2001 From: Jasper Siepkes Date: Mon, 6 May 2024 09:00:08 +0100 Subject: [PATCH] Allow client certificate CN to be used for upgrade path (#264) This change causes the wstunnel client to use the common name (CN) of the client's certificate for the upgrade path when mTLS is enabled. --- Cargo.toml | 1 + src/main.rs | 43 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c49144d..55c8906 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ notify = { version = "6.1.1", features = [] } rustls-native-certs = { version = "0.7.0", features = [] } rustls-pemfile = { version = "2.1.1", features = [] } +x509-parser = "0.16.0" scopeguard = "1.2.0" serde = { version = "1.0.197", features = ["derive"] } socket2 = { version = "0.5.6", features = [] } diff --git a/src/main.rs b/src/main.rs index 80fda7a..c787046 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,6 +48,10 @@ use crate::udp::MyUdpSocket; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; +use x509_parser::{parse_x509_certificate}; +use x509_parser::prelude::{X509Certificate}; + +const DEFAULT_CLIENT_UPGRADE_PATH_PREFIX: &str = "v1"; /// Use Websocket or HTTP2 protocol to tunnel {TCP,UDP} traffic /// wsTunnelClient <---> wsTunnelServer <---> RemoteHost @@ -179,7 +183,7 @@ struct Client { #[arg( short = 'P', long, - default_value = "v1", + default_value = DEFAULT_CLIENT_UPGRADE_PATH_PREFIX, verbatim_doc_comment, env = "WSTUNNEL_HTTP_UPGRADE_PATH_PREFIX" )] @@ -226,6 +230,7 @@ struct Client { /// [Optional] Certificate (pem) to present to the server when connecting over TLS (HTTPS). /// Used when the server requires clients to authenticate themselves with a certificate (i.e. mTLS). + /// Unless overridden, the HTTP upgrade path will be configured to be the common name (CN) of the certificate. /// The certificate will be automatically reloaded if it changes #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] tls_certificate: Option, @@ -594,6 +599,28 @@ fn parse_server_url(arg: &str) -> Result { Ok(url) } +/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate +/// is present in the vector. The other certificates should be (intermediate) CA certificates. +fn find_leaf_certificate(tls_certificates: &Vec) -> Option { + for tls_certificate in tls_certificates { + if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(&tls_certificate.0) { + if !tls_certificate_x509.is_ca() { + return Some(tls_certificate_x509); + } + } + } + None +} + +/// Returns the common name (CN) as specified in the supplied certificate. +fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option { + tls_certificate_x509.tbs_certificate.subject + .iter_common_name() + .flat_map(|cn| cn.as_str().ok()) + .map(|cn| cn.to_string()) + .next() +} + #[derive(Clone)] pub struct TlsClientConfig { pub tls_sni_disabled: bool, @@ -744,6 +771,18 @@ async fn main() { (None, None) }; + let http_upgrade_path_prefix = if args.http_upgrade_path_prefix.eq(DEFAULT_CLIENT_UPGRADE_PATH_PREFIX) { + // When using mTLS and no manual http upgrade path is specified configure the HTTP upgrade path + // to be the common name (CN) of the client's certificate. + tls_certificate.as_ref() + .and_then(|tls_certs| find_leaf_certificate(tls_certs)) + .and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) + .unwrap_or(args.http_upgrade_path_prefix) + } else { + args.http_upgrade_path_prefix + }; + println!("{}", http_upgrade_path_prefix); + let transport_scheme = TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url"); let tls = match transport_scheme { @@ -808,7 +847,7 @@ async fn main() { ) .unwrap(), socket_so_mark: args.socket_so_mark, - http_upgrade_path_prefix: args.http_upgrade_path_prefix, + http_upgrade_path_prefix, http_upgrade_credentials: args.http_upgrade_credentials, http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(), http_headers_file: args.http_headers_file,