Reduce allocation when using client certificate

This commit is contained in:
Σrebe - Romain GERARD 2024-05-16 09:05:04 +02:00
parent ddebdfd3d2
commit 246862a6da
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 66 additions and 39 deletions

View file

@ -6,11 +6,11 @@ mod socks5_udp;
mod stdio; mod stdio;
mod tcp; mod tcp;
mod tls; mod tls;
mod tls_utils;
mod tunnel; mod tunnel;
mod udp; mod udp;
#[cfg(unix)] #[cfg(unix)]
mod unix_socket; mod unix_socket;
mod tls_utils;
use anyhow::anyhow; use anyhow::anyhow;
use base64::Engine; use base64::Engine;
@ -42,10 +42,10 @@ use tracing::{error, info};
use crate::dns::DnsResolver; use crate::dns::DnsResolver;
use crate::restrictions::types::RestrictionsRules; use crate::restrictions::types::RestrictionsRules;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use crate::udp::MyUdpSocket; use crate::udp::MyUdpSocket;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use tracing_subscriber::filter::Directive; use tracing_subscriber::filter::Directive;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use url::{Host, Url}; use url::{Host, Url};
@ -756,7 +756,7 @@ async fn main() {
// to be the common name (CN) of the client's certificate. // to be the common name (CN) of the client's certificate.
tls_certificate tls_certificate
.as_ref() .as_ref()
.and_then(find_leaf_certificate) .and_then(|certs| find_leaf_certificate(certs.as_slice()))
.and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) .and_then(|leaf_cert| cn_from_certificate(&leaf_cert))
.unwrap_or(args.http_upgrade_path_prefix) .unwrap_or(args.http_upgrade_path_prefix)
} else { } else {

View file

@ -4,7 +4,7 @@ use x509_parser::prelude::X509Certificate;
/// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate /// Find a leaf certificate in a vector of certificates. It is assumed only a single leaf certificate
/// is present in the vector. The other certificates should be (intermediate) CA certificates. /// is present in the vector. The other certificates should be (intermediate) CA certificates.
pub fn find_leaf_certificate<'a>(tls_certificates: &'a Vec<CertificateDer<'static>>) -> Option<X509Certificate<'a>> { pub fn find_leaf_certificate<'a>(tls_certificates: &'a [CertificateDer<'static>]) -> Option<X509Certificate<'a>> {
for tls_certificate in tls_certificates { for tls_certificate in tls_certificates {
if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) { if let Ok((_, tls_certificate_x509)) = parse_x509_certificate(tls_certificate) {
if !tls_certificate_x509.is_ca() { if !tls_certificate_x509.is_ca() {
@ -22,6 +22,6 @@ pub fn cn_from_certificate(tls_certificate_x509: &X509Certificate) -> Option<Str
.subject .subject
.iter_common_name() .iter_common_name()
.flat_map(|cn| cn.as_str().ok()) .flat_map(|cn| cn.as_str().ok())
.map(|cn| cn.to_string())
.next() .next()
} .map(|cn| cn.to_string())
}

View file

@ -33,11 +33,11 @@ use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
}; };
use crate::socks5::Socks5Stream; use crate::socks5::Socks5Stream;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use crate::udp::UdpStream; use crate::udp::UdpStream;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::select; use tokio::select;
@ -452,10 +452,13 @@ async fn ws_server_upgrade(
if let Some(restrict_path) = restrict_path_prefix { if let Some(restrict_path) = restrict_path_prefix {
if path_prefix != restrict_path { if path_prefix != restrict_path {
warn!("Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", path_prefix, restrict_path); warn!(
"Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)",
path_prefix, restrict_path
);
return http::Response::builder() return http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body("Requested upgrade path does not match upgrade path restriction (mTLS, etc.)".into()) .body("Invalid upgrade request".to_string())
.unwrap(); .unwrap();
} }
} }
@ -580,10 +583,13 @@ async fn http_server_upgrade(
if let Some(restrict_path) = restrict_path_prefix { if let Some(restrict_path) = restrict_path_prefix {
if path_prefix != restrict_path { if path_prefix != restrict_path {
warn!("Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", path_prefix, restrict_path); warn!(
"Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)",
path_prefix, restrict_path
);
return http::Response::builder() return http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Either::Left("Requested upgrade path does not match upgrade path restriction (mTLS, etc.)".to_string())) .body(Either::Left("Invalid upgrade request".to_string()))
.unwrap(); .unwrap();
} }
} }
@ -696,21 +702,37 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
info!("Starting wstunnel server listening on {}", server_config.bind); info!("Starting wstunnel server listening on {}", server_config.bind);
// setup upgrade request handler // setup upgrade request handler
let mk_websocket_upgrade_fn = let mk_websocket_upgrade_fn = |server_config: Arc<WsServerConfig>,
|server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, restrict_path: Option<String>,client_addr: SocketAddr| { restrictions: Arc<RestrictionsRules>,
move |req: Request<Incoming>| { restrict_path: Option<String>,
ws_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) client_addr: SocketAddr| {
.map::<anyhow::Result<_>, _>(Ok) move |req: Request<Incoming>| {
} ws_server_upgrade(
}; server_config.clone(),
restrictions.clone(),
restrict_path.clone(),
client_addr,
req,
)
.map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_http_upgrade_fn = let mk_http_upgrade_fn = |server_config: Arc<WsServerConfig>,
|server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, restrict_path: Option<String>, client_addr: SocketAddr| { restrictions: Arc<RestrictionsRules>,
move |req: Request<Incoming>| { restrict_path: Option<String>,
http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) client_addr: SocketAddr| {
.map::<anyhow::Result<_>, _>(Ok) move |req: Request<Incoming>| {
} http_server_upgrade(
}; server_config.clone(),
restrictions.clone(),
restrict_path.clone(),
client_addr,
req,
)
.map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>, let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>, restrictions: Arc<RestrictionsRules>,
@ -726,9 +748,15 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
.map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left)))
.await .await
} else if req.version() == Version::HTTP_2 { } else if req.version() == Version::HTTP_2 {
http_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) http_server_upgrade(
.map::<anyhow::Result<_>, _>(Ok) server_config.clone(),
.await restrictions.clone(),
restrict_path.clone(),
client_addr,
req,
)
.map::<anyhow::Result<_>, _>(Ok)
.await
} else { } else {
error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version()); error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version());
Ok(http::Response::builder() Ok(http::Response::builder()
@ -811,15 +839,13 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
} }
}; };
let restrict_path = if let Some(client_cert_chain) = let tls_ctx = tls_stream.inner().get_ref().1;
tls_stream.inner().get_ref().1.peer_certificates() { // extract client certificate common name if any
find_leaf_certificate(&client_cert_chain.to_vec()) let restrict_path = tls_ctx
.and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) .peer_certificates()
} else { .and_then(find_leaf_certificate)
None .and_then(|c| cn_from_certificate(&c));
}; match tls_ctx.alpn_protocol() {
match tls_stream.inner().get_ref().1.alpn_protocol() {
// http2 // http2
Some(b"h2") => { Some(b"h2") => {
let mut conn_builder = http2::Builder::new(TokioExecutor::new()); let mut conn_builder = http2::Builder::new(TokioExecutor::new());
@ -827,7 +853,8 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
conn_builder.keep_alive_interval(ping); conn_builder.keep_alive_interval(ping);
} }
let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); let http_upgrade_fn =
mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr);
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
if let Err(e) = con_fut.await { if let Err(e) = con_fut.await {
error!("Error while upgrading cnx to http: {:?}", e); error!("Error while upgrading cnx to http: {:?}", e);
@ -836,7 +863,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
// websocket // websocket
_ => { _ => {
let websocket_upgrade_fn = let websocket_upgrade_fn =
mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path.clone(), peer_addr); mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr);
let conn_fut = http1::Builder::new() let conn_fut = http1::Builder::new()
.serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) .serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
.with_upgrades(); .with_upgrades();