Prep work for new transport

This commit is contained in:
Σrebe - Romain GERARD 2024-01-13 18:42:15 +01:00
parent 62f6a0287d
commit 6375e14185
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 215 additions and 40 deletions

View file

@ -18,6 +18,7 @@ use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use log::{debug, warn};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
@ -34,11 +35,12 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::server::DnsName;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
use tokio_rustls::TlsConnector;
use tracing::{error, info};
use crate::dns::DnsResolver;
use crate::tunnel::{to_host_port, RemoteAddr};
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr};
use crate::udp::MyUdpSocket;
use tracing_subscriber::filter::Directive;
use tracing_subscriber::EnvFilter;
@ -197,7 +199,7 @@ struct Client {
/// Address of the wstunnel server
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
#[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
#[arg(value_name = "ws[s]|http[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
remote_addr: Url,
}
@ -537,10 +539,11 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
Ok(url)
}
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct TlsClientConfig {
pub tls_sni_override: Option<DnsName>,
pub tls_verify_certificate: bool,
pub tls_connector: TlsConnector,
}
#[derive(Debug)]
@ -580,9 +583,8 @@ impl Debug for WsServerConfig {
#[derive(Clone)]
pub struct WsClientConfig {
pub remote_addr: (Host<String>, u16),
pub remote_addr: TransportAddr,
pub socket_so_mark: Option<u32>,
pub tls: Option<TlsClientConfig>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>,
@ -597,9 +599,9 @@ pub struct WsClientConfig {
impl WsClientConfig {
pub fn websocket_scheme(&self) -> &'static str {
match self.tls {
None => "ws",
Some(_) => "wss",
match self.remote_addr.tls().is_some() {
false => "ws",
true => "wss",
}
}
@ -608,13 +610,18 @@ impl WsClientConfig {
}
pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.0, self.remote_addr.1)
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
}
pub fn tls_server_name(&self) -> ServerName {
match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
None => match &self.remote_addr.0 {
Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
static INVALID_DNS_NAME: Lazy<DnsName> =
Lazy::new(|| DnsName::try_from_ascii(b"dns-name-invalid.com").unwrap());
match self.remote_addr.tls().and_then(|tls| tls.tls_sni_override.as_ref()) {
None => match &self.remote_addr.host() {
Host::Domain(domain) => {
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()))
}
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
},
@ -654,6 +661,8 @@ async fn main() {
let tls = match args.remote_addr.scheme() {
"ws" => None,
"wss" => Some(TlsClientConfig {
tls_connector: tls::tls_connector(args.tls_verify_certificate, Some(vec![b"http/1.1".to_vec()]))
.expect("Cannot create tls connector"),
tls_sni_override: args.tls_sni_override,
tls_verify_certificate: args.tls_verify_certificate,
}),
@ -671,12 +680,14 @@ async fn main() {
HeaderValue::from_str(&host).unwrap()
};
let mut client_config = WsClientConfig {
remote_addr: (
remote_addr: TransportAddr::from_str(
args.remote_addr.scheme(),
args.remote_addr.host().unwrap().to_owned(),
args.remote_addr.port_or_known_default().unwrap(),
),
tls,
)
.unwrap(),
socket_so_mark: args.socket_so_mark,
tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials,
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),

View file

@ -1,4 +1,4 @@
use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
use crate::{TlsServerConfig, WsClientConfig};
use anyhow::{anyhow, Context};
use std::fs::File;
@ -11,6 +11,7 @@ use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
use crate::tunnel::TransportAddr;
use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName};
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use tracing::info;
@ -62,7 +63,10 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
Ok(PrivateKey(private_key.secret_der().to_vec()))
}
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
pub fn tls_connector(
tls_verify_certificate: bool,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store
@ -77,7 +81,7 @@ fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>
.with_no_client_auth();
// To bypass certificate verification
if !tls_cfg.tls_verify_certificate {
if !tls_verify_certificate {
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
}
@ -101,22 +105,31 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn connect(
client_cfg: &WsClientConfig,
tls_cfg: &TlsClientConfig,
tcp_stream: TcpStream,
) -> anyhow::Result<TlsStream<TcpStream>> {
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = client_cfg.tls_server_name();
info!(
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
client_cfg.remote_addr.0, client_cfg.remote_addr.1
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port()
);
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
let tls_stream = tls_connector
.connect(sni, tcp_stream)
.await
.with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
let tls_connector = match &client_cfg.remote_addr {
TransportAddr::WSS { tls, .. } => &tls.tls_connector,
TransportAddr::HTTPS { tls, .. } => &tls.tls_connector,
TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => {
return Err(anyhow!(
"Transport does not support TLS: {}",
client_cfg.remote_addr.scheme_name()
))
}
};
let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| {
format!(
"failed to do TLS handshake with the server {}:{}",
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port()
)
})?;
Ok(tls_stream)
}

View file

@ -24,7 +24,7 @@ use tracing::{error, span, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
pub async fn connect(
async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
dest_addr: &RemoteAddr,
@ -70,6 +70,52 @@ pub async fn connect(
Ok((ws, response))
}
//async fn connect_http2(
// request_id: Uuid,
// client_cfg: &WsClientConfig,
// dest_addr: &RemoteAddr,
//) -> anyhow::Result<BodyStream<Incoming>> {
// let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
// Ok(cnx) => Ok(cnx),
// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
// }?;
//
// let mut req = Request::builder()
// .method("GET")
// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
// .header(HOST, &client_cfg.http_header_host)
// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
// .version(hyper::Version::HTTP_2);
//
// for (k, v) in &client_cfg.http_headers {
// req = req.header(k, v);
// }
// if let Some(auth) = &client_cfg.http_upgrade_credentials {
// req = req.header(AUTHORIZATION, auth);
// }
//
// let x: Vec<u8> = vec![];
// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))]));
// let req = req.body(Empty::<Bytes>::new()).with_context(|| {
// format!(
// "failed to build HTTP request to contact the server {:?}",
// client_cfg.remote_addr
// )
// })?;
// debug!("with HTTP upgrade request {:?}", req);
// let transport = pooled_cnx.deref_mut().take().unwrap();
// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await
// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
// tokio::spawn(cnx);
//
// let response = request_sender.send_request(req)
// .await
// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
//
// // TODO: verify response is ok
// Ok(BodyStream::new(response.into_body()))
//}
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,

View file

@ -3,13 +3,14 @@ mod io;
pub mod server;
mod tls_reloader;
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::io::{Error, IoSlice};
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
@ -77,6 +78,97 @@ pub struct RemoteAddr {
pub port: u16,
}
#[derive(Clone)]
pub enum TransportAddr {
WSS {
tls: TlsClientConfig,
host: Host,
port: u16,
},
WS {
host: Host,
port: u16,
},
HTTPS {
tls: TlsClientConfig,
host: Host,
port: u16,
},
HTTP {
host: Host,
port: u16,
},
}
impl Debug for TransportAddr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}://{}:{}", self.scheme_name(), self.host(), self.port()))
}
}
impl TransportAddr {
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
match scheme {
"https" => {
let Some(tls) = tls else { return None };
Some(TransportAddr::HTTPS { tls, host, port })
}
"http" => Some(TransportAddr::HTTP { host, port }),
"wss" => {
let Some(tls) = tls else { return None };
Some(TransportAddr::WSS { tls, host, port })
}
"ws" => Some(TransportAddr::WS { host, port }),
_ => None,
}
}
pub fn is_websocket(&self) -> bool {
matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. })
}
pub fn is_http2(&self) -> bool {
matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. })
}
pub fn tls(&self) -> Option<&TlsClientConfig> {
match self {
TransportAddr::WSS { tls, .. } => Some(tls),
TransportAddr::HTTPS { tls, .. } => Some(tls),
TransportAddr::WS { .. } => None,
TransportAddr::HTTP { .. } => None,
}
}
pub fn host(&self) -> &Host {
match self {
TransportAddr::WSS { host, .. } => host,
TransportAddr::WS { host, .. } => host,
TransportAddr::HTTPS { host, .. } => host,
TransportAddr::HTTP { host, .. } => host,
}
}
pub fn port(&self) -> u16 {
match self {
TransportAddr::WSS { port, .. } => *port,
TransportAddr::WS { port, .. } => *port,
TransportAddr::HTTPS { port, .. } => *port,
TransportAddr::HTTP { port, .. } => *port,
}
}
pub fn scheme_name(&self) -> &str {
match self {
TransportAddr::WSS { .. } => "wss",
TransportAddr::WS { .. } => "ws",
TransportAddr::HTTPS { .. } => "https",
TransportAddr::HTTP { .. } => "http",
}
}
}
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
@ -150,22 +242,35 @@ impl ManageConnection for WsClientConfig {
#[instrument(level = "trace", name = "cnx_server", skip_all)]
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let (host, port) = &self.remote_addr;
let so_mark = self.socket_so_mark;
let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout, &self.dns_resolver).await?
tcp::connect_with_http_proxy(
http_proxy,
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
} else {
tcp::connect(host, *port, so_mark, timeout, &self.dns_resolver).await?
tcp::connect(
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
};
match &self.tls {
None => Ok(Some(TransportStream::Plain(tcp_stream))),
Some(tls_cfg) => {
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
}
if self.remote_addr.tls().is_some() {
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
} else {
Ok(Some(TransportStream::Plain(tcp_stream)))
}
}