diff --git a/Cargo.lock b/Cargo.lock index 120582a..7cbf7f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,12 @@ version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "asn1-rs" version = "0.6.2" @@ -3121,6 +3127,7 @@ version = "10.1.4" dependencies = [ "ahash", "anyhow", + "arc-swap", "async-channel", "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index 7bb9e16..648d6ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ futures-util = { version = "0.3.30" } hickory-resolver = { version = "0.24.1", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls", "native-certs"] } ppp = { version = "2.2.0", features = [] } async-channel = { version = "2.3.1", features = [] } +arc-swap = { version = "1.7.1", features = [] } # For config file parsing regex = { version = "1.11.0", default-features = false, features = ["std", "perf"] } diff --git a/src/restrictions/config_reloader.rs b/src/restrictions/config_reloader.rs index e31f6ba..99b22b7 100644 --- a/src/restrictions/config_reloader.rs +++ b/src/restrictions/config_reloader.rs @@ -1,6 +1,7 @@ use super::types::RestrictionsRules; use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static}; use anyhow::Context; +use arc_swap::ArcSwap; use log::trace; use notify::{EventKind, RecommendedWatcher, Watcher}; use parking_lot::Mutex; @@ -8,33 +9,32 @@ use std::path::PathBuf; use std::sync::Arc; use std::thread; use std::time::Duration; -use tokio::sync::futures::Notified; -use tokio::sync::Notify; use tracing::{error, info, warn}; struct ConfigReloaderState { fs_watcher: Mutex, config_path: PathBuf, - should_reload_config: Notify, } +#[derive(Clone)] enum RestrictionsRulesReloaderState { - Static(Notify), + Static, Config(Arc), } impl RestrictionsRulesReloaderState { fn fs_watcher(&self) -> &Mutex { match self { - Static(_) => unreachable!(), + Static => unreachable!(), Config(this) => &this.fs_watcher, } } } +#[derive(Clone)] pub struct RestrictionsRulesReloader { state: RestrictionsRulesReloaderState, - restrictions: Arc, + restrictions: Arc>, } impl RestrictionsRulesReloader { @@ -44,37 +44,40 @@ impl RestrictionsRulesReloader { config_path } else { return Ok(Self { - state: Static(Notify::new()), - restrictions: Arc::new(restrictions_rules), + state: Static, + restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)), }); }; - - let this = Arc::new(ConfigReloaderState { - fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?), - should_reload_config: Notify::new(), - config_path, - }); + let reloader = Self { + state: Config(Arc::new(ConfigReloaderState { + fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?), + config_path, + })), + restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)), + }; info!("Starting to watch restriction config file for changes to reload them"); let mut watcher = notify::recommended_watcher({ - let this = Config(this.clone()); + let reloader = reloader.clone(); - move |event: notify::Result| Self::handle_config_fs_event(&this, event) + move |event: notify::Result| Self::handle_config_fs_event(&reloader, event) }) .with_context(|| "Cannot create restriction config watcher")?; - watcher.watch(&this.config_path, notify::RecursiveMode::NonRecursive)?; - *this.fs_watcher.lock() = watcher; + match &reloader.state { + Static => {} + Config(cfg) => { + watcher.watch(&cfg.config_path, notify::RecursiveMode::NonRecursive)?; + *cfg.fs_watcher.lock() = watcher + } + } - Ok(Self { - state: Config(this), - restrictions: Arc::new(restrictions_rules), - }) + Ok(reloader) } - pub fn reload_restrictions_config(&mut self) { + pub fn reload_restrictions_config(&self) { let restrictions = match &self.state { - Static(_) => return, + Static => return, Config(st) => match RestrictionsRules::from_config_file(&st.config_path) { Ok(restrictions) => { info!("Restrictions config file has been reloaded"); @@ -87,21 +90,14 @@ impl RestrictionsRulesReloader { }, }; - self.restrictions = Arc::new(restrictions); + self.restrictions.store(Arc::new(restrictions)); } - pub const fn restrictions_rules(&self) -> &Arc { + pub const fn restrictions_rules(&self) -> &Arc> { &self.restrictions } - pub fn reload_notifier(&self) -> Notified { - match &self.state { - Static(st) => st.notified(), - Config(st) => st.should_reload_config.notified(), - } - } - - fn try_rewatch_config(this: RestrictionsRulesReloaderState, path: PathBuf) { + fn try_rewatch_config(this: RestrictionsRulesReloader, path: PathBuf) { thread::spawn(move || { while !path.exists() { warn!( @@ -110,7 +106,7 @@ impl RestrictionsRulesReloader { ); thread::sleep(Duration::from_secs(10)); } - let mut watcher = this.fs_watcher().lock(); + let mut watcher = this.state.fs_watcher().lock(); let _ = watcher.unwatch(&path); let Ok(_) = watcher .watch(&path, notify::RecursiveMode::NonRecursive) @@ -123,23 +119,20 @@ impl RestrictionsRulesReloader { }; drop(watcher); - // Generate a fake event to force-reload the certificate + // Generate a fake event to force-reload the config let event = notify::Event { kind: EventKind::Create(notify::event::CreateKind::Any), paths: vec![path], attrs: Default::default(), }; - match &this { - Static(_) => Self::handle_config_fs_event(&this, Ok(event)), - Config(_) => Self::handle_config_fs_event(&this, Ok(event)), - } + Self::handle_config_fs_event(&this, Ok(event)) }); } - fn handle_config_fs_event(this: &RestrictionsRulesReloaderState, event: notify::Result) { - let this = match this { - Static(_) => return, + fn handle_config_fs_event(reloader: &RestrictionsRulesReloader, event: notify::Result) { + let this = match &reloader.state { + Static => return, Config(st) => st, }; @@ -159,11 +152,11 @@ impl RestrictionsRulesReloader { if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) { match event.kind { EventKind::Create(_) | EventKind::Modify(_) => { - this.should_reload_config.notify_one(); + reloader.reload_restrictions_config(); } EventKind::Remove(_) => { warn!("Restriction config file has been removed, trying to re-set a watch for it"); - Self::try_rewatch_config(Config(this.clone()), path.to_path_buf()); + Self::try_rewatch_config(reloader.clone(), path.to_path_buf()); } EventKind::Access(_) | EventKind::Other | EventKind::Any => { trace!("Ignoring event {:?}", event); diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 357f6e8..929965f 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -4,16 +4,11 @@ use http_body_util::Either; use std::fmt; use std::fmt::{Debug, Formatter}; -use bytes::Bytes; -use http_body_util::combinators::BoxBody; -use std::net::SocketAddr; -use std::path::PathBuf; -use std::pin::Pin; -use std::sync::{Arc, LazyLock}; -use std::time::Duration; - use crate::protocols; use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr}; +use arc_swap::ArcSwap; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; @@ -21,6 +16,11 @@ use hyper::{http, Request, Response, StatusCode, Version}; use hyper_util::rt::{TokioExecutor, TokioTimer}; use parking_lot::Mutex; use socket2::SockRef; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::{Arc, LazyLock}; +use std::time::Duration; use crate::protocols::dns::DnsResolver; use crate::protocols::tls; @@ -37,7 +37,6 @@ use crate::tunnel::server::utils::{ use crate::tunnel::tls_reloader::TlsReloader; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; -use tokio::select; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::TlsAcceptor; use tracing::{error, info, span, warn, Instrument, Level, Span}; @@ -285,29 +284,41 @@ impl WsServer { // setup upgrade request handler let mk_websocket_upgrade_fn = |server: WsServer, - restrictions: Arc, + restrictions: Arc>, restrict_path: Option, client_addr: SocketAddr| { move |req: Request| { - ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) - .map::, _>(Ok) - .instrument(mk_span()) + ws_server_upgrade( + server.clone(), + restrictions.load().clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + .instrument(mk_span()) } }; let mk_http_upgrade_fn = |server: WsServer, - restrictions: Arc, + restrictions: Arc>, restrict_path: Option, client_addr: SocketAddr| { move |req: Request| { - http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) - .map::, _>(Ok) - .instrument(mk_span()) + http_server_upgrade( + server.clone(), + restrictions.load().clone(), + restrict_path.clone(), + client_addr, + req, + ) + .map::, _>(Ok) + .instrument(mk_span()) } }; let mk_auto_upgrade_fn = |server: WsServer, - restrictions: Arc, + restrictions: Arc>, restrict_path: Option, client_addr: SocketAddr| { move |req: Request| { @@ -316,13 +327,13 @@ impl WsServer { let restrict_path = restrict_path.clone(); async move { if fastwebsockets::upgrade::is_upgrade_request(&req) { - ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path, client_addr, req) + ws_server_upgrade(server.clone(), restrictions.load().clone(), restrict_path, client_addr, req) .map::, _>(Ok) .await } else if req.version() == Version::HTTP_2 { http_server_upgrade( server.clone(), - restrictions.clone(), + restrictions.load().clone(), restrict_path.clone(), client_addr, req, @@ -357,25 +368,11 @@ impl WsServer { }; // Bind server and run forever to serve incoming connections. - let mut restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; - let mut await_config_reload = Box::pin(restrictions.reload_notifier()); + let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; let listener = TcpListener::bind(&self.config.bind).await?; loop { - let cnx = select! { - biased; - - _ = &mut await_config_reload => { - drop(await_config_reload); - restrictions.reload_restrictions_config(); - await_config_reload = Box::pin(restrictions.reload_notifier()); - continue; - }, - - cnx = listener.accept() => { cnx } - }; - - let (stream, peer_addr) = match cnx { + let (stream, peer_addr) = match listener.accept().await { Ok(ret) => ret, Err(err) => { warn!("Error while accepting connection {:?}", err); @@ -423,7 +420,7 @@ impl WsServer { } let http_upgrade_fn = - mk_http_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); + mk_http_upgrade_fn(server, restrictions, restrict_path, peer_addr); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); if let Err(e) = con_fut.await { error!("Error while upgrading cnx to http: {:?}", e); @@ -432,7 +429,7 @@ impl WsServer { // websocket _ => { let websocket_upgrade_fn = - mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); + mk_websocket_upgrade_fn(server, restrictions, restrict_path, peer_addr); let conn_fut = http1::Builder::new() .timer(TokioTimer::new()) // https://github.com/erebe/wstunnel/issues/358 @@ -460,7 +457,7 @@ impl WsServer { conn_fut.http2().keep_alive_interval(ping); } - let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions.clone(), None, peer_addr); + let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions, None, peer_addr); let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));