Refacto: Use proper type for WsClient
This commit is contained in:
parent
5e74ed233d
commit
a33a889b3d
11 changed files with 453 additions and 412 deletions
185
src/main.rs
185
src/main.rs
|
@ -3,15 +3,23 @@ mod protocols;
|
|||
mod restrictions;
|
||||
mod tunnel;
|
||||
|
||||
use crate::protocols::dns::DnsResolver;
|
||||
use crate::protocols::tls;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
use crate::tunnel::client::{TlsClientConfig, WsClient, WsClientConfig};
|
||||
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
|
||||
use crate::tunnel::listeners::{
|
||||
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
|
||||
};
|
||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||
use base64::Engine;
|
||||
use clap::Parser;
|
||||
use hyper::header::HOST;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use log::debug;
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::io::ErrorKind;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
|
@ -21,21 +29,8 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use tokio::select;
|
||||
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::protocols::dns::DnsResolver;
|
||||
use crate::protocols::tls;
|
||||
use crate::restrictions::types::RestrictionsRules;
|
||||
use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTunnelConnector};
|
||||
use crate::tunnel::listeners::{
|
||||
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
|
||||
};
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
|
||||
use tracing_subscriber::filter::Directive;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use url::{Host, Url};
|
||||
|
@ -695,22 +690,6 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
|||
Ok(url)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TlsClientConfig {
|
||||
pub tls_sni_disabled: bool,
|
||||
pub tls_sni_override: Option<DnsName<'static>>,
|
||||
pub tls_verify_certificate: bool,
|
||||
tls_connector: Arc<RwLock<TlsConnector>>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl TlsClientConfig {
|
||||
pub fn tls_connector(&self) -> TlsConnector {
|
||||
self.tls_connector.read().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsServerConfig {
|
||||
pub tls_certificate: Mutex<Vec<CertificateDer<'static>>>,
|
||||
|
@ -754,59 +733,6 @@ impl Debug for WsServerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClientConfig {
|
||||
pub remote_addr: TransportAddr,
|
||||
pub socket_so_mark: Option<u32>,
|
||||
pub http_upgrade_path_prefix: String,
|
||||
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
||||
pub http_headers_file: Option<PathBuf>,
|
||||
pub http_header_host: HeaderValue,
|
||||
pub timeout_connect: Duration,
|
||||
pub websocket_ping_frequency: Duration,
|
||||
pub websocket_mask_frame: bool,
|
||||
pub http_proxy: Option<Url>,
|
||||
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
|
||||
tls_reloader: Option<Arc<TlsReloader>>,
|
||||
pub dns_resolver: DnsResolver,
|
||||
}
|
||||
|
||||
impl WsClientConfig {
|
||||
pub const fn websocket_scheme(&self) -> &'static str {
|
||||
match self.remote_addr.tls().is_some() {
|
||||
false => "ws",
|
||||
true => "wss",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cnx_pool(&self) -> &bb8::Pool<Self> {
|
||||
self.cnx_pool.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn websocket_host_url(&self) -> String {
|
||||
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
|
||||
}
|
||||
|
||||
pub fn tls_server_name(&self) -> ServerName<'static> {
|
||||
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
|
||||
|
||||
self.remote_addr
|
||||
.tls()
|
||||
.and_then(|tls| tls.tls_sni_override.as_ref())
|
||||
.map_or_else(
|
||||
|| match &self.remote_addr.host() {
|
||||
Host::Domain(domain) => ServerName::DnsName(
|
||||
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
|
||||
),
|
||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
|
||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
|
||||
},
|
||||
|sni_override| ServerName::DnsName(sni_override.clone()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Wstunnel::parse();
|
||||
|
@ -866,24 +792,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
|
||||
let tls = match transport_scheme {
|
||||
TransportScheme::Ws | TransportScheme::Http => None,
|
||||
TransportScheme::Wss => Some(TlsClientConfig {
|
||||
tls_connector: Arc::new(RwLock::new(
|
||||
tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
transport_scheme.alpn_protocols(),
|
||||
!args.tls_sni_disable,
|
||||
tls_certificate,
|
||||
tls_key,
|
||||
)
|
||||
.expect("Cannot create tls connector"),
|
||||
)),
|
||||
tls_sni_override: args.tls_sni_override,
|
||||
tls_verify_certificate: args.tls_verify_certificate,
|
||||
tls_sni_disabled: args.tls_sni_disable,
|
||||
tls_certificate_path: args.tls_certificate.clone(),
|
||||
tls_key_path: args.tls_private_key.clone(),
|
||||
}),
|
||||
TransportScheme::Https => Some(TlsClientConfig {
|
||||
TransportScheme::Wss | TransportScheme::Https => Some(TlsClientConfig {
|
||||
tls_connector: Arc::new(RwLock::new(
|
||||
tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
|
@ -936,7 +845,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
} else {
|
||||
None
|
||||
};
|
||||
let mut client_config = WsClientConfig {
|
||||
let client_config = WsClientConfig {
|
||||
remote_addr: TransportAddr::new(
|
||||
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
|
||||
args.remote_addr.host().unwrap().to_owned(),
|
||||
|
@ -953,8 +862,6 @@ async fn main() -> anyhow::Result<()> {
|
|||
timeout_connect: Duration::from_secs(10),
|
||||
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
|
||||
websocket_mask_frame: args.websocket_mask_frame,
|
||||
cnx_pool: None,
|
||||
tls_reloader: None,
|
||||
dns_resolver: DnsResolver::new_from_urls(
|
||||
&args.dns_resolver,
|
||||
http_proxy.clone(),
|
||||
|
@ -965,28 +872,16 @@ async fn main() -> anyhow::Result<()> {
|
|||
http_proxy,
|
||||
};
|
||||
|
||||
let tls_reloader =
|
||||
TlsReloader::new_for_client(Arc::new(client_config.clone())).expect("Cannot create tls reloader");
|
||||
client_config.tls_reloader = Some(Arc::new(tls_reloader));
|
||||
let pool = bb8::Pool::builder()
|
||||
.max_size(1000)
|
||||
.min_idle(Some(args.connection_min_idle))
|
||||
.max_lifetime(Some(Duration::from_secs(30)))
|
||||
.connection_timeout(args.connection_retry_max_backoff_sec)
|
||||
.retry_connection(true)
|
||||
.build(client_config.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
client_config.cnx_pool = Some(pool);
|
||||
let client_config = Arc::new(client_config);
|
||||
let client =
|
||||
WsClient::new(client_config, args.connection_min_idle, args.connection_retry_max_backoff_sec).await?;
|
||||
|
||||
// Start tunnels
|
||||
for tunnel in args.remote_to_local.into_iter() {
|
||||
let client_config = client_config.clone();
|
||||
let client = client.clone();
|
||||
match &tunnel.local_protocol {
|
||||
LocalProtocol::Tcp { proxy_protocol: _ } => {
|
||||
tokio::spawn(async move {
|
||||
let cfg = client_config.clone();
|
||||
let cfg = client.config.clone();
|
||||
let tcp_connector = TcpTunnelConnector::new(
|
||||
&tunnel.remote.0,
|
||||
tunnel.remote.1,
|
||||
|
@ -1000,9 +895,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
host,
|
||||
port,
|
||||
};
|
||||
if let Err(err) =
|
||||
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
|
||||
{
|
||||
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1011,7 +904,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let timeout = *timeout;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let cfg = client_config.clone();
|
||||
let cfg = client.config.clone();
|
||||
let (host, port) = to_host_port(tunnel.local);
|
||||
let remote = RemoteAddr {
|
||||
protocol: LocalProtocol::ReverseUdp { timeout },
|
||||
|
@ -1026,9 +919,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
&cfg.dns_resolver,
|
||||
);
|
||||
|
||||
if let Err(err) =
|
||||
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), udp_connector).await
|
||||
{
|
||||
if let Err(err) = client.run_reverse_tunnel(remote.clone(), udp_connector).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1037,7 +928,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let credentials = credentials.clone();
|
||||
let timeout = *timeout;
|
||||
tokio::spawn(async move {
|
||||
let cfg = client_config.clone();
|
||||
let cfg = client.config.clone();
|
||||
let (host, port) = to_host_port(tunnel.local);
|
||||
let remote = RemoteAddr {
|
||||
protocol: LocalProtocol::ReverseSocks5 { timeout, credentials },
|
||||
|
@ -1047,9 +938,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let socks_connector =
|
||||
Socks5TunnelConnector::new(cfg.socket_so_mark, cfg.timeout_connect, &cfg.dns_resolver);
|
||||
|
||||
if let Err(err) =
|
||||
tunnel::client::run_reverse_tunnel(client_config, remote, socks_connector).await
|
||||
{
|
||||
if let Err(err) = client.run_reverse_tunnel(remote, socks_connector).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1060,7 +949,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let credentials = credentials.clone();
|
||||
let timeout = *timeout;
|
||||
tokio::spawn(async move {
|
||||
let cfg = client_config.clone();
|
||||
let cfg = client.config.clone();
|
||||
let (host, port) = to_host_port(tunnel.local);
|
||||
let remote = RemoteAddr {
|
||||
protocol: LocalProtocol::ReverseHttpProxy { timeout, credentials },
|
||||
|
@ -1075,9 +964,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
&cfg.dns_resolver,
|
||||
);
|
||||
|
||||
if let Err(err) =
|
||||
tunnel::client::run_reverse_tunnel(client_config, remote.clone(), tcp_connector).await
|
||||
{
|
||||
if let Err(err) = client.run_reverse_tunnel(remote.clone(), tcp_connector).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1086,7 +973,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
LocalProtocol::Unix { path } => {
|
||||
let path = path.clone();
|
||||
tokio::spawn(async move {
|
||||
let cfg = client_config.clone();
|
||||
let cfg = client.config.clone();
|
||||
let tcp_connector = TcpTunnelConnector::new(
|
||||
&tunnel.remote.0,
|
||||
tunnel.remote.1,
|
||||
|
@ -1101,9 +988,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
host,
|
||||
port,
|
||||
};
|
||||
if let Err(err) =
|
||||
tunnel::client::run_reverse_tunnel(client_config, remote, tcp_connector).await
|
||||
{
|
||||
if let Err(err) = client.run_reverse_tunnel(remote, tcp_connector).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1126,14 +1011,14 @@ async fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
|
||||
for tunnel in args.local_to_remote.into_iter() {
|
||||
let client_config = client_config.clone();
|
||||
let client = client.clone();
|
||||
|
||||
match &tunnel.local_protocol {
|
||||
LocalProtocol::Tcp { proxy_protocol } => {
|
||||
let server =
|
||||
TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1144,7 +1029,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1154,7 +1039,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
use crate::tunnel::listeners::UnixTunnelListener;
|
||||
let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1169,7 +1054,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
use crate::tunnel::listeners::new_tproxy_udp;
|
||||
let server = new_tproxy_udp(tunnel.local, *timeout).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1182,7 +1067,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1190,7 +1075,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
LocalProtocol::Socks5 { timeout, credentials } => {
|
||||
let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1204,7 +1089,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
|
||||
.await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -1213,7 +1098,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
LocalProtocol::Stdio => {
|
||||
let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
|
||||
if let Err(err) = client.run_tunnel(server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{TlsServerConfig, WsClientConfig};
|
||||
use crate::TlsServerConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use std::fs::File;
|
||||
|
||||
|
@ -9,6 +9,7 @@ use std::sync::Arc;
|
|||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::TransportAddr;
|
||||
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
|
||||
|
|
|
@ -1,176 +0,0 @@
|
|||
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
|
||||
use crate::tunnel::connectors::TunnelConnector;
|
||||
use crate::tunnel::listeners::TunnelListener;
|
||||
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
|
||||
use crate::{tunnel, WsClientConfig};
|
||||
use futures_util::pin_mut;
|
||||
use hyper::header::COOKIE;
|
||||
use jsonwebtoken::TokenData;
|
||||
use log::debug;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{error, event, span, Instrument, Level, Span};
|
||||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn connect_to_server<R, W>(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
remote_cfg: &RemoteAddr,
|
||||
duplex_stream: (R, W),
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
R: AsyncRead + Send + 'static,
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
// Connect to server with the correct protocol
|
||||
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Server response: {:?}", response);
|
||||
let (local_rx, local_tx) = duplex_stream;
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
|
||||
// Forward local tx to websocket tx
|
||||
let ping_frequency = client_cfg.websocket_ping_frequency;
|
||||
tokio::spawn(
|
||||
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
||||
// Forward websocket rx to local rx
|
||||
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> {
|
||||
pin_mut!(incoming_cnx);
|
||||
while let Some(cnx) = incoming_cnx.next().await {
|
||||
let (cnx_stream, remote_addr) = match cnx {
|
||||
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
|
||||
Err(err) => {
|
||||
error!("Error accepting connection: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||
);
|
||||
let client_config = client_config.clone();
|
||||
|
||||
let tunnel = async move {
|
||||
let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream)
|
||||
.await
|
||||
.map_err(|err| error!("{:?}", err));
|
||||
}
|
||||
.instrument(span);
|
||||
|
||||
tokio::spawn(tunnel);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_reverse_tunnel(
|
||||
client_cfg: Arc<WsClientConfig>,
|
||||
remote_addr: RemoteAddr,
|
||||
connector: impl TunnelConnector,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let client_config = client_cfg.clone();
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||
);
|
||||
// Correctly configure tunnel cfg
|
||||
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
match tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
{
|
||||
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
match tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
{
|
||||
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Connect to endpoint
|
||||
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
|
||||
let remote = response
|
||||
.headers
|
||||
.get(COOKIE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| {
|
||||
let (validation, decode_key) = JWT_DECODE.deref();
|
||||
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
|
||||
jwt
|
||||
})
|
||||
.map(|jwt| RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
|
||||
port: jwt.claims.rp,
|
||||
});
|
||||
|
||||
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
let tunnel = async move {
|
||||
let ping_frequency = client_config.websocket_ping_frequency;
|
||||
tokio::spawn(
|
||||
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Forward websocket rx to local rx
|
||||
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||
}
|
||||
.instrument(span.clone());
|
||||
tokio::spawn(tunnel);
|
||||
}
|
||||
}
|
221
src/tunnel/client/client.rs
Normal file
221
src/tunnel/client/client.rs
Normal file
|
@ -0,0 +1,221 @@
|
|||
use crate::tunnel;
|
||||
use crate::tunnel::client::cnx_pool::WsConnection;
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::connectors::TunnelConnector;
|
||||
use crate::tunnel::listeners::TunnelListener;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
|
||||
use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
|
||||
use anyhow::Context;
|
||||
use futures_util::pin_mut;
|
||||
use hyper::header::COOKIE;
|
||||
use jsonwebtoken::TokenData;
|
||||
use log::debug;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{error, event, span, Instrument, Level, Span};
|
||||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClient {
|
||||
pub config: Arc<WsClientConfig>,
|
||||
pub cnx_pool: bb8::Pool<WsConnection>,
|
||||
_tls_reloader: Arc<TlsReloader>,
|
||||
}
|
||||
|
||||
impl WsClient {
|
||||
pub async fn new(
|
||||
config: WsClientConfig,
|
||||
connection_min_idle: u32,
|
||||
connection_retry_max_backoff_sec: Duration,
|
||||
) -> anyhow::Result<Self> {
|
||||
let config = Arc::new(config);
|
||||
let cnx = WsConnection::new(config.clone());
|
||||
let tls_reloader = TlsReloader::new_for_client(config.clone()).with_context(|| "Cannot create tls reloader")?;
|
||||
let cnx_pool = bb8::Pool::builder()
|
||||
.max_size(1000)
|
||||
.min_idle(Some(connection_min_idle))
|
||||
.max_lifetime(Some(Duration::from_secs(30)))
|
||||
.connection_timeout(connection_retry_max_backoff_sec)
|
||||
.retry_connection(true)
|
||||
.build(cnx)
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
cnx_pool,
|
||||
_tls_reloader: Arc::new(tls_reloader),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl WsClient {
|
||||
async fn connect_to_server<R, W>(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
remote_cfg: &RemoteAddr,
|
||||
duplex_stream: (R, W),
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
R: AsyncRead + Send + 'static,
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
// Connect to server with the correct protocol
|
||||
let (ws_rx, ws_tx, response) = match self.config.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
tunnel::transport::websocket::connect(request_id, self, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
tunnel::transport::http2::connect(request_id, self, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Server response: {:?}", response);
|
||||
let (local_rx, local_tx) = duplex_stream;
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
|
||||
// Forward local tx to websocket tx
|
||||
let ping_frequency = self.config.websocket_ping_frequency;
|
||||
tokio::spawn(
|
||||
super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
||||
// Forward websocket rx to local rx
|
||||
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_tunnel(self, tunnel_listener: impl TunnelListener) -> anyhow::Result<()> {
|
||||
pin_mut!(tunnel_listener);
|
||||
while let Some(cnx) = tunnel_listener.next().await {
|
||||
let (cnx_stream, remote_addr) = match cnx {
|
||||
Ok((cnx_stream, remote_addr)) => (cnx_stream, remote_addr),
|
||||
Err(err) => {
|
||||
error!("Error accepting connection: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||
);
|
||||
let client = self.clone();
|
||||
let tunnel = async move {
|
||||
let _ = client
|
||||
.connect_to_server(request_id, &remote_addr, cnx_stream)
|
||||
.await
|
||||
.map_err(|err| error!("{:?}", err));
|
||||
}
|
||||
.instrument(span);
|
||||
|
||||
tokio::spawn(tunnel);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_reverse_tunnel(
|
||||
self,
|
||||
remote_addr: RemoteAddr,
|
||||
connector: impl TunnelConnector,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let client = self.clone();
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||
);
|
||||
// Correctly configure tunnel cfg
|
||||
let (ws_rx, ws_tx, response) = match client.config.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
match tunnel::transport::websocket::connect(request_id, &client, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
{
|
||||
Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response),
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
match tunnel::transport::http2::connect(request_id, &client, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
{
|
||||
Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response),
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err);
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Connect to endpoint
|
||||
event!(parent: &span, Level::DEBUG, "Server response: {:?}", response);
|
||||
let remote = response
|
||||
.headers
|
||||
.get(COOKIE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| {
|
||||
let (validation, decode_key) = JWT_DECODE.deref();
|
||||
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
|
||||
jwt
|
||||
})
|
||||
.map(|jwt| RemoteAddr {
|
||||
protocol: jwt.claims.p,
|
||||
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
|
||||
port: jwt.claims.rp,
|
||||
});
|
||||
|
||||
let (local_rx, local_tx) = match connector.connect(&remote).instrument(span.clone()).await {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
event!(parent: &span, Level::ERROR, "Cannot connect to {remote:?}: {err:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
let tunnel = async move {
|
||||
let ping_frequency = client.config.websocket_ping_frequency;
|
||||
tokio::spawn(
|
||||
super::super::transport::io::propagate_local_to_remote(
|
||||
local_rx,
|
||||
ws_tx,
|
||||
close_tx,
|
||||
Some(ping_frequency),
|
||||
)
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Forward websocket rx to local rx
|
||||
let _ = super::super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||
}
|
||||
.instrument(span.clone());
|
||||
tokio::spawn(tunnel);
|
||||
}
|
||||
}
|
||||
}
|
74
src/tunnel/client/cnx_pool.rs
Normal file
74
src/tunnel/client/cnx_pool.rs
Normal file
|
@ -0,0 +1,74 @@
|
|||
use crate::protocols;
|
||||
use crate::protocols::tls;
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::TransportStream;
|
||||
use async_trait::async_trait;
|
||||
use bb8::ManageConnection;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsConnection(Arc<WsClientConfig>);
|
||||
|
||||
impl WsConnection {
|
||||
pub fn new(config: Arc<WsClientConfig>) -> Self {
|
||||
Self(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WsConnection {
|
||||
type Target = WsClientConfig;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ManageConnection for WsConnection {
|
||||
type Connection = Option<TransportStream>;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
#[instrument(level = "trace", name = "cnx_server", skip_all)]
|
||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||
let so_mark = self.socket_so_mark;
|
||||
let timeout = self.timeout_connect;
|
||||
|
||||
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
|
||||
protocols::tcp::connect_with_http_proxy(
|
||||
http_proxy,
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
protocols::tcp::connect(
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
if self.remote_addr.tls().is_some() {
|
||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||
} else {
|
||||
Ok(Some(TransportStream::Plain(tcp_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
||||
conn.is_none()
|
||||
}
|
||||
}
|
76
src/tunnel/client/config.rs
Normal file
76
src/tunnel/client/config.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use crate::protocols::dns::DnsResolver;
|
||||
use crate::tunnel::TransportAddr;
|
||||
use hyper::header::{HeaderName, HeaderValue};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio_rustls::rustls::pki_types::{DnsName, ServerName};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use url::{Host, Url};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClientConfig {
|
||||
pub remote_addr: TransportAddr,
|
||||
pub socket_so_mark: Option<u32>,
|
||||
pub http_upgrade_path_prefix: String,
|
||||
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
||||
pub http_headers_file: Option<PathBuf>,
|
||||
pub http_header_host: HeaderValue,
|
||||
pub timeout_connect: Duration,
|
||||
pub websocket_ping_frequency: Duration,
|
||||
pub websocket_mask_frame: bool,
|
||||
pub http_proxy: Option<Url>,
|
||||
pub dns_resolver: DnsResolver,
|
||||
}
|
||||
|
||||
impl WsClientConfig {
|
||||
pub const fn websocket_scheme(&self) -> &'static str {
|
||||
match self.remote_addr.tls().is_some() {
|
||||
false => "ws",
|
||||
true => "wss",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn websocket_host_url(&self) -> String {
|
||||
format!("{}:{}", self.remote_addr.host(), self.remote_addr.port())
|
||||
}
|
||||
|
||||
pub fn tls_server_name(&self) -> ServerName<'static> {
|
||||
static INVALID_DNS_NAME: Lazy<DnsName> = Lazy::new(|| DnsName::try_from("dns-name-invalid.com").unwrap());
|
||||
|
||||
self.remote_addr
|
||||
.tls()
|
||||
.and_then(|tls| tls.tls_sni_override.as_ref())
|
||||
.map_or_else(
|
||||
|| match &self.remote_addr.host() {
|
||||
Host::Domain(domain) => ServerName::DnsName(
|
||||
DnsName::try_from(domain.clone()).unwrap_or_else(|_| INVALID_DNS_NAME.clone()),
|
||||
),
|
||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip).into()),
|
||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip).into()),
|
||||
},
|
||||
|sni_override| ServerName::DnsName(sni_override.clone()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TlsClientConfig {
|
||||
pub tls_sni_disabled: bool,
|
||||
pub tls_sni_override: Option<DnsName<'static>>,
|
||||
pub tls_verify_certificate: bool,
|
||||
pub tls_connector: Arc<RwLock<TlsConnector>>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl TlsClientConfig {
|
||||
pub fn tls_connector(&self) -> TlsConnector {
|
||||
self.tls_connector.read().clone()
|
||||
}
|
||||
}
|
8
src/tunnel/client/mod.rs
Normal file
8
src/tunnel/client/mod.rs
Normal file
|
@ -0,0 +1,8 @@
|
|||
#![allow(clippy::module_inception)]
|
||||
mod client;
|
||||
mod cnx_pool;
|
||||
mod config;
|
||||
|
||||
pub use client::WsClient;
|
||||
pub use config::TlsClientConfig;
|
||||
pub use config::WsClientConfig;
|
|
@ -5,10 +5,7 @@ pub mod server;
|
|||
pub mod tls_reloader;
|
||||
mod transport;
|
||||
|
||||
use crate::protocols::tls;
|
||||
use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use bb8::ManageConnection;
|
||||
use crate::{LocalProtocol, TlsClientConfig};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -23,7 +20,6 @@ use std::task::{Context, Poll};
|
|||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tracing::instrument;
|
||||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -304,54 +300,6 @@ impl AsyncWrite for TransportStream {
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ManageConnection for WsClientConfig {
|
||||
type Connection = Option<TransportStream>;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
#[instrument(level = "trace", name = "cnx_server", skip_all)]
|
||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||
let so_mark = self.socket_so_mark;
|
||||
let timeout = self.timeout_connect;
|
||||
|
||||
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
|
||||
protocols::tcp::connect_with_http_proxy(
|
||||
http_proxy,
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
protocols::tcp::connect(
|
||||
self.remote_addr.host(),
|
||||
self.remote_addr.port(),
|
||||
so_mark,
|
||||
timeout,
|
||||
&self.dns_resolver,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
if self.remote_addr.tls().is_some() {
|
||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||
} else {
|
||||
Ok(Some(TransportStream::Plain(tcp_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
||||
conn.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
|
||||
match addr.ip() {
|
||||
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::protocols::tls;
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
|
||||
use crate::{WsClientConfig, WsServerConfig};
|
||||
use crate::WsServerConfig;
|
||||
use anyhow::Context;
|
||||
use log::trace;
|
||||
use notify::{EventKind, RecommendedWatcher, Watcher};
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::tunnel::client::WsClient;
|
||||
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme};
|
||||
use crate::WsClientConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use http_body_util::{BodyExt, BodyStream, StreamBody};
|
||||
|
@ -99,22 +99,24 @@ impl TunnelWrite for Http2TunnelWrite {
|
|||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
client: &WsClient,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
let mut pooled_cnx = match client.cnx_pool.get().await {
|
||||
Ok(cnx) => Ok(cnx),
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
}?;
|
||||
|
||||
// In http2 HOST header does not exist, it is explicitly set in the authority from the request uri
|
||||
let (headers_file, authority) = client_cfg
|
||||
let (headers_file, authority) =
|
||||
client
|
||||
.config
|
||||
.http_headers_file
|
||||
.as_ref()
|
||||
.map_or((None, None), |headers_file_path| {
|
||||
let (host, headers) = headers_from_file(headers_file_path);
|
||||
let host = if let Some((_, v)) = host {
|
||||
match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) {
|
||||
match (client.config.remote_addr.scheme(), client.config.remote_addr.port()) {
|
||||
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
|
||||
Some(v.to_str().unwrap_or("").to_string())
|
||||
}
|
||||
|
@ -131,23 +133,23 @@ pub async fn connect(
|
|||
.method("POST")
|
||||
.uri(format!(
|
||||
"{}://{}/{}/events",
|
||||
client_cfg.remote_addr.scheme(),
|
||||
client.config.remote_addr.scheme(),
|
||||
authority
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| client_cfg.http_header_host.to_str().unwrap_or("")),
|
||||
&client_cfg.http_upgrade_path_prefix
|
||||
.unwrap_or_else(|| client.config.http_header_host.to_str().unwrap_or("")),
|
||||
&client.config.http_upgrade_path_prefix
|
||||
))
|
||||
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.version(hyper::Version::HTTP_2);
|
||||
|
||||
let headers = req.headers_mut().unwrap();
|
||||
for (k, v) in &client_cfg.http_headers {
|
||||
for (k, v) in &client.config.http_headers {
|
||||
let _ = headers.remove(k);
|
||||
headers.append(k, v.clone());
|
||||
}
|
||||
|
||||
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||
if let Some(auth) = &client.config.http_upgrade_credentials {
|
||||
let _ = headers.remove(AUTHORIZATION);
|
||||
headers.append(AUTHORIZATION, auth.clone());
|
||||
}
|
||||
|
@ -164,7 +166,7 @@ pub async fn connect(
|
|||
let req = req.body(body).with_context(|| {
|
||||
format!(
|
||||
"failed to build HTTP request to contact the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
client.config.remote_addr
|
||||
)
|
||||
})?;
|
||||
debug!("with HTTP upgrade request {:?}", req);
|
||||
|
@ -172,11 +174,11 @@ pub async fn connect(
|
|||
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.adaptive_window(true)
|
||||
.keep_alive_interval(client_cfg.websocket_ping_frequency)
|
||||
.keep_alive_interval(client.config.websocket_ping_frequency)
|
||||
.keep_alive_while_idle(false)
|
||||
.handshake(TokioIo::new(transport))
|
||||
.await
|
||||
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client.config.remote_addr))?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = cnx.await {
|
||||
error!("{:?}", err)
|
||||
|
@ -186,7 +188,7 @@ pub async fn connect(
|
|||
let response = request_sender
|
||||
.send_request(req)
|
||||
.await
|
||||
.with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
|
||||
.with_context(|| format!("failed to send http2 request with the server {:?}", client.config.remote_addr))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::tunnel::client::WsClient;
|
||||
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
|
||||
use crate::WsClientConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
|
@ -135,10 +135,11 @@ impl TunnelRead for WebsocketTunnelRead {
|
|||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
client: &WsClient,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
let client_cfg = &client.config;
|
||||
let mut pooled_cnx = match client.cnx_pool.get().await {
|
||||
Ok(cnx) => Ok(cnx),
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
}?;
|
||||
|
|
Loading…
Reference in a new issue