From 6375e14185b975a79d1c6efe67cb5ecaa996ec0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sat, 13 Jan 2024 18:42:15 +0100 Subject: [PATCH] Prep work for new transport --- src/main.rs | 41 ++++++++------ src/tls.rs | 41 +++++++++----- src/tunnel/client.rs | 48 ++++++++++++++++- src/tunnel/mod.rs | 125 +++++++++++++++++++++++++++++++++++++++---- 4 files changed, 215 insertions(+), 40 deletions(-) diff --git a/src/main.rs b/src/main.rs index 68b2661..b448c7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts}; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; use log::{debug, warn}; +use once_cell::sync::Lazy; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; @@ -34,11 +35,12 @@ use tokio::net::TcpStream; use tokio_rustls::rustls::server::DnsName; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; +use tokio_rustls::TlsConnector; use tracing::{error, info}; use crate::dns::DnsResolver; -use crate::tunnel::{to_host_port, RemoteAddr}; +use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr}; use crate::udp::MyUdpSocket; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; @@ -197,7 +199,7 @@ struct Client { /// Address of the wstunnel server /// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com - #[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)] + #[arg(value_name = "ws[s]|http[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)] remote_addr: Url, } @@ -537,10 +539,11 @@ fn parse_server_url(arg: &str) -> Result { Ok(url) } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TlsClientConfig { pub tls_sni_override: Option, pub tls_verify_certificate: bool, + pub tls_connector: TlsConnector, } #[derive(Debug)] @@ -580,9 +583,8 @@ impl Debug for WsServerConfig { #[derive(Clone)] pub struct WsClientConfig { - pub remote_addr: (Host, u16), + pub remote_addr: TransportAddr, pub socket_so_mark: Option, - pub tls: Option, pub http_upgrade_path_prefix: String, pub http_upgrade_credentials: Option, pub http_headers: HashMap, @@ -597,9 +599,9 @@ pub struct WsClientConfig { impl WsClientConfig { pub fn websocket_scheme(&self) -> &'static str { - match self.tls { - None => "ws", - Some(_) => "wss", + match self.remote_addr.tls().is_some() { + false => "ws", + true => "wss", } } @@ -608,13 +610,18 @@ impl WsClientConfig { } pub fn websocket_host_url(&self) -> String { - format!("{}:{}", self.remote_addr.0, self.remote_addr.1) + format!("{}:{}", self.remote_addr.host(), self.remote_addr.port()) } pub fn tls_server_name(&self) -> ServerName { - match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) { - None => match &self.remote_addr.0 { - Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()), + static INVALID_DNS_NAME: Lazy = + Lazy::new(|| DnsName::try_from_ascii(b"dns-name-invalid.com").unwrap()); + + match self.remote_addr.tls().and_then(|tls| tls.tls_sni_override.as_ref()) { + None => match &self.remote_addr.host() { + Host::Domain(domain) => { + ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone())) + } Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)), Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)), }, @@ -654,6 +661,8 @@ async fn main() { let tls = match args.remote_addr.scheme() { "ws" => None, "wss" => Some(TlsClientConfig { + tls_connector: tls::tls_connector(args.tls_verify_certificate, Some(vec![b"http/1.1".to_vec()])) + .expect("Cannot create tls connector"), tls_sni_override: args.tls_sni_override, tls_verify_certificate: args.tls_verify_certificate, }), @@ -671,12 +680,14 @@ async fn main() { HeaderValue::from_str(&host).unwrap() }; let mut client_config = WsClientConfig { - remote_addr: ( + remote_addr: TransportAddr::from_str( + args.remote_addr.scheme(), args.remote_addr.host().unwrap().to_owned(), args.remote_addr.port_or_known_default().unwrap(), - ), + tls, + ) + .unwrap(), socket_so_mark: args.socket_so_mark, - tls, http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_credentials: args.http_upgrade_credentials, http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(), diff --git a/src/tls.rs b/src/tls.rs index b0ad955..7f5b247 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,4 +1,4 @@ -use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig}; +use crate::{TlsServerConfig, WsClientConfig}; use anyhow::{anyhow, Context}; use std::fs::File; @@ -11,6 +11,7 @@ use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier}; +use crate::tunnel::TransportAddr; use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName}; use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; use tracing::info; @@ -62,7 +63,10 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result { Ok(PrivateKey(private_key.secret_der().to_vec())) } -fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option>>) -> anyhow::Result { +pub fn tls_connector( + tls_verify_certificate: bool, + alpn_protocols: Option>>, +) -> anyhow::Result { let mut root_store = rustls::RootCertStore::empty(); // Load system certificates and add them to the root store @@ -77,7 +81,7 @@ fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option>> .with_no_client_auth(); // To bypass certificate verification - if !tls_cfg.tls_verify_certificate { + if !tls_verify_certificate { config.dangerous().set_certificate_verifier(Arc::new(NullVerifier)); } @@ -101,22 +105,31 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option anyhow::Result> { +pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result> { let sni = client_cfg.tls_server_name(); info!( "Doing TLS handshake using sni {sni:?} with the server {}:{}", - client_cfg.remote_addr.0, client_cfg.remote_addr.1 + client_cfg.remote_addr.host(), + client_cfg.remote_addr.port() ); - let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?; - let tls_stream = tls_connector - .connect(sni, tcp_stream) - .await - .with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?; + let tls_connector = match &client_cfg.remote_addr { + TransportAddr::WSS { tls, .. } => &tls.tls_connector, + TransportAddr::HTTPS { tls, .. } => &tls.tls_connector, + TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => { + return Err(anyhow!( + "Transport does not support TLS: {}", + client_cfg.remote_addr.scheme_name() + )) + } + }; + let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| { + format!( + "failed to do TLS handshake with the server {}:{}", + client_cfg.remote_addr.host(), + client_cfg.remote_addr.port() + ) + })?; Ok(tls_stream) } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index c0b2d56..06e27b5 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -24,7 +24,7 @@ use tracing::{error, span, Instrument, Level, Span}; use url::Host; use uuid::Uuid; -pub async fn connect( +async fn connect( request_id: Uuid, client_cfg: &WsClientConfig, dest_addr: &RemoteAddr, @@ -70,6 +70,52 @@ pub async fn connect( Ok((ws, response)) } +//async fn connect_http2( +// request_id: Uuid, +// client_cfg: &WsClientConfig, +// dest_addr: &RemoteAddr, +//) -> anyhow::Result> { +// let mut pooled_cnx = match client_cfg.cnx_pool().get().await { +// Ok(cnx) => Ok(cnx), +// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), +// }?; +// +// let mut req = Request::builder() +// .method("GET") +// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix)) +// .header(HOST, &client_cfg.http_header_host) +// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr)) +// .version(hyper::Version::HTTP_2); +// +// for (k, v) in &client_cfg.http_headers { +// req = req.header(k, v); +// } +// if let Some(auth) = &client_cfg.http_upgrade_credentials { +// req = req.header(AUTHORIZATION, auth); +// } +// +// let x: Vec = vec![]; +// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))])); +// let req = req.body(Empty::::new()).with_context(|| { +// format!( +// "failed to build HTTP request to contact the server {:?}", +// client_cfg.remote_addr +// ) +// })?; +// debug!("with HTTP upgrade request {:?}", req); +// let transport = pooled_cnx.deref_mut().take().unwrap(); +// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await +// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?; +// tokio::spawn(cnx); +// +// let response = request_sender.send_request(req) +// .await +// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?; +// +// // TODO: verify response is ok +// Ok(BodyStream::new(response.into_body())) +//} + async fn connect_to_server( request_id: Uuid, client_cfg: &WsClientConfig, diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 6c7cdc5..b8d8225 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -3,13 +3,14 @@ mod io; pub mod server; mod tls_reloader; -use crate::{tcp, tls, LocalProtocol, WsClientConfig}; +use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig}; use async_trait::async_trait; use bb8::ManageConnection; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; use std::io::{Error, IoSlice}; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; @@ -77,6 +78,97 @@ pub struct RemoteAddr { pub port: u16, } +#[derive(Clone)] +pub enum TransportAddr { + WSS { + tls: TlsClientConfig, + host: Host, + port: u16, + }, + WS { + host: Host, + port: u16, + }, + HTTPS { + tls: TlsClientConfig, + host: Host, + port: u16, + }, + HTTP { + host: Host, + port: u16, + }, +} + +impl Debug for TransportAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}://{}:{}", self.scheme_name(), self.host(), self.port())) + } +} + +impl TransportAddr { + pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option) -> Option { + match scheme { + "https" => { + let Some(tls) = tls else { return None }; + + Some(TransportAddr::HTTPS { tls, host, port }) + } + "http" => Some(TransportAddr::HTTP { host, port }), + "wss" => { + let Some(tls) = tls else { return None }; + + Some(TransportAddr::WSS { tls, host, port }) + } + "ws" => Some(TransportAddr::WS { host, port }), + _ => None, + } + } + pub fn is_websocket(&self) -> bool { + matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. }) + } + + pub fn is_http2(&self) -> bool { + matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. }) + } + + pub fn tls(&self) -> Option<&TlsClientConfig> { + match self { + TransportAddr::WSS { tls, .. } => Some(tls), + TransportAddr::HTTPS { tls, .. } => Some(tls), + TransportAddr::WS { .. } => None, + TransportAddr::HTTP { .. } => None, + } + } + + pub fn host(&self) -> &Host { + match self { + TransportAddr::WSS { host, .. } => host, + TransportAddr::WS { host, .. } => host, + TransportAddr::HTTPS { host, .. } => host, + TransportAddr::HTTP { host, .. } => host, + } + } + + pub fn port(&self) -> u16 { + match self { + TransportAddr::WSS { port, .. } => *port, + TransportAddr::WS { port, .. } => *port, + TransportAddr::HTTPS { port, .. } => *port, + TransportAddr::HTTP { port, .. } => *port, + } + } + + pub fn scheme_name(&self) -> &str { + match self { + TransportAddr::WSS { .. } => "wss", + TransportAddr::WS { .. } => "ws", + TransportAddr::HTTPS { .. } => "https", + TransportAddr::HTTP { .. } => "http", + } + } +} + impl TryFrom for RemoteAddr { type Error = anyhow::Error; fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result { @@ -150,22 +242,35 @@ impl ManageConnection for WsClientConfig { #[instrument(level = "trace", name = "cnx_server", skip_all)] async fn connect(&self) -> Result { - let (host, port) = &self.remote_addr; let so_mark = self.socket_so_mark; let timeout = self.timeout_connect; let tcp_stream = if let Some(http_proxy) = &self.http_proxy { - tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout, &self.dns_resolver).await? + tcp::connect_with_http_proxy( + http_proxy, + self.remote_addr.host(), + self.remote_addr.port(), + so_mark, + timeout, + &self.dns_resolver, + ) + .await? } else { - tcp::connect(host, *port, so_mark, timeout, &self.dns_resolver).await? + tcp::connect( + self.remote_addr.host(), + self.remote_addr.port(), + so_mark, + timeout, + &self.dns_resolver, + ) + .await? }; - match &self.tls { - None => Ok(Some(TransportStream::Plain(tcp_stream))), - Some(tls_cfg) => { - let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?; - Ok(Some(TransportStream::Tls(tls_stream))) - } + if self.remote_addr.tls().is_some() { + let tls_stream = tls::connect(self, tcp_stream).await?; + Ok(Some(TransportStream::Tls(tls_stream))) + } else { + Ok(Some(TransportStream::Plain(tcp_stream))) } }