diff --git a/src/main.rs b/src/main.rs index 6d5b8c3..c6f75f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use crate::tunnel::connectors::{Socks5TunnelConnector, TcpTunnelConnector, UdpTu use crate::tunnel::listeners::{ new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, }; +use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig}; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use base64::Engine; use clap::Parser; @@ -20,16 +21,16 @@ use log::debug; use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; +use std::io; use std::io::ErrorKind; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use std::{fmt, io}; use tokio::select; -use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer}; +use tokio_rustls::rustls::pki_types::DnsName; use tracing::{error, info}; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; @@ -690,49 +691,6 @@ fn parse_server_url(arg: &str) -> Result { Ok(url) } -#[derive(Debug)] -pub struct TlsServerConfig { - pub tls_certificate: Mutex>>, - pub tls_key: Mutex>, - pub tls_client_ca_certificates: Option>>>, - pub tls_certificate_path: Option, - pub tls_key_path: Option, - pub tls_client_ca_certs_path: Option, -} - -pub struct WsServerConfig { - pub socket_so_mark: Option, - pub bind: SocketAddr, - pub websocket_ping_frequency: Option, - pub timeout_connect: Duration, - pub websocket_mask_frame: bool, - pub tls: Option, - pub dns_resolver: DnsResolver, - pub restriction_config: Option, -} - -impl Debug for WsServerConfig { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("WsServerConfig") - .field("socket_so_mark", &self.socket_so_mark) - .field("bind", &self.bind) - .field("websocket_ping_frequency", &self.websocket_ping_frequency) - .field("timeout_connect", &self.timeout_connect) - .field("websocket_mask_frame", &self.websocket_mask_frame) - .field("restriction_config", &self.restriction_config) - .field("tls", &self.tls.is_some()) - .field( - "mTLS", - &self - .tls - .as_ref() - .map(|x| x.tls_client_ca_certificates.is_some()) - .unwrap_or(false), - ) - .finish() - } -} - #[tokio::main] async fn main() -> anyhow::Result<()> { let args = Wstunnel::parse(); @@ -1194,18 +1152,17 @@ async fn main() -> anyhow::Result<()> { .expect("Cannot create DNS resolver"), restriction_config: args.restrict_config, }; + let server = WsServer::new(server_config); info!( "Starting wstunnel server v{} with config {:?}", env!("CARGO_PKG_VERSION"), - server_config + server.config ); debug!("Restriction rules: {:#?}", restrictions); - tunnel::server::run_server(Arc::new(server_config), restrictions) - .await - .unwrap_or_else(|err| { - panic!("Cannot start wstunnel server: {:?}", err); - }); + server.serve(restrictions).await.unwrap_or_else(|err| { + panic!("Cannot start wstunnel server: {:?}", err); + }); } } diff --git a/src/protocols/tls/server.rs b/src/protocols/tls/server.rs index 5e3d24c..13f627d 100644 --- a/src/protocols/tls/server.rs +++ b/src/protocols/tls/server.rs @@ -1,4 +1,3 @@ -use crate::TlsServerConfig; use anyhow::{anyhow, Context}; use std::fs::File; @@ -10,6 +9,7 @@ use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; use crate::tunnel::client::WsClientConfig; +use crate::tunnel::server::TlsServerConfig; use crate::tunnel::TransportAddr; use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 7d96978..bac5dfc 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,7 +2,7 @@ pub mod client; pub mod connectors; pub mod listeners; pub mod server; -pub mod tls_reloader; +mod tls_reloader; mod transport; use crate::{LocalProtocol, TlsClientConfig}; diff --git a/src/tunnel/server/mod.rs b/src/tunnel/server/mod.rs new file mode 100644 index 0000000..2d22698 --- /dev/null +++ b/src/tunnel/server/mod.rs @@ -0,0 +1,6 @@ +#![allow(clippy::module_inception)] +mod server; + +pub use server::TlsServerConfig; +pub use server::WsServer; +pub use server::WsServerConfig; diff --git a/src/tunnel/server.rs b/src/tunnel/server/server.rs similarity index 54% rename from src/tunnel/server.rs rename to src/tunnel/server/server.rs index d9eef43..41496e6 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server/server.rs @@ -5,16 +5,19 @@ use futures_util::{pin_mut, FutureExt, StreamExt}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyStream, Either, StreamBody}; use std::cmp::min; +use std::fmt; +use std::fmt::{Debug, Formatter}; use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; +use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::{protocols, LocalProtocol, TlsServerConfig, WsServerConfig}; +use crate::tunnel::{transport, tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use crate::{protocols, LocalProtocol}; use hyper::body::{Frame, Incoming}; use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; @@ -27,6 +30,7 @@ use once_cell::sync::Lazy; use parking_lot::Mutex; use socket2::SockRef; +use crate::protocols::dns::DnsResolver; use crate::protocols::tls; use crate::protocols::udp::{UdpStream, UdpStreamWriter}; use crate::restrictions::config_reloader::RestrictionsRulesReloader; @@ -44,153 +48,407 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::select; use tokio::sync::{mpsc, oneshot}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::TlsAcceptor; use tokio_stream::wrappers::ReceiverStream; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; use uuid::Uuid; -async fn run_tunnel( - server_config: &WsServerConfig, - restriction: &RestrictionConfig, - remote: RemoteAddr, - client_address: SocketAddr, -) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { - match remote.protocol { - LocalProtocol::Udp { timeout, .. } => { - let (rx, tx) = UdpTunnelConnector::new( - &remote.host, - remote.port, - server_config.socket_so_mark, - timeout.unwrap_or(Duration::from_secs(10)), - &server_config.dns_resolver, - ) - .connect(&None) - .await?; +#[derive(Debug)] +pub struct TlsServerConfig { + pub tls_certificate: Mutex>>, + pub tls_key: Mutex>, + pub tls_client_ca_certificates: Option>>>, + pub tls_certificate_path: Option, + pub tls_key_path: Option, + pub tls_client_ca_certs_path: Option, +} - Ok((remote, Box::pin(rx), Box::pin(tx))) +pub struct WsServerConfig { + pub socket_so_mark: Option, + pub bind: SocketAddr, + pub websocket_ping_frequency: Option, + pub timeout_connect: Duration, + pub websocket_mask_frame: bool, + pub tls: Option, + pub dns_resolver: DnsResolver, + pub restriction_config: Option, +} + +#[derive(Clone)] +pub struct WsServer { + pub config: Arc, +} + +impl WsServer { + pub fn new(config: WsServerConfig) -> Self { + Self { + config: Arc::new(config), } - LocalProtocol::Tcp { proxy_protocol } => { - let (rx, mut tx) = TcpTunnelConnector::new( - &remote.host, - remote.port, - server_config.socket_so_mark, - Duration::from_secs(10), - &server_config.dns_resolver, - ) - .connect(&None) - .await?; + } - if proxy_protocol { - let header = ppp::v2::Builder::with_addresses( - ppp::v2::Version::Two | ppp::v2::Command::Proxy, - ppp::v2::Protocol::Stream, - (client_address, tx.local_addr().unwrap()), + async fn run_tunnel( + &self, + restriction: &RestrictionConfig, + remote: RemoteAddr, + client_address: SocketAddr, + ) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { + match remote.protocol { + LocalProtocol::Udp { timeout, .. } => { + let (rx, tx) = UdpTunnelConnector::new( + &remote.host, + remote.port, + self.config.socket_so_mark, + timeout.unwrap_or(Duration::from_secs(10)), + &self.config.dns_resolver, ) - .build() - .unwrap(); - let _ = tx.write_all(&header).await; + .connect(&None) + .await?; + + Ok((remote, Box::pin(rx), Box::pin(tx))) + } + LocalProtocol::Tcp { proxy_protocol } => { + let (rx, mut tx) = TcpTunnelConnector::new( + &remote.host, + remote.port, + self.config.socket_so_mark, + Duration::from_secs(10), + &self.config.dns_resolver, + ) + .connect(&None) + .await?; + + if proxy_protocol { + let header = ppp::v2::Builder::with_addresses( + ppp::v2::Version::Two | ppp::v2::Command::Proxy, + ppp::v2::Protocol::Stream, + (client_address, tx.local_addr().unwrap()), + ) + .build() + .unwrap(); + let _ = tx.write_all(&header).await; + } + + Ok((remote, Box::pin(rx), Box::pin(tx))) + } + LocalProtocol::ReverseTcp => { + type Item = ::OkReturn; + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + TcpTunnelListener::new(bind.parse()?, local_srv.clone(), false).await + }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + LocalProtocol::ReverseUdp { timeout } => { + type Item = ((UdpStream, UdpStreamWriter), RemoteAddr); + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + new_udp_listener(bind.parse()?, local_srv.clone(), timeout).await + }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + LocalProtocol::ReverseSocks5 { timeout, credentials } => { + type Item = ::OkReturn; + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + Socks5TunnelListener::new(bind.parse()?, timeout, credentials).await + }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + LocalProtocol::ReverseHttpProxy { timeout, credentials } => { + type Item = ::OkReturn; + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); + let listening_server = async { + let bind = format!("{}:{}", local_srv.0, local_srv.1); + HttpProxyTunnelListener::new(bind.parse()?, timeout, credentials, false).await + }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + #[cfg(unix)] + LocalProtocol::ReverseUnix { ref path } => { + use crate::tunnel::listeners::UnixTunnelListener; + type Item = ::OkReturn; + #[allow(clippy::type_complexity)] + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + + let remote_port = find_mapped_port(remote.port, restriction); + let local_srv = (remote.host, remote_port); + let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await }; + let ((local_rx, local_tx), remote) = + run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) + } + #[cfg(not(unix))] + LocalProtocol::ReverseUnix { .. } => { + error!("Received an unsupported target protocol {:?}", remote); + Err(anyhow::anyhow!("Invalid upgrade request")) + } + LocalProtocol::Stdio + | LocalProtocol::Socks5 { .. } + | LocalProtocol::TProxyTcp + | LocalProtocol::TProxyUdp { .. } + | LocalProtocol::HttpProxy { .. } + | LocalProtocol::Unix { .. } => { + error!("Received an unsupported target protocol {:?}", remote); + Err(anyhow::anyhow!("Invalid upgrade request")) + } + } + } + + pub async fn serve(self, restrictions: RestrictionsRules) -> anyhow::Result<()> { + info!("Starting wstunnel server listening on {}", self.config.bind); + + // setup upgrade request handler + let mk_websocket_upgrade_fn = |server: WsServer, + restrictions: Arc, + restrict_path: Option, + client_addr: SocketAddr| { + move |req: Request| { + ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) + .map::, _>(Ok) + } + }; + + let mk_http_upgrade_fn = |server: WsServer, + restrictions: Arc, + restrict_path: Option, + client_addr: SocketAddr| { + move |req: Request| { + http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) + .map::, _>(Ok) + } + }; + + let mk_auto_upgrade_fn = |server: WsServer, + restrictions: Arc, + restrict_path: Option, + client_addr: SocketAddr| { + move |req: Request| { + let server = server.clone(); + let restrictions = restrictions.clone(); + let restrict_path = restrict_path.clone(); + async move { + if fastwebsockets::upgrade::is_upgrade_request(&req) { + ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path, client_addr, req) + .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) + .await + } else if req.version() == Version::HTTP_2 { + http_server_upgrade( + server.clone(), + restrictions.clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + .await + } else { + error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version()); + Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid protocol request".to_string())) + .unwrap()) + } + } + } + }; + + // Init TLS if needed + let mut tls_context = if let Some(tls_config) = &self.config.tls { + let tls_context = TlsContext { + tls_acceptor: Arc::new(tls::tls_acceptor( + tls_config, + Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), + )?), + tls_reloader: TlsReloader::new_for_server(self.config.clone())?, + tls_config, + }; + Some(tls_context) + } else { + None + }; + + // Bind server and run forever to serve incoming connections. + let mut restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; + let mut await_config_reload = Box::pin(restrictions.reload_notifier()); + let listener = TcpListener::bind(&self.config.bind).await?; + + loop { + let cnx = select! { + biased; + + _ = &mut await_config_reload => { + drop(await_config_reload); + restrictions.reload_restrictions_config(); + await_config_reload = Box::pin(restrictions.reload_notifier()); + continue; + }, + + cnx = listener.accept() => { cnx } + }; + + let (stream, peer_addr) = match cnx { + Ok(ret) => ret, + Err(err) => { + warn!("Error while accepting connection {:?}", err); + continue; + } + }; + + if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) { + warn!("Error while configuring server socket {:?}", err); } - Ok((remote, Box::pin(rx), Box::pin(tx))) + let span = span!( + Level::INFO, + "tunnel", + id = tracing::field::Empty, + remote = tracing::field::Empty, + peer = peer_addr.to_string(), + forwarded_for = tracing::field::Empty + ); + + info!("Accepting connection"); + let server = self.clone(); + let restrictions = restrictions.restrictions_rules().clone(); + + // Check if we need to enable TLS or not + match tls_context.as_mut() { + Some(tls) => { + // Reload TLS certificate if needed + let tls_acceptor = tls.tls_acceptor().clone(); + let fut = async move { + info!("Doing TLS handshake"); + let tls_stream = match tls_acceptor.accept(stream).await { + Ok(tls_stream) => hyper_util::rt::TokioIo::new(tls_stream), + Err(err) => { + error!("error while accepting TLS connection {}", err); + return; + } + }; + + let tls_ctx = tls_stream.inner().get_ref().1; + // extract client certificate common name if any + let restrict_path = tls_ctx + .peer_certificates() + .and_then(tls::find_leaf_certificate) + .and_then(|c| tls::cn_from_certificate(&c)); + match tls_ctx.alpn_protocol() { + // http2 + Some(b"h2") => { + let mut conn_builder = http2::Builder::new(TokioExecutor::new()); + if let Some(ping) = server.config.websocket_ping_frequency { + conn_builder.keep_alive_interval(ping); + } + + let http_upgrade_fn = + mk_http_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); + let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); + if let Err(e) = con_fut.await { + error!("Error while upgrading cnx to http: {:?}", e); + } + } + // websocket + _ => { + let websocket_upgrade_fn = + mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); + let conn_fut = http1::Builder::new() + .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) + .with_upgrades(); + + if let Err(e) = conn_fut.await { + error!("Error while upgrading cnx: {:?}", e); + } + } + }; + } + .instrument(span); + + tokio::spawn(fut); + // Normal + } + // HTTP without TLS + None => { + let fut = async move { + let stream = hyper_util::rt::TokioIo::new(stream); + let mut conn_fut = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + if let Some(ping) = server.config.websocket_ping_frequency { + conn_fut.http2().keep_alive_interval(ping); + } + + let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions.clone(), None, peer_addr); + let upgradable = + conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); + + if let Err(e) = upgradable.await { + error!("Error while upgrading cnx to websocket: {:?}", e); + } + } + .instrument(span); + + tokio::spawn(fut); + } + } } - LocalProtocol::ReverseTcp => { - type Item = ::OkReturn; - #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = - Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + } +} - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let listening_server = async { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - TcpTunnelListener::new(bind.parse()?, local_srv.clone(), false).await - }; - let ((local_rx, local_tx), remote) = - run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - - Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) - } - LocalProtocol::ReverseUdp { timeout } => { - type Item = ((UdpStream, UdpStreamWriter), RemoteAddr); - #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = - Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let listening_server = async { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - new_udp_listener(bind.parse()?, local_srv.clone(), timeout).await - }; - let ((local_rx, local_tx), remote) = - run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) - } - LocalProtocol::ReverseSocks5 { timeout, credentials } => { - type Item = ::OkReturn; - #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = - Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let listening_server = async { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - Socks5TunnelListener::new(bind.parse()?, timeout, credentials).await - }; - let ((local_rx, local_tx), remote) = - run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - - Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) - } - LocalProtocol::ReverseHttpProxy { timeout, credentials } => { - type Item = ::OkReturn; - #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = - Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let listening_server = async { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - HttpProxyTunnelListener::new(bind.parse()?, timeout, credentials, false).await - }; - let ((local_rx, local_tx), remote) = - run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - - Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) - } - #[cfg(unix)] - LocalProtocol::ReverseUnix { ref path } => { - use crate::tunnel::listeners::UnixTunnelListener; - type Item = ::OkReturn; - #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver>>> = - Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); - - let remote_port = find_mapped_port(remote.port, restriction); - let local_srv = (remote.host, remote_port); - let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await }; - let ((local_rx, local_tx), remote) = - run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - - Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) - } - #[cfg(not(unix))] - LocalProtocol::ReverseUnix { .. } => { - error!("Received an unsupported target protocol {:?}", remote); - Err(anyhow::anyhow!("Invalid upgrade request")) - } - LocalProtocol::Stdio - | LocalProtocol::Socks5 { .. } - | LocalProtocol::TProxyTcp - | LocalProtocol::TProxyUdp { .. } - | LocalProtocol::HttpProxy { .. } - | LocalProtocol::Unix { .. } => { - error!("Received an unsupported target protocol {:?}", remote); - Err(anyhow::anyhow!("Invalid upgrade request")) - } +impl Debug for WsServerConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("WsServerConfig") + .field("socket_so_mark", &self.socket_so_mark) + .field("bind", &self.bind) + .field("websocket_ping_frequency", &self.websocket_ping_frequency) + .field("timeout_connect", &self.timeout_connect) + .field("websocket_mask_frame", &self.websocket_mask_frame) + .field("restriction_config", &self.restriction_config) + .field("tls", &self.tls.is_some()) + .field( + "mTLS", + &self + .tls + .as_ref() + .map(|x| x.tls_client_ca_certificates.is_some()) + .unwrap_or(false), + ) + .finish() } } @@ -396,7 +654,7 @@ fn validate_tunnel<'a>( } async fn ws_server_upgrade( - server_config: Arc, + server: WsServer, restrictions: Arc, restrict_path_prefix: Option, mut client_addr: SocketAddr, @@ -466,7 +724,7 @@ async fn ws_server_upgrade( }; let req_protocol = remote.protocol.clone(); - let tunnel = match run_tunnel(&server_config, restriction, remote, client_addr).await { + let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); @@ -500,20 +758,16 @@ async fn ws_server_upgrade( } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); + ws_tx.set_auto_apply_mask(server.config.websocket_mask_frame); tokio::task::spawn( - super::transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) + transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) .instrument(Span::current()), ); - let _ = super::transport::io::propagate_local_to_remote( - local_rx, - WebsocketTunnelWrite::new(ws_tx), - close_tx, - None, - ) - .await; + let _ = + transport::io::propagate_local_to_remote(local_rx, WebsocketTunnelWrite::new(ws_tx), close_tx, None) + .await; } .instrument(Span::current()), ); @@ -539,7 +793,7 @@ async fn ws_server_upgrade( } async fn http_server_upgrade( - server_config: Arc, + server: WsServer, restrictions: Arc, restrict_path_prefix: Option, mut client_addr: SocketAddr, @@ -600,7 +854,7 @@ async fn http_server_upgrade( }; let req_protocol = remote.protocol.clone(); - let tunnel = match run_tunnel(&server_config, restriction, remote, client_addr).await { + let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); @@ -630,13 +884,12 @@ async fn http_server_upgrade( async move { let (close_tx, close_rx) = oneshot::channel::<()>(); tokio::task::spawn( - super::transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx) + transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx) .instrument(Span::current()), ); let _ = - super::transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None) - .await; + transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None).await; } .instrument(Span::current()), ); @@ -678,211 +931,6 @@ impl TlsContext<'_> { } } -pub async fn run_server(server_config: Arc, restrictions: RestrictionsRules) -> anyhow::Result<()> { - info!("Starting wstunnel server listening on {}", server_config.bind); - - // setup upgrade request handler - let mk_websocket_upgrade_fn = |server_config: Arc, - restrictions: Arc, - restrict_path: Option, - client_addr: SocketAddr| { - move |req: Request| { - ws_server_upgrade( - server_config.clone(), - restrictions.clone(), - restrict_path.clone(), - client_addr, - req, - ) - .map::, _>(Ok) - } - }; - - let mk_http_upgrade_fn = |server_config: Arc, - restrictions: Arc, - restrict_path: Option, - client_addr: SocketAddr| { - move |req: Request| { - http_server_upgrade( - server_config.clone(), - restrictions.clone(), - restrict_path.clone(), - client_addr, - req, - ) - .map::, _>(Ok) - } - }; - - let mk_auto_upgrade_fn = |server_config: Arc, - restrictions: Arc, - restrict_path: Option, - client_addr: SocketAddr| { - move |req: Request| { - let server_config = server_config.clone(); - let restrictions = restrictions.clone(); - let restrict_path = restrict_path.clone(); - async move { - if fastwebsockets::upgrade::is_upgrade_request(&req) { - ws_server_upgrade(server_config.clone(), restrictions.clone(), restrict_path, client_addr, req) - .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) - .await - } else if req.version() == Version::HTTP_2 { - http_server_upgrade( - server_config.clone(), - restrictions.clone(), - restrict_path.clone(), - client_addr, - req, - ) - .map::, _>(Ok) - .await - } else { - error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version()); - Ok(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid protocol request".to_string())) - .unwrap()) - } - } - } - }; - - // Init TLS if needed - let mut tls_context = if let Some(tls_config) = &server_config.tls { - let tls_context = TlsContext { - tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]))?), - tls_reloader: TlsReloader::new_for_server(server_config.clone())?, - tls_config, - }; - Some(tls_context) - } else { - None - }; - - // Bind server and run forever to serve incoming connections. - let mut restrictions = RestrictionsRulesReloader::new(restrictions, server_config.restriction_config.clone())?; - let mut await_config_reload = Box::pin(restrictions.reload_notifier()); - let listener = TcpListener::bind(&server_config.bind).await?; - - loop { - let cnx = select! { - biased; - - _ = &mut await_config_reload => { - drop(await_config_reload); - restrictions.reload_restrictions_config(); - await_config_reload = Box::pin(restrictions.reload_notifier()); - continue; - }, - - cnx = listener.accept() => { cnx } - }; - - let (stream, peer_addr) = match cnx { - Ok(ret) => ret, - Err(err) => { - warn!("Error while accepting connection {:?}", err); - continue; - } - }; - - if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) { - warn!("Error while configuring server socket {:?}", err); - } - - let span = span!( - Level::INFO, - "tunnel", - id = tracing::field::Empty, - remote = tracing::field::Empty, - peer = peer_addr.to_string(), - forwarded_for = tracing::field::Empty - ); - - info!("Accepting connection"); - let server_config = server_config.clone(); - let restrictions = restrictions.restrictions_rules().clone(); - - // Check if we need to enable TLS or not - match tls_context.as_mut() { - Some(tls) => { - // Reload TLS certificate if needed - let tls_acceptor = tls.tls_acceptor().clone(); - let fut = async move { - info!("Doing TLS handshake"); - let tls_stream = match tls_acceptor.accept(stream).await { - Ok(tls_stream) => hyper_util::rt::TokioIo::new(tls_stream), - Err(err) => { - error!("error while accepting TLS connection {}", err); - return; - } - }; - - let tls_ctx = tls_stream.inner().get_ref().1; - // extract client certificate common name if any - let restrict_path = tls_ctx - .peer_certificates() - .and_then(tls::find_leaf_certificate) - .and_then(|c| tls::cn_from_certificate(&c)); - match tls_ctx.alpn_protocol() { - // http2 - Some(b"h2") => { - let mut conn_builder = http2::Builder::new(TokioExecutor::new()); - if let Some(ping) = server_config.websocket_ping_frequency { - conn_builder.keep_alive_interval(ping); - } - - let http_upgrade_fn = - mk_http_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr); - let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); - if let Err(e) = con_fut.await { - error!("Error while upgrading cnx to http: {:?}", e); - } - } - // websocket - _ => { - let websocket_upgrade_fn = - mk_websocket_upgrade_fn(server_config, restrictions.clone(), restrict_path, peer_addr); - let conn_fut = http1::Builder::new() - .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) - .with_upgrades(); - - if let Err(e) = conn_fut.await { - error!("Error while upgrading cnx: {:?}", e); - } - } - }; - } - .instrument(span); - - tokio::spawn(fut); - // Normal - } - // HTTP without TLS - None => { - let fut = async move { - let stream = hyper_util::rt::TokioIo::new(stream); - let mut conn_fut = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); - if let Some(ping) = server_config.websocket_ping_frequency { - conn_fut.http2().keep_alive_interval(ping); - } - - let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), None, peer_addr); - let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); - - if let Err(e) = upgradable.await { - error!("Error while upgrading cnx to websocket: {:?}", e); - } - } - .instrument(span); - - tokio::spawn(fut); - } - } - } -} - #[allow(clippy::type_complexity)] async fn run_listening_server( local_srv: &(Host, u16), diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs index 8af275b..2cd7346 100644 --- a/src/tunnel/tls_reloader.rs +++ b/src/tunnel/tls_reloader.rs @@ -1,7 +1,7 @@ use crate::protocols::tls; use crate::tunnel::client::WsClientConfig; +use crate::tunnel::server::WsServerConfig; use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server}; -use crate::WsServerConfig; use anyhow::Context; use log::trace; use notify::{EventKind, RecommendedWatcher, Watcher};