Refacto: Use proper type for WsClient

This commit is contained in:
Σrebe - Romain GERARD 2024-07-29 23:08:40 +02:00
parent 5e74ed233d
commit a33a889b3d
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
11 changed files with 453 additions and 412 deletions

View file

@ -3,15 +3,23 @@ mod protocols;
mod restrictions;
mod tunnel;
use crate::protocols::dns::DnsResolver;
use crate::protocols::tls;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::client::{TlsClientConfig, WsClient, WsClientConfig};
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
use crate::tunnel::listeners::{
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
};
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use base64::Engine;
use clap::Parser;
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use log::debug;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;
use std::fmt::{Debug, Formatter};
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
@ -21,21 +29,8 @@ use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
use tokio::select;
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer};
use tracing::{error, info};
use crate::protocols::dns::DnsResolver;
use crate::protocols::tls;
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
use crate::tunnel::listeners::{
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
};
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use tracing_subscriber::filter::Directive;
use tracing_subscriber::EnvFilter;
use url::{Host, Url};
@ -695,22 +690,6 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
Ok(url)
}
#[derive(Clone)]
pub struct TlsClientConfig {
pub tls_sni_disabled: bool,
pub tls_sni_override: Option<DnsName<'static>>,
pub tls_verify_certificate: bool,
tls_connector: Arc<RwLock<TlsConnector>>,
pub tls_certificate_path: Option<PathBuf>,
pub tls_key_path: Option<PathBuf>,
}
impl TlsClientConfig {
pub fn tls_connector(&self) -> TlsConnector {
self.tls_connector.read().clone()
}
}
#[derive(Debug)]
pub struct TlsServerConfig {
pub tls_certificate: Mutex<Vec<CertificateDer<'static>>>,
@ -754,59 +733,6 @@ impl Debug for WsServerConfig {
}
}
#[derive(Clone)]
pub struct WsClientConfig {
pub remote_addr: TransportAddr,
pub socket_so_mark: Option<u32>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>,
pub http_headers_file: Option<PathBuf>,
pub http_header_host: HeaderValue,
pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
pub http_proxy: Option<Url>,
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
tls_reloader: Option<Arc<TlsReloader>>,
pub dns_resolver: DnsResolver,
}
impl WsClientConfig {
pub const fn websocket_scheme(&self) -> &'static str {
match self.remote_addr.tls().is_some() {
false => "ws",
true => "wss",
}
}
pub fn cnx_pool(&self) -> &bb8::Pool<Self> {
self.cnx_pool.as_ref().unwrap()
}
pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
}
pub fn tls_server_name(&self) -> ServerName<'static> {
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
self.remote_addr
.tls()
.and_then(|tls| tls.tls_sni_override.as_ref())
.map_or_else(
|| match &self.remote_addr.host() {
Host::Domain(domain) => ServerName::DnsName(
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
),
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
},
|sni_override| ServerName::DnsName(sni_override.clone()),
)
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Wstunnel::parse();
@ -866,24 +792,7 @@ async fn main() -> anyhow::Result<()> {
TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
let tls = match transport_scheme {
TransportScheme::Ws | TransportScheme::Http => None,
TransportScheme::Wss => Some(TlsClientConfig {
tls_connector: Arc::new(RwLock::new(
tls::tls_connector(
args.tls_verify_certificate,
transport_scheme.alpn_protocols(),
!args.tls_sni_disable,
tls_certificate,
tls_key,
)
.expect("Cannot create tls connector"),
)),
tls_sni_override: args.tls_sni_override,
tls_verify_certificate: args.tls_verify_certificate,
tls_sni_disabled: args.tls_sni_disable,
tls_certificate_path: args.tls_certificate.clone(),
tls_key_path: args.tls_private_key.clone(),
}),
TransportScheme::Https => Some(TlsClientConfig {
TransportScheme::Wss | TransportScheme::Https => Some(TlsClientConfig {
tls_connector: Arc::new(RwLock::new(
tls::tls_connector(
args.tls_verify_certificate,
@ -936,7 +845,7 @@ async fn main() -> anyhow::Result<()> {
} else {
None
};
let mut client_config = WsClientConfig {
let client_config = WsClientConfig {
remote_addr: TransportAddr::new(
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
args.remote_addr.host().unwrap().to_owned(),
@ -953,8 +862,6 @@ async fn main() -> anyhow::Result<()> {
timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame,
cnx_pool: None,
tls_reloader: None,
dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver,
http_proxy.clone(),
@ -965,28 +872,16 @@ async fn main() -> anyhow::Result<()> {
http_proxy,
};
let tls_reloader =
TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader");
client_config.tls_reloader = Some(Arc::new(tls_reloader));
let pool = bb8::Pool::builder()
.max_size(1000)
.min_idle(Some(args.connection_min_idle))
.max_lifetime(Some(Duration::from_secs(30)))
.connection_timeout(args.connection_retry_max_backoff_sec)
.retry_connection(true)
.build(client_config.clone())
.await
.unwrap();
client_config.cnx_pool = Some(pool);
let client_config = Arc::new(client_config);
let client =
WsClient::new(client_config, args.connection_min_idle, args.connection_retry_max_backoff_sec).await?;
// Start tunnels
for tunnel in args.remote_to_local.into_iter() {
let client_config = client_config.clone();
let client = client.clone();
match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol: _ } => {
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new(
&tunnel.remote.0,
tunnel.remote.1,
@ -1000,9 +895,7 @@ async fn main() -> anyhow::Result<()> {
host,
port,
};
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
{
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
error!("{:?}", err);
}
});
@ -1011,7 +904,7 @@ async fn main() -> anyhow::Result<()> {
let timeout = *timeout;
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseUdp { timeout },
@ -1026,9 +919,7 @@ async fn main() -> anyhow::Result<()> {
&cfg.dns_resolver,
);
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await
{
if let Err(err) = client.run_reverse_tunnel(remote.clone(), udp_connector).await {
error!("{:?}", err);
}
});
@ -1037,7 +928,7 @@ async fn main() -> anyhow::Result<()> {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseSocks5 { timeout, credentials },
@ -1047,9 +938,7 @@ async fn main() -> anyhow::Result<()> {
let socks_connector =
Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver);
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await
{
if let Err(err) = client.run_reverse_tunnel(remote, socks_connector).await {
error!("{:?}", err);
}
});
@ -1060,7 +949,7 @@ async fn main() -> anyhow::Result<()> {
let credentials = credentials.clone();
let timeout = *timeout;
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let (host, port) = to_host_port(tunnel.local);
let remote = RemoteAddr {
protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials },
@ -1075,9 +964,7 @@ async fn main() -> anyhow::Result<()> {
&cfg.dns_resolver,
);
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await
{
if let Err(err) = client.run_reverse_tunnel(remote.clone(), tcp_connector).await {
error!("{:?}", err);
}
});
@ -1086,7 +973,7 @@ async fn main() -> anyhow::Result<()> {
LocalProtocol::Unix { path } => {
let path = path.clone();
tokio::spawn(async move {
let cfg = client_config.clone();
let cfg = client.config.clone();
let tcp_connector = TcpTunnelConnector::new(
&tunnel.remote.0,
tunnel.remote.1,
@ -1101,9 +988,7 @@ async fn main() -> anyhow::Result<()> {
host,
port,
};
if let Err(err) =
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
{
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
error!("{:?}", err);
}
});
@ -1126,14 +1011,14 @@ async fn main() -> anyhow::Result<()> {
}
for tunnel in args.local_to_remote.into_iter() {
let client_config = client_config.clone();
let client = client.clone();
match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol } => {
let server =
TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1144,7 +1029,7 @@ async fn main() -> anyhow::Result<()> {
let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1154,7 +1039,7 @@ async fn main() -> anyhow::Result<()> {
use crate::tunnel::listeners::UnixTunnelListener;
let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1169,7 +1054,7 @@ async fn main() -> anyhow::Result<()> {
use crate::tunnel::listeners::new_tproxy_udp;
let server = new_tproxy_udp(tunnel.local, *timeout).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1182,7 +1067,7 @@ async fn main() -> anyhow::Result<()> {
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1190,7 +1075,7 @@ async fn main() -> anyhow::Result<()> {
LocalProtocol::Socks5 { timeout, credentials } => {
let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1204,7 +1089,7 @@ async fn main() -> anyhow::Result<()> {
HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
.await?;
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});
@ -1213,7 +1098,7 @@ async fn main() -> anyhow::Result<()> {
LocalProtocol::Stdio => {
let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
if let Err(err) = client.run_tunnel(server).await {
error!("{:?}", err);
}
});

View file

@ -1,4 +1,4 @@
use crate::{TlsServerConfig, WsClientConfig};
use crate::TlsServerConfig;
use anyhow::{anyhow, Context};
use std::fs::File;
@ -9,6 +9,7 @@ use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::TransportAddr;
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};

View file

@ -1,176 +0,0 @@
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
use crate::tunnel::connectors::TunnelConnector;
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
use crate::{tunnel, WsClientConfig};
use futures_util::pin_mut;
use hyper::header::COOKIE;
use jsonwebtoken::TokenData;
use log::debug;
use std::ops::Deref;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
use tokio_stream::StreamExt;
use tracing::{error, event, span, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &RemoteAddr,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
// Connect to server with the correct protocol
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
TransportScheme::Ws | TransportScheme::Wss => {
tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg)
.await
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
}
TransportScheme::Http | TransportScheme::Https => {
tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg)
.await
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
}
};
debug!("Server response: {:?}", response);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
.instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
Ok(())
}
pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> {
pin_mut!(incoming_cnx);
while let Some(cnx) = incoming_cnx.next().await {
let (cnx_stream, remote_addr) = match cnx {
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
Err(err) => {
error!("Error accepting connection: {:?}", err);
continue;
}
};
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
let client_config = client_config.clone();
let tunnel = async move {
let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream)
.await
.map_err(|err| error!("{:?}", err));
}
.instrument(span);
tokio::spawn(tunnel);
}
Ok(())
}
pub async fn run_reverse_tunnel(
client_cfg: Arc<WsClientConfig>,
remote_addr: RemoteAddr,
connector: impl TunnelConnector,
) -> anyhow::Result<()> {
loop {
let client_config = client_cfg.clone();
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
// Correctly configure tunnel cfg
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
TransportScheme::Ws | TransportScheme::Wss => {
match tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr)
.instrument(span.clone())
.await
{
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
Err(err) => {
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
continue;
}
}
}
TransportScheme::Http | TransportScheme::Https => {
match tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr)
.instrument(span.clone())
.await
{
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
Err(err) => {
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
continue;
}
}
}
};
// Connect to endpoint
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
let remote = response
.headers
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|h| {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
jwt
})
.map(|jwt| RemoteAddr {
protocol: jwt.claims.p,
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
port: jwt.claims.rp,
});
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
Ok(s) => s,
Err(err) => {
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
continue;
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();
let tunnel = async move {
let ping_frequency = client_config.websocket_ping_frequency;
tokio::spawn(
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
.in_current_span(),
);
// Forward websocket rx to local rx
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
}
.instrument(span.clone());
tokio::spawn(tunnel);
}
}

221
src/tunnel/client/client.rs Normal file
View file

@ -0,0 +1,221 @@
use crate::tunnel;
use crate::tunnel::client::cnx_pool::WsConnection;
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::connectors::TunnelConnector;
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
use anyhow::Context;
use futures_util::pin_mut;
use hyper::header::COOKIE;
use jsonwebtoken::TokenData;
use log::debug;
use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
use tokio_stream::StreamExt;
use tracing::{error, event, span, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
#[derive(Clone)]
pub struct WsClient {
pub config: Arc<WsClientConfig>,
pub cnx_pool: bb8::Pool<WsConnection>,
_tls_reloader: Arc<TlsReloader>,
}
impl WsClient {
pub async fn new(
config: WsClientConfig,
connection_min_idle: u32,
connection_retry_max_backoff_sec: Duration,
) -> anyhow::Result<Self> {
let config = Arc::new(config);
let cnx = WsConnection::new(config.clone());
let tls_reloader = TlsReloader::new_for_client(config.clone()).with_context(|| "Cannot create tls reloader")?;
let cnx_pool = bb8::Pool::builder()
.max_size(1000)
.min_idle(Some(connection_min_idle))
.max_lifetime(Some(Duration::from_secs(30)))
.connection_timeout(connection_retry_max_backoff_sec)
.retry_connection(true)
.build(cnx)
.await?;
Ok(Self {
config,
cnx_pool,
_tls_reloader: Arc::new(tls_reloader),
})
}
}
impl WsClient {
async fn connect_to_server<R, W>(
&self,
request_id: Uuid,
remote_cfg: &RemoteAddr,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
// Connect to server with the correct protocol
let (ws_rx, ws_tx, response) = match self.config.remote_addr.scheme() {
TransportScheme::Ws | TransportScheme::Wss => {
tunnel::transport::websocket::connect(request_id, self, remote_cfg)
.await
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
}
TransportScheme::Http | TransportScheme::Https => {
tunnel::transport::http2::connect(request_id, self, remote_cfg)
.await
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
}
};
debug!("Server response: {:?}", response);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = self.config.websocket_ping_frequency;
tokio::spawn(
super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
.instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
Ok(())
}
pub async fn run_tunnel(self, tunnel_listener: impl TunnelListener) -> anyhow::Result<()> {
pin_mut!(tunnel_listener);
while let Some(cnx) = tunnel_listener.next().await {
let (cnx_stream, remote_addr) = match cnx {
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
Err(err) => {
error!("Error accepting connection: {:?}", err);
continue;
}
};
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
let client = self.clone();
let tunnel = async move {
let _ = client
.connect_to_server(request_id, &remote_addr, cnx_stream)
.await
.map_err(|err| error!("{:?}", err));
}
.instrument(span);
tokio::spawn(tunnel);
}
Ok(())
}
pub async fn run_reverse_tunnel(
self,
remote_addr: RemoteAddr,
connector: impl TunnelConnector,
) -> anyhow::Result<()> {
loop {
let client = self.clone();
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
// Correctly configure tunnel cfg
let (ws_rx, ws_tx, response) = match client.config.remote_addr.scheme() {
TransportScheme::Ws | TransportScheme::Wss => {
match tunnel::transport::websocket::connect(request_id, &client, &remote_addr)
.instrument(span.clone())
.await
{
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
Err(err) => {
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
}
}
TransportScheme::Http | TransportScheme::Https => {
match tunnel::transport::http2::connect(request_id, &client, &remote_addr)
.instrument(span.clone())
.await
{
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
Err(err) => {
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
}
}
};
// Connect to endpoint
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
let remote = response
.headers
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|h| {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
jwt
})
.map(|jwt| RemoteAddr {
protocol: jwt.claims.p,
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
port: jwt.claims.rp,
});
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
Ok(s) => s,
Err(err) => {
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
continue;
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();
let tunnel = async move {
let ping_frequency = client.config.websocket_ping_frequency;
tokio::spawn(
super::super::transport::io::propagate_local_to_remote(
local_rx,
ws_tx,
close_tx,
Some(ping_frequency),
)
.in_current_span(),
);
// Forward websocket rx to local rx
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
}
.instrument(span.clone());
tokio::spawn(tunnel);
}
}
}

View file

@ -0,0 +1,74 @@
use crate::protocols;
use crate::protocols::tls;
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::TransportStream;
use async_trait::async_trait;
use bb8::ManageConnection;
use std::ops::Deref;
use std::sync::Arc;
use tracing::instrument;
#[derive(Clone)]
pub struct WsConnection(Arc<WsClientConfig>);
impl WsConnection {
pub fn new(config: Arc<WsClientConfig>) -> Self {
Self(config)
}
}
impl Deref for WsConnection {
type Target = WsClientConfig;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[async_trait]
impl ManageConnection for WsConnection {
type Connection = Option<TransportStream>;
type Error = anyhow::Error;
#[instrument(level = "trace", name = "cnx_server", skip_all)]
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let so_mark = self.socket_so_mark;
let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
protocols::tcp::connect_with_http_proxy(
http_proxy,
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
} else {
protocols::tcp::connect(
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
};
if self.remote_addr.tls().is_some() {
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
} else {
Ok(Some(TransportStream::Plain(tcp_stream)))
}
}
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
Ok(())
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_none()
}
}

View file

@ -0,0 +1,76 @@
use crate::protocols::dns::DnsResolver;
use crate::tunnel::TransportAddr;
use hyper::header::{HeaderName, HeaderValue};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio_rustls::rustls::pki_types::{DnsName, ServerName};
use tokio_rustls::TlsConnector;
use url::{Host, Url};
#[derive(Clone)]
pub struct WsClientConfig {
pub remote_addr: TransportAddr,
pub socket_so_mark: Option<u32>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>,
pub http_headers_file: Option<PathBuf>,
pub http_header_host: HeaderValue,
pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
pub http_proxy: Option<Url>,
pub dns_resolver: DnsResolver,
}
impl WsClientConfig {
pub const fn websocket_scheme(&self) -> &'static str {
match self.remote_addr.tls().is_some() {
false => "ws",
true => "wss",
}
}
pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
}
pub fn tls_server_name(&self) -> ServerName<'static> {
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
self.remote_addr
.tls()
.and_then(|tls| tls.tls_sni_override.as_ref())
.map_or_else(
|| match &self.remote_addr.host() {
Host::Domain(domain) => ServerName::DnsName(
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
),
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
},
|sni_override| ServerName::DnsName(sni_override.clone()),
)
}
}
#[derive(Clone)]
pub struct TlsClientConfig {
pub tls_sni_disabled: bool,
pub tls_sni_override: Option<DnsName<'static>>,
pub tls_verify_certificate: bool,
pub tls_connector: Arc<RwLock<TlsConnector>>,
pub tls_certificate_path: Option<PathBuf>,
pub tls_key_path: Option<PathBuf>,
}
impl TlsClientConfig {
pub fn tls_connector(&self) -> TlsConnector {
self.tls_connector.read().clone()
}
}

8
src/tunnel/client/mod.rs Normal file
View file

@ -0,0 +1,8 @@
#![allow(clippy::module_inception)]
mod client;
mod cnx_pool;
mod config;
pub use client::WsClient;
pub use config::TlsClientConfig;
pub use config::WsClientConfig;

View file

@ -5,10 +5,7 @@ pub mod server;
pub mod tls_reloader;
mod transport;
use crate::protocols::tls;
use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use crate::{LocalProtocol, TlsClientConfig};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
@ -23,7 +20,6 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tracing::instrument;
use url::Host;
use uuid::Uuid;
@ -304,54 +300,6 @@ impl AsyncWrite for TransportStream {
}
}
#[async_trait]
impl ManageConnection for WsClientConfig {
type Connection = Option<TransportStream>;
type Error = anyhow::Error;
#[instrument(level = "trace", name = "cnx_server", skip_all)]
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let so_mark = self.socket_so_mark;
let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
protocols::tcp::connect_with_http_proxy(
http_proxy,
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
} else {
protocols::tcp::connect(
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
};
if self.remote_addr.tls().is_some() {
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
} else {
Ok(Some(TransportStream::Plain(tcp_stream)))
}
}
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
Ok(())
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_none()
}
}
pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
match addr.ip() {
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),

View file

@ -1,6 +1,7 @@
use crate::protocols::tls;
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
use crate::{WsClientConfig, WsServerConfig};
use crate::WsServerConfig;
use anyhow::Context;
use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher};

View file

@ -1,6 +1,6 @@
use crate::tunnel::client::WsClient;
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme};
use crate::WsClientConfig;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use http_body_util::{BodyExt, BodyStream, StreamBody};
@ -99,22 +99,24 @@ impl TunnelWrite for Http2TunnelWrite {
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
client: &WsClient,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
let mut pooled_cnx = match client.cnx_pool.get().await {
Ok(cnx) => Ok(cnx),
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
}?;
// In http2 HOST header does not exist, it is explicitly set in the authority from the request uri
let (headers_file, authority) = client_cfg
let (headers_file, authority) =
client
.config
.http_headers_file
.as_ref()
.map_or((None, None), |headers_file_path| {
let (host, headers) = headers_from_file(headers_file_path);
let host = if let Some((_, v)) = host {
match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) {
match (client.config.remote_addr.scheme(), client.config.remote_addr.port()) {
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
Some(v.to_str().unwrap_or("").to_string())
}
@ -131,23 +133,23 @@ pub async fn connect(
.method("POST")
.uri(format!(
"{}://{}/{}/events",
client_cfg.remote_addr.scheme(),
client.config.remote_addr.scheme(),
authority
.as_deref()
.unwrap_or_else(|| client_cfg.http_header_host.to_str().unwrap_or("")),
&client_cfg.http_upgrade_path_prefix
.unwrap_or_else(|| client.config.http_header_host.to_str().unwrap_or("")),
&client.config.http_upgrade_path_prefix
))
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
.header(CONTENT_TYPE, "application/json")
.version(hyper::Version::HTTP_2);
let headers = req.headers_mut().unwrap();
for (k, v) in &client_cfg.http_headers {
for (k, v) in &client.config.http_headers {
let _ = headers.remove(k);
headers.append(k, v.clone());
}
if let Some(auth) = &client_cfg.http_upgrade_credentials {
if let Some(auth) = &client.config.http_upgrade_credentials {
let _ = headers.remove(AUTHORIZATION);
headers.append(AUTHORIZATION, auth.clone());
}
@ -164,7 +166,7 @@ pub async fn connect(
let req = req.body(body).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
client_cfg.remote_addr
client.config.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
@ -172,11 +174,11 @@ pub async fn connect(
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.adaptive_window(true)
.keep_alive_interval(client_cfg.websocket_ping_frequency)
.keep_alive_interval(client.config.websocket_ping_frequency)
.keep_alive_while_idle(false)
.handshake(TokioIo::new(transport))
.await
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client.config.remote_addr))?;
tokio::spawn(async move {
if let Err(err) = cnx.await {
error!("{:?}", err)
@ -186,7 +188,7 @@ pub async fn connect(
let response = request_sender
.send_request(req)
.await
.with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
.with_context(|| format!("failed to send http2 request with the server {:?}", client.config.remote_addr))?;
if !response.status().is_success() {
return Err(anyhow!(

View file

@ -1,6 +1,6 @@
use crate::tunnel::client::WsClient;
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
use crate::WsClientConfig;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
@ -135,10 +135,11 @@ impl TunnelRead for WebsocketTunnelRead {
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
client: &WsClient,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
let client_cfg = &client.config;
let mut pooled_cnx = match client.cnx_pool.get().await {
Ok(cnx) => Ok(cnx),
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
}?;