Refacto: Use proper type for WsClient
This commit is contained in:
parent
5e74ed233d
commit
a33a889b3d
11 changed files with 453 additions and 412 deletions
185
src/main.rs
185
src/main.rs
|
@ -3,15 +3,23 @@ mod protocols;
|
||||||
mod restrictions;
|
mod restrictions;
|
||||||
mod tunnel;
|
mod tunnel;
|
||||||
|
|
||||||
|
use crate::protocols::dns::DnsResolver;
|
||||||
|
use crate::protocols::tls;
|
||||||
|
use crate::restrictions::types::RestrictionsRules;
|
||||||
|
use crate::tunnel::client::{TlsClientConfig, WsClient, WsClientConfig};
|
||||||
|
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
|
||||||
|
use crate::tunnel::listeners::{
|
||||||
|
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
|
||||||
|
};
|
||||||
|
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hyper::header::HOST;
|
use hyper::header::HOST;
|
||||||
use hyper::http::{HeaderName, HeaderValue};
|
use hyper::http::{HeaderName, HeaderValue};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use parking_lot::{Mutex, RwLock};
|
use parking_lot::{Mutex, RwLock};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::BTreeMap;
|
||||||
use std::fmt::{Debug, Formatter};
|
use std::fmt::{Debug, Formatter};
|
||||||
use std::io::ErrorKind;
|
use std::io::ErrorKind;
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
|
@ -21,21 +29,8 @@ use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{fmt, io};
|
use std::{fmt, io};
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
|
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer};
|
||||||
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
|
|
||||||
use tokio_rustls::TlsConnector;
|
|
||||||
|
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use crate::protocols::dns::DnsResolver;
|
|
||||||
use crate::protocols::tls;
|
|
||||||
use crate::restrictions::types::RestrictionsRules;
|
|
||||||
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
|
|
||||||
use crate::tunnel::listeners::{
|
|
||||||
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
|
|
||||||
};
|
|
||||||
use crate::tunnel::tls_reloader::TlsReloader;
|
|
||||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
|
||||||
use tracing_subscriber::filter::Directive;
|
use tracing_subscriber::filter::Directive;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use url::{Host, Url};
|
use url::{Host, Url};
|
||||||
|
@ -695,22 +690,6 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
||||||
Ok(url)
|
Ok(url)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct TlsClientConfig {
|
|
||||||
pub tls_sni_disabled: bool,
|
|
||||||
pub tls_sni_override: Option<DnsName<'static>>,
|
|
||||||
pub tls_verify_certificate: bool,
|
|
||||||
tls_connector: Arc<RwLock<TlsConnector>>,
|
|
||||||
pub tls_certificate_path: Option<PathBuf>,
|
|
||||||
pub tls_key_path: Option<PathBuf>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TlsClientConfig {
|
|
||||||
pub fn tls_connector(&self) -> TlsConnector {
|
|
||||||
self.tls_connector.read().clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TlsServerConfig {
|
pub struct TlsServerConfig {
|
||||||
pub tls_certificate: Mutex<Vec<CertificateDer<'static>>>,
|
pub tls_certificate: Mutex<Vec<CertificateDer<'static>>>,
|
||||||
|
@ -754,59 +733,6 @@ impl Debug for WsServerConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct WsClientConfig {
|
|
||||||
pub remote_addr: TransportAddr,
|
|
||||||
pub socket_so_mark: Option<u32>,
|
|
||||||
pub http_upgrade_path_prefix: String,
|
|
||||||
pub http_upgrade_credentials: Option<HeaderValue>,
|
|
||||||
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
|
||||||
pub http_headers_file: Option<PathBuf>,
|
|
||||||
pub http_header_host: HeaderValue,
|
|
||||||
pub timeout_connect: Duration,
|
|
||||||
pub websocket_ping_frequency: Duration,
|
|
||||||
pub websocket_mask_frame: bool,
|
|
||||||
pub http_proxy: Option<Url>,
|
|
||||||
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
|
|
||||||
tls_reloader: Option<Arc<TlsReloader>>,
|
|
||||||
pub dns_resolver: DnsResolver,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WsClientConfig {
|
|
||||||
pub const fn websocket_scheme(&self) -> &'static str {
|
|
||||||
match self.remote_addr.tls().is_some() {
|
|
||||||
false => "ws",
|
|
||||||
true => "wss",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cnx_pool(&self) -> &bb8::Pool<Self> {
|
|
||||||
self.cnx_pool.as_ref().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn websocket_host_url(&self) -> String {
|
|
||||||
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tls_server_name(&self) -> ServerName<'static> {
|
|
||||||
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
|
|
||||||
|
|
||||||
self.remote_addr
|
|
||||||
.tls()
|
|
||||||
.and_then(|tls| tls.tls_sni_override.as_ref())
|
|
||||||
.map_or_else(
|
|
||||||
|| match &self.remote_addr.host() {
|
|
||||||
Host::Domain(domain) => ServerName::DnsName(
|
|
||||||
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
|
|
||||||
),
|
|
||||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
|
|
||||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
|
|
||||||
},
|
|
||||||
|sni_override| ServerName::DnsName(sni_override.clone()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let args = Wstunnel::parse();
|
let args = Wstunnel::parse();
|
||||||
|
@ -866,24 +792,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
|
TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
|
||||||
let tls = match transport_scheme {
|
let tls = match transport_scheme {
|
||||||
TransportScheme::Ws | TransportScheme::Http => None,
|
TransportScheme::Ws | TransportScheme::Http => None,
|
||||||
TransportScheme::Wss => Some(TlsClientConfig {
|
TransportScheme::Wss | TransportScheme::Https => Some(TlsClientConfig {
|
||||||
tls_connector: Arc::new(RwLock::new(
|
|
||||||
tls::tls_connector(
|
|
||||||
args.tls_verify_certificate,
|
|
||||||
transport_scheme.alpn_protocols(),
|
|
||||||
!args.tls_sni_disable,
|
|
||||||
tls_certificate,
|
|
||||||
tls_key,
|
|
||||||
)
|
|
||||||
.expect("Cannot create tls connector"),
|
|
||||||
)),
|
|
||||||
tls_sni_override: args.tls_sni_override,
|
|
||||||
tls_verify_certificate: args.tls_verify_certificate,
|
|
||||||
tls_sni_disabled: args.tls_sni_disable,
|
|
||||||
tls_certificate_path: args.tls_certificate.clone(),
|
|
||||||
tls_key_path: args.tls_private_key.clone(),
|
|
||||||
}),
|
|
||||||
TransportScheme::Https => Some(TlsClientConfig {
|
|
||||||
tls_connector: Arc::new(RwLock::new(
|
tls_connector: Arc::new(RwLock::new(
|
||||||
tls::tls_connector(
|
tls::tls_connector(
|
||||||
args.tls_verify_certificate,
|
args.tls_verify_certificate,
|
||||||
|
@ -936,7 +845,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let mut client_config = WsClientConfig {
|
let client_config = WsClientConfig {
|
||||||
remote_addr: TransportAddr::new(
|
remote_addr: TransportAddr::new(
|
||||||
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
|
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
|
||||||
args.remote_addr.host().unwrap().to_owned(),
|
args.remote_addr.host().unwrap().to_owned(),
|
||||||
|
@ -953,8 +862,6 @@ async fn main() -> anyhow::Result<()> {
|
||||||
timeout_connect: Duration::from_secs(10),
|
timeout_connect: Duration::from_secs(10),
|
||||||
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
|
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
|
||||||
websocket_mask_frame: args.websocket_mask_frame,
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
cnx_pool: None,
|
|
||||||
tls_reloader: None,
|
|
||||||
dns_resolver: DnsResolver::new_from_urls(
|
dns_resolver: DnsResolver::new_from_urls(
|
||||||
&args.dns_resolver,
|
&args.dns_resolver,
|
||||||
http_proxy.clone(),
|
http_proxy.clone(),
|
||||||
|
@ -965,28 +872,16 @@ async fn main() -> anyhow::Result<()> {
|
||||||
http_proxy,
|
http_proxy,
|
||||||
};
|
};
|
||||||
|
|
||||||
let tls_reloader =
|
let client =
|
||||||
TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader");
|
WsClient::new(client_config, args.connection_min_idle, args.connection_retry_max_backoff_sec).await?;
|
||||||
client_config.tls_reloader = Some(Arc::new(tls_reloader));
|
|
||||||
let pool = bb8::Pool::builder()
|
|
||||||
.max_size(1000)
|
|
||||||
.min_idle(Some(args.connection_min_idle))
|
|
||||||
.max_lifetime(Some(Duration::from_secs(30)))
|
|
||||||
.connection_timeout(args.connection_retry_max_backoff_sec)
|
|
||||||
.retry_connection(true)
|
|
||||||
.build(client_config.clone())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
client_config.cnx_pool = Some(pool);
|
|
||||||
let client_config = Arc::new(client_config);
|
|
||||||
|
|
||||||
// Start tunnels
|
// Start tunnels
|
||||||
for tunnel in args.remote_to_local.into_iter() {
|
for tunnel in args.remote_to_local.into_iter() {
|
||||||
let client_config = client_config.clone();
|
let client = client.clone();
|
||||||
match &tunnel.local_protocol {
|
match &tunnel.local_protocol {
|
||||||
LocalProtocol::Tcp { proxy_protocol: _ } => {
|
LocalProtocol::Tcp { proxy_protocol: _ } => {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cfg = client_config.clone();
|
let cfg = client.config.clone();
|
||||||
let tcp_connector = TcpTunnelConnector::new(
|
let tcp_connector = TcpTunnelConnector::new(
|
||||||
&tunnel.remote.0,
|
&tunnel.remote.0,
|
||||||
tunnel.remote.1,
|
tunnel.remote.1,
|
||||||
|
@ -1000,9 +895,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
};
|
};
|
||||||
if let Err(err) =
|
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
|
||||||
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1011,7 +904,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let timeout = *timeout;
|
let timeout = *timeout;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cfg = client_config.clone();
|
let cfg = client.config.clone();
|
||||||
let (host, port) = to_host_port(tunnel.local);
|
let (host, port) = to_host_port(tunnel.local);
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::ReverseUdp { timeout },
|
protocol: LocalProtocol::ReverseUdp { timeout },
|
||||||
|
@ -1026,9 +919,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
&cfg.dns_resolver,
|
&cfg.dns_resolver,
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(err) =
|
if let Err(err) = client.run_reverse_tunnel(remote.clone(), udp_connector).await {
|
||||||
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1037,7 +928,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let credentials = credentials.clone();
|
let credentials = credentials.clone();
|
||||||
let timeout = *timeout;
|
let timeout = *timeout;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cfg = client_config.clone();
|
let cfg = client.config.clone();
|
||||||
let (host, port) = to_host_port(tunnel.local);
|
let (host, port) = to_host_port(tunnel.local);
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::ReverseSocks5 { timeout, credentials },
|
protocol: LocalProtocol::ReverseSocks5 { timeout, credentials },
|
||||||
|
@ -1047,9 +938,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let socks_connector =
|
let socks_connector =
|
||||||
Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver);
|
Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver);
|
||||||
|
|
||||||
if let Err(err) =
|
if let Err(err) = client.run_reverse_tunnel(remote, socks_connector).await {
|
||||||
tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1060,7 +949,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let credentials = credentials.clone();
|
let credentials = credentials.clone();
|
||||||
let timeout = *timeout;
|
let timeout = *timeout;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cfg = client_config.clone();
|
let cfg = client.config.clone();
|
||||||
let (host, port) = to_host_port(tunnel.local);
|
let (host, port) = to_host_port(tunnel.local);
|
||||||
let remote = RemoteAddr {
|
let remote = RemoteAddr {
|
||||||
protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials },
|
protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials },
|
||||||
|
@ -1075,9 +964,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
&cfg.dns_resolver,
|
&cfg.dns_resolver,
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(err) =
|
if let Err(err) = client.run_reverse_tunnel(remote.clone(), tcp_connector).await {
|
||||||
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1086,7 +973,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
LocalProtocol::Unix { path } => {
|
LocalProtocol::Unix { path } => {
|
||||||
let path = path.clone();
|
let path = path.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cfg = client_config.clone();
|
let cfg = client.config.clone();
|
||||||
let tcp_connector = TcpTunnelConnector::new(
|
let tcp_connector = TcpTunnelConnector::new(
|
||||||
&tunnel.remote.0,
|
&tunnel.remote.0,
|
||||||
tunnel.remote.1,
|
tunnel.remote.1,
|
||||||
|
@ -1101,9 +988,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
};
|
};
|
||||||
if let Err(err) =
|
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
|
||||||
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1126,14 +1011,14 @@ async fn main() -> anyhow::Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
for tunnel in args.local_to_remote.into_iter() {
|
for tunnel in args.local_to_remote.into_iter() {
|
||||||
let client_config = client_config.clone();
|
let client = client.clone();
|
||||||
|
|
||||||
match &tunnel.local_protocol {
|
match &tunnel.local_protocol {
|
||||||
LocalProtocol::Tcp { proxy_protocol } => {
|
LocalProtocol::Tcp { proxy_protocol } => {
|
||||||
let server =
|
let server =
|
||||||
TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
|
TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1144,7 +1029,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;
|
let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1154,7 +1039,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
use crate::tunnel::listeners::UnixTunnelListener;
|
use crate::tunnel::listeners::UnixTunnelListener;
|
||||||
let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1169,7 +1054,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
use crate::tunnel::listeners::new_tproxy_udp;
|
use crate::tunnel::listeners::new_tproxy_udp;
|
||||||
let server = new_tproxy_udp(tunnel.local, *timeout).await?;
|
let server = new_tproxy_udp(tunnel.local, *timeout).await?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1182,7 +1067,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
|
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1190,7 +1075,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
LocalProtocol::Socks5 { timeout, credentials } => {
|
LocalProtocol::Socks5 { timeout, credentials } => {
|
||||||
let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
|
let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1204,7 +1089,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
|
HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
|
||||||
.await?;
|
.await?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1213,7 +1098,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
LocalProtocol::Stdio => {
|
LocalProtocol::Stdio => {
|
||||||
let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
if let Err(err) = client.run_tunnel(server).await {
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{TlsServerConfig, WsClientConfig};
|
use crate::TlsServerConfig;
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::client::TlsStream;
|
use tokio_rustls::client::TlsStream;
|
||||||
|
|
||||||
|
use crate::tunnel::client::WsClientConfig;
|
||||||
use crate::tunnel::TransportAddr;
|
use crate::tunnel::TransportAddr;
|
||||||
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||||
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
|
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
|
||||||
|
|
|
@ -1,176 +0,0 @@
|
||||||
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
|
|
||||||
use crate::tunnel::connectors::TunnelConnector;
|
|
||||||
use crate::tunnel::listeners::TunnelListener;
|
|
||||||
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
|
|
||||||
use crate::{tunnel, WsClientConfig};
|
|
||||||
use futures_util::pin_mut;
|
|
||||||
use hyper::header::COOKIE;
|
|
||||||
use jsonwebtoken::TokenData;
|
|
||||||
use log::debug;
|
|
||||||
use std::ops::Deref;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio::sync::oneshot;
|
|
||||||
use tokio_stream::StreamExt;
|
|
||||||
use tracing::{error, event, span, Instrument, Level, Span};
|
|
||||||
use url::Host;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
async fn connect_to_server<R, W>(
|
|
||||||
request_id: Uuid,
|
|
||||||
client_cfg: &WsClientConfig,
|
|
||||||
remote_cfg: &RemoteAddr,
|
|
||||||
duplex_stream: (R, W),
|
|
||||||
) -> anyhow::Result<()>
|
|
||||||
where
|
|
||||||
R: AsyncRead + Send + 'static,
|
|
||||||
W: AsyncWrite + Send + 'static,
|
|
||||||
{
|
|
||||||
// Connect to server with the correct protocol
|
|
||||||
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
|
|
||||||
TransportScheme::Ws | TransportScheme::Wss => {
|
|
||||||
tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg)
|
|
||||||
.await
|
|
||||||
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
|
|
||||||
}
|
|
||||||
TransportScheme::Http | TransportScheme::Https => {
|
|
||||||
tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg)
|
|
||||||
.await
|
|
||||||
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!("Server response: {:?}", response);
|
|
||||||
let (local_rx, local_tx) = duplex_stream;
|
|
||||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
|
||||||
|
|
||||||
// Forward local tx to websocket tx
|
|
||||||
let ping_frequency = client_cfg.websocket_ping_frequency;
|
|
||||||
tokio::spawn(
|
|
||||||
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
|
||||||
.instrument(Span::current()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Forward websocket rx to local rx
|
|
||||||
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> {
|
|
||||||
pin_mut!(incoming_cnx);
|
|
||||||
while let Some(cnx) = incoming_cnx.next().await {
|
|
||||||
let (cnx_stream, remote_addr) = match cnx {
|
|
||||||
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
|
|
||||||
Err(err) => {
|
|
||||||
error!("Error accepting connection: {:?}", err);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let request_id = Uuid::now_v7();
|
|
||||||
let span = span!(
|
|
||||||
Level::INFO,
|
|
||||||
"tunnel",
|
|
||||||
id = request_id.to_string(),
|
|
||||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
|
||||||
);
|
|
||||||
let client_config = client_config.clone();
|
|
||||||
|
|
||||||
let tunnel = async move {
|
|
||||||
let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream)
|
|
||||||
.await
|
|
||||||
.map_err(|err| error!("{:?}", err));
|
|
||||||
}
|
|
||||||
.instrument(span);
|
|
||||||
|
|
||||||
tokio::spawn(tunnel);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_reverse_tunnel(
|
|
||||||
client_cfg: Arc<WsClientConfig>,
|
|
||||||
remote_addr: RemoteAddr,
|
|
||||||
connector: impl TunnelConnector,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
loop {
|
|
||||||
let client_config = client_cfg.clone();
|
|
||||||
let request_id = Uuid::now_v7();
|
|
||||||
let span = span!(
|
|
||||||
Level::INFO,
|
|
||||||
"tunnel",
|
|
||||||
id = request_id.to_string(),
|
|
||||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
|
||||||
);
|
|
||||||
// Correctly configure tunnel cfg
|
|
||||||
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
|
|
||||||
TransportScheme::Ws | TransportScheme::Wss => {
|
|
||||||
match tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr)
|
|
||||||
.instrument(span.clone())
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
|
|
||||||
Err(err) => {
|
|
||||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TransportScheme::Http | TransportScheme::Https => {
|
|
||||||
match tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr)
|
|
||||||
.instrument(span.clone())
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
|
|
||||||
Err(err) => {
|
|
||||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Connect to endpoint
|
|
||||||
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
|
|
||||||
let remote = response
|
|
||||||
.headers
|
|
||||||
.get(COOKIE)
|
|
||||||
.and_then(|h| h.to_str().ok())
|
|
||||||
.and_then(|h| {
|
|
||||||
let (validation, decode_key) = JWT_DECODE.deref();
|
|
||||||
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
|
|
||||||
jwt
|
|
||||||
})
|
|
||||||
.map(|jwt| RemoteAddr {
|
|
||||||
protocol: jwt.claims.p,
|
|
||||||
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
|
|
||||||
port: jwt.claims.rp,
|
|
||||||
});
|
|
||||||
|
|
||||||
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(err) => {
|
|
||||||
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
|
||||||
let tunnel = async move {
|
|
||||||
let ping_frequency = client_config.websocket_ping_frequency;
|
|
||||||
tokio::spawn(
|
|
||||||
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
|
||||||
.in_current_span(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Forward websocket rx to local rx
|
|
||||||
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
|
||||||
}
|
|
||||||
.instrument(span.clone());
|
|
||||||
tokio::spawn(tunnel);
|
|
||||||
}
|
|
||||||
}
|
|
221
src/tunnel/client/client.rs
Normal file
221
src/tunnel/client/client.rs
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
use crate::tunnel;
|
||||||
|
use crate::tunnel::client::cnx_pool::WsConnection;
|
||||||
|
use crate::tunnel::client::WsClientConfig;
|
||||||
|
use crate::tunnel::connectors::TunnelConnector;
|
||||||
|
use crate::tunnel::listeners::TunnelListener;
|
||||||
|
use crate::tunnel::tls_reloader::TlsReloader;
|
||||||
|
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
|
||||||
|
use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
|
||||||
|
use anyhow::Context;
|
||||||
|
use futures_util::pin_mut;
|
||||||
|
use hyper::header::COOKIE;
|
||||||
|
use jsonwebtoken::TokenData;
|
||||||
|
use log::debug;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tracing::{error, event, span, Instrument, Level, Span};
|
||||||
|
use url::Host;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsClient {
|
||||||
|
pub config: Arc<WsClientConfig>,
|
||||||
|
pub cnx_pool: bb8::Pool<WsConnection>,
|
||||||
|
_tls_reloader: Arc<TlsReloader>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClient {
|
||||||
|
pub async fn new(
|
||||||
|
config: WsClientConfig,
|
||||||
|
connection_min_idle: u32,
|
||||||
|
connection_retry_max_backoff_sec: Duration,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
|
let config = Arc::new(config);
|
||||||
|
let cnx = WsConnection::new(config.clone());
|
||||||
|
let tls_reloader = TlsReloader::new_for_client(config.clone()).with_context(|| "Cannot create tls reloader")?;
|
||||||
|
let cnx_pool = bb8::Pool::builder()
|
||||||
|
.max_size(1000)
|
||||||
|
.min_idle(Some(connection_min_idle))
|
||||||
|
.max_lifetime(Some(Duration::from_secs(30)))
|
||||||
|
.connection_timeout(connection_retry_max_backoff_sec)
|
||||||
|
.retry_connection(true)
|
||||||
|
.build(cnx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
cnx_pool,
|
||||||
|
_tls_reloader: Arc::new(tls_reloader),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClient {
|
||||||
|
async fn connect_to_server<R, W>(
|
||||||
|
&self,
|
||||||
|
request_id: Uuid,
|
||||||
|
remote_cfg: &RemoteAddr,
|
||||||
|
duplex_stream: (R, W),
|
||||||
|
) -> anyhow::Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Send + 'static,
|
||||||
|
W: AsyncWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
// Connect to server with the correct protocol
|
||||||
|
let (ws_rx, ws_tx, response) = match self.config.remote_addr.scheme() {
|
||||||
|
TransportScheme::Ws | TransportScheme::Wss => {
|
||||||
|
tunnel::transport::websocket::connect(request_id, self, remote_cfg)
|
||||||
|
.await
|
||||||
|
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
|
||||||
|
}
|
||||||
|
TransportScheme::Http | TransportScheme::Https => {
|
||||||
|
tunnel::transport::http2::connect(request_id, self, remote_cfg)
|
||||||
|
.await
|
||||||
|
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Server response: {:?}", response);
|
||||||
|
let (local_rx, local_tx) = duplex_stream;
|
||||||
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
|
|
||||||
|
// Forward local tx to websocket tx
|
||||||
|
let ping_frequency = self.config.websocket_ping_frequency;
|
||||||
|
tokio::spawn(
|
||||||
|
super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||||
|
.instrument(Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Forward websocket rx to local rx
|
||||||
|
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_tunnel(self, tunnel_listener: impl TunnelListener) -> anyhow::Result<()> {
|
||||||
|
pin_mut!(tunnel_listener);
|
||||||
|
while let Some(cnx) = tunnel_listener.next().await {
|
||||||
|
let (cnx_stream, remote_addr) = match cnx {
|
||||||
|
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
|
||||||
|
Err(err) => {
|
||||||
|
error!("Error accepting connection: {:?}", err);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let request_id = Uuid::now_v7();
|
||||||
|
let span = span!(
|
||||||
|
Level::INFO,
|
||||||
|
"tunnel",
|
||||||
|
id = request_id.to_string(),
|
||||||
|
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||||
|
);
|
||||||
|
let client = self.clone();
|
||||||
|
let tunnel = async move {
|
||||||
|
let _ = client
|
||||||
|
.connect_to_server(request_id, &remote_addr, cnx_stream)
|
||||||
|
.await
|
||||||
|
.map_err(|err| error!("{:?}", err));
|
||||||
|
}
|
||||||
|
.instrument(span);
|
||||||
|
|
||||||
|
tokio::spawn(tunnel);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_reverse_tunnel(
|
||||||
|
self,
|
||||||
|
remote_addr: RemoteAddr,
|
||||||
|
connector: impl TunnelConnector,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
loop {
|
||||||
|
let client = self.clone();
|
||||||
|
let request_id = Uuid::now_v7();
|
||||||
|
let span = span!(
|
||||||
|
Level::INFO,
|
||||||
|
"tunnel",
|
||||||
|
id = request_id.to_string(),
|
||||||
|
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||||
|
);
|
||||||
|
// Correctly configure tunnel cfg
|
||||||
|
let (ws_rx, ws_tx, response) = match client.config.remote_addr.scheme() {
|
||||||
|
TransportScheme::Ws | TransportScheme::Wss => {
|
||||||
|
match tunnel::transport::websocket::connect(request_id, &client, &remote_addr)
|
||||||
|
.instrument(span.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
|
||||||
|
Err(err) => {
|
||||||
|
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TransportScheme::Http | TransportScheme::Https => {
|
||||||
|
match tunnel::transport::http2::connect(request_id, &client, &remote_addr)
|
||||||
|
.instrument(span.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
|
||||||
|
Err(err) => {
|
||||||
|
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Connect to endpoint
|
||||||
|
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
|
||||||
|
let remote = response
|
||||||
|
.headers
|
||||||
|
.get(COOKIE)
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|h| {
|
||||||
|
let (validation, decode_key) = JWT_DECODE.deref();
|
||||||
|
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
|
||||||
|
jwt
|
||||||
|
})
|
||||||
|
.map(|jwt| RemoteAddr {
|
||||||
|
protocol: jwt.claims.p,
|
||||||
|
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
|
||||||
|
port: jwt.claims.rp,
|
||||||
|
});
|
||||||
|
|
||||||
|
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(err) => {
|
||||||
|
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
|
let tunnel = async move {
|
||||||
|
let ping_frequency = client.config.websocket_ping_frequency;
|
||||||
|
tokio::spawn(
|
||||||
|
super::super::transport::io::propagate_local_to_remote(
|
||||||
|
local_rx,
|
||||||
|
ws_tx,
|
||||||
|
close_tx,
|
||||||
|
Some(ping_frequency),
|
||||||
|
)
|
||||||
|
.in_current_span(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Forward websocket rx to local rx
|
||||||
|
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||||
|
}
|
||||||
|
.instrument(span.clone());
|
||||||
|
tokio::spawn(tunnel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
74
src/tunnel/client/cnx_pool.rs
Normal file
74
src/tunnel/client/cnx_pool.rs
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
use crate::protocols;
|
||||||
|
use crate::protocols::tls;
|
||||||
|
use crate::tunnel::client::WsClientConfig;
|
||||||
|
use crate::tunnel::TransportStream;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bb8::ManageConnection;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsConnection(Arc<WsClientConfig>);
|
||||||
|
|
||||||
|
impl WsConnection {
|
||||||
|
pub fn new(config: Arc<WsClientConfig>) -> Self {
|
||||||
|
Self(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for WsConnection {
|
||||||
|
type Target = WsClientConfig;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ManageConnection for WsConnection {
|
||||||
|
type Connection = Option<TransportStream>;
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
#[instrument(level = "trace", name = "cnx_server", skip_all)]
|
||||||
|
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||||
|
let so_mark = self.socket_so_mark;
|
||||||
|
let timeout = self.timeout_connect;
|
||||||
|
|
||||||
|
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
|
||||||
|
protocols::tcp::connect_with_http_proxy(
|
||||||
|
http_proxy,
|
||||||
|
self.remote_addr.host(),
|
||||||
|
self.remote_addr.port(),
|
||||||
|
so_mark,
|
||||||
|
timeout,
|
||||||
|
&self.dns_resolver,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
protocols::tcp::connect(
|
||||||
|
self.remote_addr.host(),
|
||||||
|
self.remote_addr.port(),
|
||||||
|
so_mark,
|
||||||
|
timeout,
|
||||||
|
&self.dns_resolver,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.remote_addr.tls().is_some() {
|
||||||
|
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||||
|
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||||
|
} else {
|
||||||
|
Ok(Some(TransportStream::Plain(tcp_stream)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
||||||
|
conn.is_none()
|
||||||
|
}
|
||||||
|
}
|
76
src/tunnel/client/config.rs
Normal file
76
src/tunnel/client/config.rs
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
use crate::protocols::dns::DnsResolver;
|
||||||
|
use crate::tunnel::TransportAddr;
|
||||||
|
use hyper::header::{HeaderName, HeaderValue};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio_rustls::rustls::pki_types::{DnsName, ServerName};
|
||||||
|
use tokio_rustls::TlsConnector;
|
||||||
|
use url::{Host, Url};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsClientConfig {
|
||||||
|
pub remote_addr: TransportAddr,
|
||||||
|
pub socket_so_mark: Option<u32>,
|
||||||
|
pub http_upgrade_path_prefix: String,
|
||||||
|
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||||
|
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
||||||
|
pub http_headers_file: Option<PathBuf>,
|
||||||
|
pub http_header_host: HeaderValue,
|
||||||
|
pub timeout_connect: Duration,
|
||||||
|
pub websocket_ping_frequency: Duration,
|
||||||
|
pub websocket_mask_frame: bool,
|
||||||
|
pub http_proxy: Option<Url>,
|
||||||
|
pub dns_resolver: DnsResolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClientConfig {
|
||||||
|
pub const fn websocket_scheme(&self) -> &'static str {
|
||||||
|
match self.remote_addr.tls().is_some() {
|
||||||
|
false => "ws",
|
||||||
|
true => "wss",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn websocket_host_url(&self) -> String {
|
||||||
|
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_server_name(&self) -> ServerName<'static> {
|
||||||
|
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
|
||||||
|
|
||||||
|
self.remote_addr
|
||||||
|
.tls()
|
||||||
|
.and_then(|tls| tls.tls_sni_override.as_ref())
|
||||||
|
.map_or_else(
|
||||||
|
|| match &self.remote_addr.host() {
|
||||||
|
Host::Domain(domain) => ServerName::DnsName(
|
||||||
|
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
|
||||||
|
),
|
||||||
|
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
|
||||||
|
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
|
||||||
|
},
|
||||||
|
|sni_override| ServerName::DnsName(sni_override.clone()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TlsClientConfig {
|
||||||
|
pub tls_sni_disabled: bool,
|
||||||
|
pub tls_sni_override: Option<DnsName<'static>>,
|
||||||
|
pub tls_verify_certificate: bool,
|
||||||
|
pub tls_connector: Arc<RwLock<TlsConnector>>,
|
||||||
|
pub tls_certificate_path: Option<PathBuf>,
|
||||||
|
pub tls_key_path: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TlsClientConfig {
|
||||||
|
pub fn tls_connector(&self) -> TlsConnector {
|
||||||
|
self.tls_connector.read().clone()
|
||||||
|
}
|
||||||
|
}
|
8
src/tunnel/client/mod.rs
Normal file
8
src/tunnel/client/mod.rs
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#![allow(clippy::module_inception)]
|
||||||
|
mod client;
|
||||||
|
mod cnx_pool;
|
||||||
|
mod config;
|
||||||
|
|
||||||
|
pub use client::WsClient;
|
||||||
|
pub use config::TlsClientConfig;
|
||||||
|
pub use config::WsClientConfig;
|
|
@ -5,10 +5,7 @@ pub mod server;
|
||||||
pub mod tls_reloader;
|
pub mod tls_reloader;
|
||||||
mod transport;
|
mod transport;
|
||||||
|
|
||||||
use crate::protocols::tls;
|
use crate::{LocalProtocol, TlsClientConfig};
|
||||||
use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use bb8::ManageConnection;
|
|
||||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -23,7 +20,6 @@ use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::client::TlsStream;
|
use tokio_rustls::client::TlsStream;
|
||||||
use tracing::instrument;
|
|
||||||
use url::Host;
|
use url::Host;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -304,54 +300,6 @@ impl AsyncWrite for TransportStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ManageConnection for WsClientConfig {
|
|
||||||
type Connection = Option<TransportStream>;
|
|
||||||
type Error = anyhow::Error;
|
|
||||||
|
|
||||||
#[instrument(level = "trace", name = "cnx_server", skip_all)]
|
|
||||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
|
||||||
let so_mark = self.socket_so_mark;
|
|
||||||
let timeout = self.timeout_connect;
|
|
||||||
|
|
||||||
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
|
|
||||||
protocols::tcp::connect_with_http_proxy(
|
|
||||||
http_proxy,
|
|
||||||
self.remote_addr.host(),
|
|
||||||
self.remote_addr.port(),
|
|
||||||
so_mark,
|
|
||||||
timeout,
|
|
||||||
&self.dns_resolver,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
} else {
|
|
||||||
protocols::tcp::connect(
|
|
||||||
self.remote_addr.host(),
|
|
||||||
self.remote_addr.port(),
|
|
||||||
so_mark,
|
|
||||||
timeout,
|
|
||||||
&self.dns_resolver,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
if self.remote_addr.tls().is_some() {
|
|
||||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
|
||||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
|
||||||
} else {
|
|
||||||
Ok(Some(TransportStream::Plain(tcp_stream)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
|
||||||
conn.is_none()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
|
pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
|
||||||
match addr.ip() {
|
match addr.ip() {
|
||||||
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),
|
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::protocols::tls;
|
use crate::protocols::tls;
|
||||||
|
use crate::tunnel::client::WsClientConfig;
|
||||||
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
|
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
|
||||||
use crate::{WsClientConfig, WsServerConfig};
|
use crate::WsServerConfig;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use log::trace;
|
use log::trace;
|
||||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
use crate::tunnel::client::WsClient;
|
||||||
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme};
|
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme};
|
||||||
use crate::WsClientConfig;
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use http_body_util::{BodyExt, BodyStream, StreamBody};
|
use http_body_util::{BodyExt, BodyStream, StreamBody};
|
||||||
|
@ -99,55 +99,57 @@ impl TunnelWrite for Http2TunnelWrite {
|
||||||
|
|
||||||
pub async fn connect(
|
pub async fn connect(
|
||||||
request_id: Uuid,
|
request_id: Uuid,
|
||||||
client_cfg: &WsClientConfig,
|
client: &WsClient,
|
||||||
dest_addr: &RemoteAddr,
|
dest_addr: &RemoteAddr,
|
||||||
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
|
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
|
||||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
let mut pooled_cnx = match client.cnx_pool.get().await {
|
||||||
Ok(cnx) => Ok(cnx),
|
Ok(cnx) => Ok(cnx),
|
||||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
// In http2 HOST header does not exist, it is explicitly set in the authority from the request uri
|
// In http2 HOST header does not exist, it is explicitly set in the authority from the request uri
|
||||||
let (headers_file, authority) = client_cfg
|
let (headers_file, authority) =
|
||||||
.http_headers_file
|
client
|
||||||
.as_ref()
|
.config
|
||||||
.map_or((None, None), |headers_file_path| {
|
.http_headers_file
|
||||||
let (host, headers) = headers_from_file(headers_file_path);
|
.as_ref()
|
||||||
let host = if let Some((_, v)) = host {
|
.map_or((None, None), |headers_file_path| {
|
||||||
match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) {
|
let (host, headers) = headers_from_file(headers_file_path);
|
||||||
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
|
let host = if let Some((_, v)) = host {
|
||||||
Some(v.to_str().unwrap_or("").to_string())
|
match (client.config.remote_addr.scheme(), client.config.remote_addr.port()) {
|
||||||
|
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
|
||||||
|
Some(v.to_str().unwrap_or("").to_string())
|
||||||
|
}
|
||||||
|
(_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)),
|
||||||
}
|
}
|
||||||
(_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)),
|
} else {
|
||||||
}
|
None
|
||||||
} else {
|
};
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
(Some(headers), host)
|
(Some(headers), host)
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut req = Request::builder()
|
let mut req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(format!(
|
.uri(format!(
|
||||||
"{}://{}/{}/events",
|
"{}://{}/{}/events",
|
||||||
client_cfg.remote_addr.scheme(),
|
client.config.remote_addr.scheme(),
|
||||||
authority
|
authority
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.unwrap_or_else(|| client_cfg.http_header_host.to_str().unwrap_or("")),
|
.unwrap_or_else(|| client.config.http_header_host.to_str().unwrap_or("")),
|
||||||
&client_cfg.http_upgrade_path_prefix
|
&client.config.http_upgrade_path_prefix
|
||||||
))
|
))
|
||||||
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
||||||
.header(CONTENT_TYPE, "application/json")
|
.header(CONTENT_TYPE, "application/json")
|
||||||
.version(hyper::Version::HTTP_2);
|
.version(hyper::Version::HTTP_2);
|
||||||
|
|
||||||
let headers = req.headers_mut().unwrap();
|
let headers = req.headers_mut().unwrap();
|
||||||
for (k, v) in &client_cfg.http_headers {
|
for (k, v) in &client.config.http_headers {
|
||||||
let _ = headers.remove(k);
|
let _ = headers.remove(k);
|
||||||
headers.append(k, v.clone());
|
headers.append(k, v.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
if let Some(auth) = &client.config.http_upgrade_credentials {
|
||||||
let _ = headers.remove(AUTHORIZATION);
|
let _ = headers.remove(AUTHORIZATION);
|
||||||
headers.append(AUTHORIZATION, auth.clone());
|
headers.append(AUTHORIZATION, auth.clone());
|
||||||
}
|
}
|
||||||
|
@ -164,7 +166,7 @@ pub async fn connect(
|
||||||
let req = req.body(body).with_context(|| {
|
let req = req.body(body).with_context(|| {
|
||||||
format!(
|
format!(
|
||||||
"failed to build HTTP request to contact the server {:?}",
|
"failed to build HTTP request to contact the server {:?}",
|
||||||
client_cfg.remote_addr
|
client.config.remote_addr
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
debug!("with HTTP upgrade request {:?}", req);
|
debug!("with HTTP upgrade request {:?}", req);
|
||||||
|
@ -172,11 +174,11 @@ pub async fn connect(
|
||||||
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||||
.timer(TokioTimer::new())
|
.timer(TokioTimer::new())
|
||||||
.adaptive_window(true)
|
.adaptive_window(true)
|
||||||
.keep_alive_interval(client_cfg.websocket_ping_frequency)
|
.keep_alive_interval(client.config.websocket_ping_frequency)
|
||||||
.keep_alive_while_idle(false)
|
.keep_alive_while_idle(false)
|
||||||
.handshake(TokioIo::new(transport))
|
.handshake(TokioIo::new(transport))
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
|
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client.config.remote_addr))?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) = cnx.await {
|
if let Err(err) = cnx.await {
|
||||||
error!("{:?}", err)
|
error!("{:?}", err)
|
||||||
|
@ -186,7 +188,7 @@ pub async fn connect(
|
||||||
let response = request_sender
|
let response = request_sender
|
||||||
.send_request(req)
|
.send_request(req)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
|
.with_context(|| format!("failed to send http2 request with the server {:?}", client.config.remote_addr))?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
use crate::tunnel::client::WsClient;
|
||||||
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
|
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
|
||||||
use crate::WsClientConfig;
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||||
|
@ -135,10 +135,11 @@ impl TunnelRead for WebsocketTunnelRead {
|
||||||
|
|
||||||
pub async fn connect(
|
pub async fn connect(
|
||||||
request_id: Uuid,
|
request_id: Uuid,
|
||||||
client_cfg: &WsClientConfig,
|
client: &WsClient,
|
||||||
dest_addr: &RemoteAddr,
|
dest_addr: &RemoteAddr,
|
||||||
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
|
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
|
||||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
let client_cfg = &client.config;
|
||||||
|
let mut pooled_cnx = match client.cnx_pool.get().await {
|
||||||
Ok(cnx) => Ok(cnx),
|
Ok(cnx) => Ok(cnx),
|
||||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||||
}?;
|
}?;
|
||||||
|
|
Loading…
Reference in a new issue