Prep work for new transport
This commit is contained in:
parent
62f6a0287d
commit
6375e14185
4 changed files with 215 additions and 40 deletions
41
src/main.rs
41
src/main.rs
|
@ -18,6 +18,7 @@ use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
|
|||
use hyper::header::HOST;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use log::{debug, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
@ -34,11 +35,12 @@ use tokio::net::TcpStream;
|
|||
|
||||
use tokio_rustls::rustls::server::DnsName;
|
||||
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::dns::DnsResolver;
|
||||
use crate::tunnel::{to_host_port, RemoteAddr};
|
||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr};
|
||||
use crate::udp::MyUdpSocket;
|
||||
use tracing_subscriber::filter::Directive;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
@ -197,7 +199,7 @@ struct Client {
|
|||
|
||||
/// Address of the wstunnel server
|
||||
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
|
||||
#[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
|
||||
#[arg(value_name = "ws[s]|http[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
|
||||
remote_addr: Url,
|
||||
}
|
||||
|
||||
|
@ -537,10 +539,11 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
|||
Ok(url)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct TlsClientConfig {
|
||||
pub tls_sni_override: Option<DnsName>,
|
||||
pub tls_verify_certificate: bool,
|
||||
pub tls_connector: TlsConnector,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -580,9 +583,8 @@ impl Debug for WsServerConfig {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClientConfig {
|
||||
pub remote_addr: (Host<String>, u16),
|
||||
pub remote_addr: TransportAddr,
|
||||
pub socket_so_mark: Option<u32>,
|
||||
pub tls: Option<TlsClientConfig>,
|
||||
pub http_upgrade_path_prefix: String,
|
||||
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
||||
|
@ -597,9 +599,9 @@ pub struct WsClientConfig {
|
|||
|
||||
impl WsClientConfig {
|
||||
pub fn websocket_scheme(&self) -> &'static str {
|
||||
match self.tls {
|
||||
None => "ws",
|
||||
Some(_) => "wss",
|
||||
match self.remote_addr.tls().is_some() {
|
||||
false => "ws",
|
||||
true => "wss",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,13 +610,18 @@ impl WsClientConfig {
|
|||
}
|
||||
|
||||
pub fn websocket_host_url(&self) -> String {
|
||||
format!("{}:{}", self.remote_addr.0, self.remote_addr.1)
|
||||
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
|
||||
}
|
||||
|
||||
pub fn tls_server_name(&self) -> ServerName {
|
||||
match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
|
||||
None => match &self.remote_addr.0 {
|
||||
Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
|
||||
static INVALID_DNS_NAME: Lazy<DnsName> =
|
||||
Lazy::new(|| DnsName::try_from_ascii(b"dns-name-invalid.com").unwrap());
|
||||
|
||||
match self.remote_addr.tls().and_then(|tls| tls.tls_sni_override.as_ref()) {
|
||||
None => match &self.remote_addr.host() {
|
||||
Host::Domain(domain) => {
|
||||
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()))
|
||||
}
|
||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
|
||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
|
||||
},
|
||||
|
@ -654,6 +661,8 @@ async fn main() {
|
|||
let tls = match args.remote_addr.scheme() {
|
||||
"ws" => None,
|
||||
"wss" => Some(TlsClientConfig {
|
||||
tls_connector: tls::tls_connector(args.tls_verify_certificate, Some(vec![b"http/1.1".to_vec()]))
|
||||
.expect("Cannot create tls connector"),
|
||||
tls_sni_override: args.tls_sni_override,
|
||||
tls_verify_certificate: args.tls_verify_certificate,
|
||||
}),
|
||||
|
@ -671,12 +680,14 @@ async fn main() {
|
|||
HeaderValue::from_str(&host).unwrap()
|
||||
};
|
||||
let mut client_config = WsClientConfig {
|
||||
remote_addr: (
|
||||
remote_addr: TransportAddr::from_str(
|
||||
args.remote_addr.scheme(),
|
||||
args.remote_addr.host().unwrap().to_owned(),
|
||||
args.remote_addr.port_or_known_default().unwrap(),
|
||||
),
|
||||
tls,
|
||||
)
|
||||
.unwrap(),
|
||||
socket_so_mark: args.socket_so_mark,
|
||||
tls,
|
||||
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
||||
http_upgrade_credentials: args.http_upgrade_credentials,
|
||||
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
|
||||
|
|
41
src/tls.rs
41
src/tls.rs
|
@ -1,4 +1,4 @@
|
|||
use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
|
||||
use crate::{TlsServerConfig, WsClientConfig};
|
||||
use anyhow::{anyhow, Context};
|
||||
use std::fs::File;
|
||||
|
||||
|
@ -11,6 +11,7 @@ use tokio::net::TcpStream;
|
|||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
|
||||
use crate::tunnel::TransportAddr;
|
||||
use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName};
|
||||
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
|
||||
use tracing::info;
|
||||
|
@ -62,7 +63,10 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
|
|||
Ok(PrivateKey(private_key.secret_der().to_vec()))
|
||||
}
|
||||
|
||||
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
|
||||
pub fn tls_connector(
|
||||
tls_verify_certificate: bool,
|
||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load system certificates and add them to the root store
|
||||
|
@ -77,7 +81,7 @@ fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>
|
|||
.with_no_client_auth();
|
||||
|
||||
// To bypass certificate verification
|
||||
if !tls_cfg.tls_verify_certificate {
|
||||
if !tls_verify_certificate {
|
||||
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
|
||||
}
|
||||
|
||||
|
@ -101,22 +105,31 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
|
|||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
client_cfg: &WsClientConfig,
|
||||
tls_cfg: &TlsClientConfig,
|
||||
tcp_stream: TcpStream,
|
||||
) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
let sni = client_cfg.tls_server_name();
|
||||
info!(
|
||||
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
|
||||
client_cfg.remote_addr.0, client_cfg.remote_addr.1
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port()
|
||||
);
|
||||
|
||||
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
|
||||
let tls_stream = tls_connector
|
||||
.connect(sni, tcp_stream)
|
||||
.await
|
||||
.with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
let tls_connector = match &client_cfg.remote_addr {
|
||||
TransportAddr::WSS { tls, .. } => &tls.tls_connector,
|
||||
TransportAddr::HTTPS { tls, .. } => &tls.tls_connector,
|
||||
TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => {
|
||||
return Err(anyhow!(
|
||||
"Transport does not support TLS: {}",
|
||||
client_cfg.remote_addr.scheme_name()
|
||||
))
|
||||
}
|
||||
};
|
||||
let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| {
|
||||
format!(
|
||||
"failed to do TLS handshake with the server {}:{}",
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port()
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ use tracing::{error, span, Instrument, Level, Span};
|
|||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn connect(
|
||||
async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
dest_addr: &RemoteAddr,
|
||||
|
@ -70,6 +70,52 @@ pub async fn connect(
|
|||
Ok((ws, response))
|
||||
}
|
||||
|
||||
//async fn connect_http2(
|
||||
// request_id: Uuid,
|
||||
// client_cfg: &WsClientConfig,
|
||||
// dest_addr: &RemoteAddr,
|
||||
//) -> anyhow::Result<BodyStream<Incoming>> {
|
||||
// let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
// Ok(cnx) => Ok(cnx),
|
||||
// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
// }?;
|
||||
//
|
||||
// let mut req = Request::builder()
|
||||
// .method("GET")
|
||||
// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
|
||||
// .header(HOST, &client_cfg.http_header_host)
|
||||
// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
||||
// .version(hyper::Version::HTTP_2);
|
||||
//
|
||||
// for (k, v) in &client_cfg.http_headers {
|
||||
// req = req.header(k, v);
|
||||
// }
|
||||
// if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||
// req = req.header(AUTHORIZATION, auth);
|
||||
// }
|
||||
//
|
||||
// let x: Vec<u8> = vec![];
|
||||
// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))]));
|
||||
// let req = req.body(Empty::<Bytes>::new()).with_context(|| {
|
||||
// format!(
|
||||
// "failed to build HTTP request to contact the server {:?}",
|
||||
// client_cfg.remote_addr
|
||||
// )
|
||||
// })?;
|
||||
// debug!("with HTTP upgrade request {:?}", req);
|
||||
// let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await
|
||||
// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
// tokio::spawn(cnx);
|
||||
//
|
||||
// let response = request_sender.send_request(req)
|
||||
// .await
|
||||
// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
|
||||
//
|
||||
// // TODO: verify response is ok
|
||||
// Ok(BodyStream::new(response.into_body()))
|
||||
//}
|
||||
|
||||
async fn connect_to_server<R, W>(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
|
|
|
@ -3,13 +3,14 @@ mod io;
|
|||
pub mod server;
|
||||
mod tls_reloader;
|
||||
|
||||
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
|
||||
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use bb8::ManageConnection;
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::io::{Error, IoSlice};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::ops::Deref;
|
||||
|
@ -77,6 +78,97 @@ pub struct RemoteAddr {
|
|||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum TransportAddr {
|
||||
WSS {
|
||||
tls: TlsClientConfig,
|
||||
host: Host,
|
||||
port: u16,
|
||||
},
|
||||
WS {
|
||||
host: Host,
|
||||
port: u16,
|
||||
},
|
||||
HTTPS {
|
||||
tls: TlsClientConfig,
|
||||
host: Host,
|
||||
port: u16,
|
||||
},
|
||||
HTTP {
|
||||
host: Host,
|
||||
port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
impl Debug for TransportAddr {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("{}://{}:{}", self.scheme_name(), self.host(), self.port()))
|
||||
}
|
||||
}
|
||||
|
||||
impl TransportAddr {
|
||||
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
|
||||
match scheme {
|
||||
"https" => {
|
||||
let Some(tls) = tls else { return None };
|
||||
|
||||
Some(TransportAddr::HTTPS { tls, host, port })
|
||||
}
|
||||
"http" => Some(TransportAddr::HTTP { host, port }),
|
||||
"wss" => {
|
||||
let Some(tls) = tls else { return None };
|
||||
|
||||
Some(TransportAddr::WSS { tls, host, port })
|
||||
}
|
||||
"ws" => Some(TransportAddr::WS { host, port }),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub fn is_websocket(&self) -> bool {
|
||||
matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. })
|
||||
}
|
||||
|
||||
pub fn is_http2(&self) -> bool {
|
||||
matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. })
|
||||
}
|
||||
|
||||
pub fn tls(&self) -> Option<&TlsClientConfig> {
|
||||
match self {
|
||||
TransportAddr::WSS { tls, .. } => Some(tls),
|
||||
TransportAddr::HTTPS { tls, .. } => Some(tls),
|
||||
TransportAddr::WS { .. } => None,
|
||||
TransportAddr::HTTP { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn host(&self) -> &Host {
|
||||
match self {
|
||||
TransportAddr::WSS { host, .. } => host,
|
||||
TransportAddr::WS { host, .. } => host,
|
||||
TransportAddr::HTTPS { host, .. } => host,
|
||||
TransportAddr::HTTP { host, .. } => host,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn port(&self) -> u16 {
|
||||
match self {
|
||||
TransportAddr::WSS { port, .. } => *port,
|
||||
TransportAddr::WS { port, .. } => *port,
|
||||
TransportAddr::HTTPS { port, .. } => *port,
|
||||
TransportAddr::HTTP { port, .. } => *port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scheme_name(&self) -> &str {
|
||||
match self {
|
||||
TransportAddr::WSS { .. } => "wss",
|
||||
TransportAddr::WS { .. } => "ws",
|
||||
TransportAddr::HTTPS { .. } => "https",
|
||||
TransportAddr::HTTP { .. } => "http",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
|
||||
|
@ -150,22 +242,35 @@ impl ManageConnection for WsClientConfig {
|
|||
|
||||
#[instrument(level = "trace", name = "cnx_server", skip_all)]
|
||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||
let (host, port) = &self.remote_addr;
|
||||
let so_mark = self.socket_so_mark;
|
||||
let timeout = self.timeout_connect;
|
||||
|
||||
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
|
||||
tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout, &self.dns_resolver).await?
|
||||
tcp::connect_with_http_proxy(
|
||||
http_proxy,
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
tcp::connect(host, *port, so_mark, timeout, &self.dns_resolver).await?
|
||||
tcp::connect(
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
match &self.tls {
|
||||
None => Ok(Some(TransportStream::Plain(tcp_stream))),
|
||||
Some(tls_cfg) => {
|
||||
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
|
||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||
}
|
||||
if self.remote_addr.tls().is_some() {
|
||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||
} else {
|
||||
Ok(Some(TransportStream::Plain(tcp_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue