diff --git a/src/main.rs b/src/main.rs index d227dc8..db6eabd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ mod tunnel; mod udp; #[cfg(unix)] mod unix_socket; +mod tls_utils; use anyhow::anyhow; use base64::Engine; @@ -44,11 +45,10 @@ use crate::restrictions::types::RestrictionsRules; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::udp::MyUdpSocket; +use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; -use x509_parser::parse_x509_certificate; -use x509_parser::prelude::X509Certificate; const DEFAULT_CLIENT_UPGRADE_PATH_PREFIX: &str = "v1"; @@ -602,30 +602,6 @@ fn parse_server_url(arg: &str) -> Result { Ok(url) } -/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate -/// is present in the vector. The other certificates should be (intermediate) CA certificates. -fn find_leaf_certificate<'a>(tls_certificates: &'a Vec>) -> Option> { - for tls_certificate in tls_certificates { - if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) { - if !tls_certificate_x509.is_ca() { - return Some(tls_certificate_x509); - } - } - } - None -} - -/// Returns the common name (CN) as specified in the supplied certificate. -fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option { - tls_certificate_x509 - .tbs_certificate - .subject - .iter_common_name() - .flat_map(|cn| cn.as_str().ok()) - .map(|cn| cn.to_string()) - .next() -} - #[derive(Clone)] pub struct TlsClientConfig { pub tls_sni_disabled: bool, diff --git a/src/tls_utils.rs b/src/tls_utils.rs new file mode 100644 index 0000000..d9f301a --- /dev/null +++ b/src/tls_utils.rs @@ -0,0 +1,27 @@ +use tokio_rustls::rustls::pki_types::CertificateDer; +use x509_parser::parse_x509_certificate; +use x509_parser::prelude::X509Certificate; + +/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate +/// is present in the vector. The other certificates should be (intermediate) CA certificates. +pub fn find_leaf_certificate<'a>(tls_certificates: &'a Vec>) -> Option> { + for tls_certificate in tls_certificates { + if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) { + if !tls_certificate_x509.is_ca() { + return Some(tls_certificate_x509); + } + } + } + None +} + +/// Returns the common name (CN) as specified in the supplied certificate. +pub fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option { + tls_certificate_x509 + .tbs_certificate + .subject + .iter_common_name() + .flat_map(|cn| cn.as_str().ok()) + .map(|cn| cn.to_string()) + .next() +} \ No newline at end of file diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index a9d18bd..6e778fb 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -37,6 +37,7 @@ use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use crate::udp::UdpStream; +use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::select; @@ -422,6 +423,7 @@ fn validate_tunnel<'a>( async fn ws_server_upgrade( server_config: Arc, restrictions: Arc, + restrict_path_prefix: Option, mut client_addr: SocketAddr, mut req: Request, ) -> Response { @@ -448,6 +450,16 @@ async fn ws_server_upgrade( Err(err) => return err, }; + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!("Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", path_prefix, restrict_path); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Requested upgrade path does not match upgrade path restriction (mTLS, etc.)".into()) + .unwrap(); + } + } + let jwt = match extract_tunnel_info(&req) { Ok(jwt) => jwt, Err(err) => return err, @@ -547,6 +559,7 @@ async fn ws_server_upgrade( async fn http_server_upgrade( server_config: Arc, restrictions: Arc, + restrict_path_prefix: Option, mut client_addr: SocketAddr, mut req: Request, ) -> Response>> { @@ -565,6 +578,16 @@ async fn http_server_upgrade( Err(err) => return err.map(Either::Left), }; + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!("Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", path_prefix, restrict_path); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Requested upgrade path does not match upgrade path restriction (mTLS, etc.)".to_string())) + .unwrap(); + } + } + let jwt = match extract_tunnel_info(&req) { Ok(jwt) => jwt, Err(err) => return err.map(Either::Left), @@ -674,34 +697,36 @@ pub async fn run_server(server_config: Arc, restrictions: Restri // setup upgrade request handler let mk_websocket_upgrade_fn = - |server_config: Arc, restrictions: Arc, client_addr: SocketAddr| { + |server_config: Arc, restrictions: Arc, restrict_path: Option,client_addr: SocketAddr| { move |req: Request| { - ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + ws_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) .map::, _>(Ok) } }; let mk_http_upgrade_fn = - |server_config: Arc, restrictions: Arc, client_addr: SocketAddr| { + |server_config: Arc, restrictions: Arc, restrict_path: Option, client_addr: SocketAddr| { move |req: Request| { - http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) .map::, _>(Ok) } }; let mk_auto_upgrade_fn = |server_config: Arc, restrictions: Arc, + restrict_path: Option, client_addr: SocketAddr| { move |req: Request| { let server_config = server_config.clone(); let restrictions = restrictions.clone(); + let restrict_path = restrict_path.clone(); async move { if fastwebsockets::upgrade::is_upgrade_request(&req) { - ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + ws_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path, client_addr, req) .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) .await } else if req.version() == Version::HTTP_2 { - http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req) + http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) .map::, _>(Ok) .await } else { @@ -786,6 +811,14 @@ pub async fn run_server(server_config: Arc, restrictions: Restri } }; + let restrict_path = if let Some(client_cert_chain) = + tls_stream.inner().get_ref().1.peer_certificates() { + find_leaf_certificate(&client_cert_chain.to_vec()) + .and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) + } else { + None + }; + match tls_stream.inner().get_ref().1.alpn_protocol() { // http2 Some(b"h2") => { @@ -794,7 +827,7 @@ pub async fn run_server(server_config: Arc, restrictions: Restri conn_builder.keep_alive_interval(ping); } - let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), peer_addr); + let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); if let Err(e) = con_fut.await { error!("Error while upgrading cnx to http: {:?}", e); @@ -803,7 +836,7 @@ pub async fn run_server(server_config: Arc, restrictions: Restri // websocket _ => { let websocket_upgrade_fn = - mk_websocket_upgrade_fn(server_config, restrictions.clone(), peer_addr); + mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); let conn_fut = http1::Builder::new() .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .with_upgrades(); @@ -828,7 +861,7 @@ pub async fn run_server(server_config: Arc, restrictions: Restri conn_fut.http2().keep_alive_interval(ping); } - let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), peer_addr); + let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), None, peer_addr); let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); if let Err(e) = upgradable.await {