Allows to load EC and RSA encoded tls private key

This commit is contained in:
Σrebe - Romain GERARD 2023-12-04 18:21:55 +01:00
parent fffec24c99
commit b64b0bb70b
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
9 changed files with 214 additions and 124 deletions

View file

@ -1,15 +1,26 @@
use log::info;
use once_cell::sync::Lazy;
use tokio_rustls::rustls::{Certificate, PrivateKey};
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
info!("Loading embedded tls private key");
let key = include_bytes!("../certs/key.pem");
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key");
PrivateKey(keys.remove(0))
let key = rustls_pemfile::private_key(&mut key.as_slice())
.expect("failed to load embedded tls private key")
.expect("failed to load embedded tls private key");
PrivateKey(key.secret_der().to_vec())
});
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
let cert = include_bytes!("../certs/cert.pem");
let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate");
info!("Loading embedded tls certificate");
certs.into_iter().map(Certificate).collect()
let cert = include_bytes!("../certs/cert.pem");
let certs = rustls_pemfile::certs(&mut cert.as_slice())
.next()
.expect("failed to load embedded tls certificate");
certs
.into_iter()
.map(|cert| Certificate(cert.as_ref().to_vec()))
.collect()
});

View file

@ -77,7 +77,7 @@ struct Client {
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
#[arg(long, value_name = "INT", verbatim_doc_comment)]
socket_so_mark: Option<i32>,
socket_so_mark: Option<u32>,
/// Client will maintain a pool of open connection to the server, in order to speed up the connection process.
/// This option set the maximum number of connection that will be kept open.
@ -141,7 +141,7 @@ struct Server {
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
#[arg(long, value_name = "INT", verbatim_doc_comment)]
socket_so_mark: Option<i32>,
socket_so_mark: Option<u32>,
/// Frequency at which the server will send websocket ping to client.
#[arg(long, value_name = "seconds", value_parser = parse_duration_sec, verbatim_doc_comment)]
@ -438,7 +438,7 @@ pub struct TlsServerConfig {
#[derive(Clone)]
pub struct WsServerConfig {
pub socket_so_mark: Option<i32>,
pub socket_so_mark: Option<u32>,
pub bind: SocketAddr,
pub restrict_to: Option<Vec<String>>,
pub restrict_http_upgrade_path_prefix: Option<Vec<String>>,
@ -465,7 +465,7 @@ impl Debug for WsServerConfig {
#[derive(Clone, Debug)]
pub struct WsClientConfig {
pub remote_addr: (Host<String>, u16),
pub socket_so_mark: Option<i32>,
pub socket_so_mark: Option<u32>,
pub tls: Option<TlsClientConfig>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,

View file

@ -3,8 +3,8 @@ use tokio_fd::AsyncFd;
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
eprintln!("Starting STDIO server");
let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?;
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
Ok((stdin, stdout))
}

View file

@ -14,27 +14,22 @@ use tracing::debug;
use tracing::log::info;
use url::{Host, Url};
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<u32>) -> Result<(), anyhow::Error> {
socket
.set_nodelay(true)
.with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?;
#[cfg(target_os = "linux")]
if let Some(so_mark) = so_mark {
use std::os::fd::AsRawFd;
unsafe {
let optval: libc::c_int = *so_mark;
let ret = libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_MARK,
&optval as *const _ as *const libc::c_void,
std::mem::size_of_val(&optval) as libc::socklen_t,
);
use std::os::fd::AsFd;
if ret != 0 {
return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error()));
}
let ret = nix::sys::socket::setsockopt(&socket.as_fd(), nix::sys::socket::sockopt::Mark, so_mark);
if let Err(err) = ret {
return Err(anyhow!(
"Cannot set SO_MARK on the connection {:?} {:?}",
err,
io::Error::last_os_error()
));
}
}
@ -44,7 +39,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(),
pub async fn connect(
host: &Host<String>,
port: u16,
so_mark: Option<i32>,
so_mark: Option<u32>,
connect_timeout: Duration,
) -> Result<TcpStream, anyhow::Error> {
info!("Opening TCP connection to {}:{}", host, port);
@ -103,7 +98,7 @@ pub async fn connect_with_http_proxy(
proxy: &Url,
host: &Host<String>,
port: u16,
so_mark: Option<i32>,
so_mark: Option<u32>,
connect_timeout: Duration,
) -> Result<TcpStream, anyhow::Error> {
let proxy_host = proxy.host().context("Cannot parse proxy host")?.to_owned();

View file

@ -2,6 +2,7 @@ use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
use anyhow::{anyhow, Context};
use std::fs::File;
use log::warn;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
@ -30,23 +31,35 @@ impl ServerCertVerifier for NullVerifier {
}
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<Certificate>> {
info!("Loading tls certificate from {:?}", path);
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)?;
let certs = rustls_pemfile::certs(&mut reader);
Ok(certs.into_iter().map(Certificate).collect())
Ok(certs
.into_iter()
.filter_map(|cert| match cert {
Ok(cert) => Some(Certificate(cert.to_vec())),
Err(err) => {
warn!("Error while parsing tls certificate: {:?}", err);
None
}
})
.collect())
}
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
info!("Loading tls private key from {:?}", path);
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
match keys.len() {
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
1 => Ok(PrivateKey(keys.remove(0))),
_ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")),
}
let Some(private_key) = rustls_pemfile::private_key(&mut reader)? else {
return Err(anyhow!("No private key found in {path:?}"));
};
Ok(PrivateKey(private_key.secret_der().to_vec()))
}
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
@ -55,7 +68,7 @@ fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>
// Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
for cert in certs {
root_store.add(&Certificate(cert.0))?;
root_store.add(&Certificate(cert.as_ref().to_vec()))?;
}
let mut config = ClientConfig::builder()

View file

@ -181,23 +181,19 @@ where
ws.set_auto_apply_mask(client_config.websocket_mask_frame);
// Connect to endpoint
let remote: (Host, u16) = response
let remote = response
.headers()
.get(COOKIE)
.and_then(|h| {
h.to_str()
.ok()
.and_then(|s| base64::engine::general_purpose::STANDARD.decode(s).ok())
.and_then(|s| Url::parse(&String::from_utf8_lossy(&s)).ok())
.and_then(|url| match (url.host(), url.port()) {
(Some(h), Some(p)) => Some((h.to_owned(), p)),
_ => None,
})
.and_then(|h| h.to_str().ok())
.and_then(|h| base64::engine::general_purpose::STANDARD.decode(h).ok())
.and_then(|h| Url::parse(&String::from_utf8_lossy(&h)).ok())
.and_then(|url| match (url.host(), url.port()) {
(Some(h), Some(p)) => Some((h.to_owned(), p)),
_ => None,
})
.unwrap_or(remote_ori.clone());
let stream = connect_to_dest(remote.clone()).instrument(span.clone()).await;
let stream = match stream {
let stream = match connect_to_dest(remote.clone()).instrument(span.clone()).await {
Ok(s) => s,
Err(err) => {
error!("Cannot connect to {remote:?}: {err:?}");

View file

@ -377,7 +377,7 @@ pub fn mk_send_socket_tproxy(listener: &Arc<UdpSocket>) -> anyhow::Result<Arc<Ud
use std::net::IpAddr;
use std::os::fd::AsRawFd;
let mut cmsg_space = cmsg_space!(libc::sockaddr_in6);
let mut cmsg_space = cmsg_space!(nix::libc::sockaddr_in6);
let mut buf = [0; 8];
let mut io = [IoSliceMut::new(&mut buf)];
let msg: RecvMsg<SockaddrIn> = nix::sys::socket::recvmsg(