wstunnel/src/tls.rs
Σrebe - Romain GERARD b30bd381e1 Bump
Former-commit-id: c4282dccbec4fa9d64fb60334fe83caec963140f [formerly 12eeb52b4a8760d1ec7c13d6cc77c9213a6d3392] [formerly 90e4dda3b1a8e224de2820c387e1e4a07a4db372 [formerly 978616526843c8918e23384b2404ccbf241c4dbf]]
Former-commit-id: 5035c63e099ff2d0729a69c059c4d1ac1a288c8e [formerly 7225907b8ab627bd90b8542d3ba2884764f6a209]
Former-commit-id: 4a1fb1590711e763896b5d525091d35a85a1c70a
Former-commit-id: b343703f7ccfcbb40f1642cd150a9b98d1fcb05e
Former-commit-id: 1e7c5340f1f6bb43041a95f5e405c1e2ec0d7b29
Former-commit-id: dc0d08065cbef90e0b8e890af551ffe6f47a9b17 [formerly e90c7a4f23afd6cbea95ebc55dace2960f6aa003]
Former-commit-id: 4991d306af1ea50d3b0eeb46dc5f47a6a5b2f4b0
2023-10-18 09:50:47 +02:00

125 lines
3.9 KiB
Rust

use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
use anyhow::{anyhow, Context};
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName};
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use tracing::info;
pub struct NullVerifier;
impl ServerCertVerifier for NullVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<Certificate>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)?;
Ok(certs.into_iter().map(Certificate).collect())
}
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
match keys.len() {
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
1 => Ok(PrivateKey(keys.remove(0))),
_ => Err(anyhow!(
"More than one PKCS8-encoded private key found in {path:?}"
)),
}
}
pub fn tls_connector(
tls_cfg: &TlsClientConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs()
.with_context(|| "Cannot load system certificates")?;
for cert in certs {
root_store.add(&Certificate(cert.0))?;
}
let mut config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// To bypass certificate verification
if !tls_cfg.tls_verify_certificate {
config
.dangerous()
.set_certificate_verifier(Arc::new(NullVerifier));
}
if let Some(alpn_protocols) = alpn_protocols {
config.alpn_protocols = alpn_protocols;
}
let tls_connector = TlsConnector::from(Arc::new(config));
Ok(tls_connector)
}
pub fn tls_acceptor(
tls_cfg: &TlsServerConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsAcceptor> {
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
.with_context(|| "invalid tls certificate or private key")?;
if let Some(alpn_protocols) = alpn_protocols {
config.alpn_protocols = alpn_protocols;
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn connect(
server_cfg: &WsClientConfig,
tls_cfg: &TlsClientConfig,
tcp_stream: TcpStream,
) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = server_cfg.tls_server_name();
info!(
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
server_cfg.remote_addr.0, server_cfg.remote_addr.1
);
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
let tls_stream = tls_connector
.connect(sni, tcp_stream)
.await
.with_context(|| {
format!(
"failed to do TLS handshake with the server {:?}",
server_cfg.remote_addr
)
})?;
Ok(tls_stream)
}