diff --git a/src/main.rs b/src/main.rs index db6eabd..7d1f4f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,11 +6,11 @@ mod socks5_udp; mod stdio; mod tcp; mod tls; +mod tls_utils; mod tunnel; mod udp; #[cfg(unix)] mod unix_socket; -mod tls_utils; use anyhow::anyhow; use base64::Engine; @@ -42,10 +42,10 @@ use tracing::{error, info}; use crate::dns::DnsResolver; use crate::restrictions::types::RestrictionsRules; +use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::udp::MyUdpSocket; -use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -756,7 +756,7 @@ async fn main() { // to be the common name (CN) of the client's certificate. tls_certificate .as_ref() - .and_then(find_leaf_certificate) + .and_then(|certs| find_leaf_certificate(certs.as_slice())) .and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) .unwrap_or(args.http_upgrade_path_prefix) } else { diff --git a/src/tls_utils.rs b/src/tls_utils.rs index d9f301a..6ccb284 100644 --- a/src/tls_utils.rs +++ b/src/tls_utils.rs @@ -4,7 +4,7 @@ use x509_parser::prelude::X509Certificate; /// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate /// is present in the vector. The other certificates should be (intermediate) CA certificates. -pub fn find_leaf_certificate<'a>(tls_certificates: &'a Vec>) -> Option> { +pub fn find_leaf_certificate<'a>(tls_certificates: &'a [CertificateDer<'static>]) -> Option> { for tls_certificate in tls_certificates { if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) { if !tls_certificate_x509.is_ca() { @@ -22,6 +22,6 @@ pub fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option, restrictions: Restri info!("Starting wstunnel server listening on {}", server_config.bind); // setup upgrade request handler - let mk_websocket_upgrade_fn = - |server_config: Arc, restrictions: Arc, restrict_path: Option,client_addr: SocketAddr| { - move |req: Request| { - ws_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) - .map::, _>(Ok) - } - }; + let mk_websocket_upgrade_fn = |server_config: Arc, + restrictions: Arc, + restrict_path: Option, + client_addr: SocketAddr| { + move |req: Request| { + ws_server_upgrade( + server_config.clone(), + restrictions.clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + } + }; - let mk_http_upgrade_fn = - |server_config: Arc, restrictions: Arc, restrict_path: Option, client_addr: SocketAddr| { - move |req: Request| { - http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) - .map::, _>(Ok) - } - }; + let mk_http_upgrade_fn = |server_config: Arc, + restrictions: Arc, + restrict_path: Option, + client_addr: SocketAddr| { + move |req: Request| { + http_server_upgrade( + server_config.clone(), + restrictions.clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + } + }; let mk_auto_upgrade_fn = |server_config: Arc, restrictions: Arc, @@ -726,9 +748,15 @@ pub async fn run_server(server_config: Arc, restrictions: Restri .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) .await } else if req.version() == Version::HTTP_2 { - http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) - .map::, _>(Ok) - .await + http_server_upgrade( + server_config.clone(), + restrictions.clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + .await } else { error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version()); Ok(http::Response::builder() @@ -811,15 +839,13 @@ pub async fn run_server(server_config: Arc, restrictions: Restri } }; - let restrict_path = if let Some(client_cert_chain) = - tls_stream.inner().get_ref().1.peer_certificates() { - find_leaf_certificate(&client_cert_chain.to_vec()) - .and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) - } else { - None - }; - - match tls_stream.inner().get_ref().1.alpn_protocol() { + let tls_ctx = tls_stream.inner().get_ref().1; + // extract client certificate common name if any + let restrict_path = tls_ctx + .peer_certificates() + .and_then(find_leaf_certificate) + .and_then(|c| cn_from_certificate(&c)); + match tls_ctx.alpn_protocol() { // http2 Some(b"h2") => { let mut conn_builder = http2::Builder::new(TokioExecutor::new()); @@ -827,7 +853,8 @@ pub async fn run_server(server_config: Arc, restrictions: Restri conn_builder.keep_alive_interval(ping); } - let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); + let http_upgrade_fn = + mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); if let Err(e) = con_fut.await { error!("Error while upgrading cnx to http: {:?}", e); @@ -836,7 +863,7 @@ pub async fn run_server(server_config: Arc, restrictions: Restri // websocket _ => { let websocket_upgrade_fn = - mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); + mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr); let conn_fut = http1::Builder::new() .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .with_upgrades();