fix(#368): Update the restrictions before each request

instead of only once per conections, to avoid using a stale
restriction config when multiple request arrive on the same tcp stream.
This commit is contained in:
Σrebe - Romain GERARD 2024-10-12 19:06:29 +02:00
parent 6ae1eae4a6
commit 4b43dfc268
4 changed files with 82 additions and 84 deletions

7
Cargo.lock generated
View file

@ -109,6 +109,12 @@ version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6"
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "asn1-rs" name = "asn1-rs"
version = "0.6.2" version = "0.6.2"
@ -3121,6 +3127,7 @@ version = "10.1.4"
dependencies = [ dependencies = [
"ahash", "ahash",
"anyhow", "anyhow",
"arc-swap",
"async-channel", "async-channel",
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",

View file

@ -21,6 +21,7 @@ futures-util = { version = "0.3.30" }
hickory-resolver = { version = "0.24.1", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls", "native-certs"] } hickory-resolver = { version = "0.24.1", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls", "native-certs"] }
ppp = { version = "2.2.0", features = [] } ppp = { version = "2.2.0", features = [] }
async-channel = { version = "2.3.1", features = [] } async-channel = { version = "2.3.1", features = [] }
arc-swap = { version = "1.7.1", features = [] }
# For config file parsing # For config file parsing
regex = { version = "1.11.0", default-features = false, features = ["std", "perf"] } regex = { version = "1.11.0", default-features = false, features = ["std", "perf"] }

View file

@ -1,6 +1,7 @@
use super::types::RestrictionsRules; use super::types::RestrictionsRules;
use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static}; use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static};
use anyhow::Context; use anyhow::Context;
use arc_swap::ArcSwap;
use log::trace; use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher}; use notify::{EventKind, RecommendedWatcher, Watcher};
use parking_lot::Mutex; use parking_lot::Mutex;
@ -8,33 +9,32 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use tokio::sync::futures::Notified;
use tokio::sync::Notify;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
struct ConfigReloaderState { struct ConfigReloaderState {
fs_watcher: Mutex<RecommendedWatcher>, fs_watcher: Mutex<RecommendedWatcher>,
config_path: PathBuf, config_path: PathBuf,
should_reload_config: Notify,
} }
#[derive(Clone)]
enum RestrictionsRulesReloaderState { enum RestrictionsRulesReloaderState {
Static(Notify), Static,
Config(Arc<ConfigReloaderState>), Config(Arc<ConfigReloaderState>),
} }
impl RestrictionsRulesReloaderState { impl RestrictionsRulesReloaderState {
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> { fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
match self { match self {
Static(_) => unreachable!(), Static => unreachable!(),
Config(this) => &this.fs_watcher, Config(this) => &this.fs_watcher,
} }
} }
} }
#[derive(Clone)]
pub struct RestrictionsRulesReloader { pub struct RestrictionsRulesReloader {
state: RestrictionsRulesReloaderState, state: RestrictionsRulesReloaderState,
restrictions: Arc<RestrictionsRules>, restrictions: Arc<ArcSwap<RestrictionsRules>>,
} }
impl RestrictionsRulesReloader { impl RestrictionsRulesReloader {
@ -44,37 +44,40 @@ impl RestrictionsRulesReloader {
config_path config_path
} else { } else {
return Ok(Self { return Ok(Self {
state: Static(Notify::new()), state: Static,
restrictions: Arc::new(restrictions_rules), restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)),
}); });
}; };
let reloader = Self {
let this = Arc::new(ConfigReloaderState { state: Config(Arc::new(ConfigReloaderState {
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?), fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
should_reload_config: Notify::new(), config_path,
config_path, })),
}); restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)),
};
info!("Starting to watch restriction config file for changes to reload them"); info!("Starting to watch restriction config file for changes to reload them");
let mut watcher = notify::recommended_watcher({ let mut watcher = notify::recommended_watcher({
let this = Config(this.clone()); let reloader = reloader.clone();
move |event: notify::Result<notify::Event>| Self::handle_config_fs_event(&this, event) move |event: notify::Result<notify::Event>| Self::handle_config_fs_event(&reloader, event)
}) })
.with_context(|| "Cannot create restriction config watcher")?; .with_context(|| "Cannot create restriction config watcher")?;
watcher.watch(&this.config_path, notify::RecursiveMode::NonRecursive)?; match &reloader.state {
*this.fs_watcher.lock() = watcher; Static => {}
Config(cfg) => {
watcher.watch(&cfg.config_path, notify::RecursiveMode::NonRecursive)?;
*cfg.fs_watcher.lock() = watcher
}
}
Ok(Self { Ok(reloader)
state: Config(this),
restrictions: Arc::new(restrictions_rules),
})
} }
pub fn reload_restrictions_config(&mut self) { pub fn reload_restrictions_config(&self) {
let restrictions = match &self.state { let restrictions = match &self.state {
Static(_) => return, Static => return,
Config(st) => match RestrictionsRules::from_config_file(&st.config_path) { Config(st) => match RestrictionsRules::from_config_file(&st.config_path) {
Ok(restrictions) => { Ok(restrictions) => {
info!("Restrictions config file has been reloaded"); info!("Restrictions config file has been reloaded");
@ -87,21 +90,14 @@ impl RestrictionsRulesReloader {
}, },
}; };
self.restrictions = Arc::new(restrictions); self.restrictions.store(Arc::new(restrictions));
} }
pub const fn restrictions_rules(&self) -> &Arc<RestrictionsRules> { pub const fn restrictions_rules(&self) -> &Arc<ArcSwap<RestrictionsRules>> {
&self.restrictions &self.restrictions
} }
pub fn reload_notifier(&self) -> Notified { fn try_rewatch_config(this: RestrictionsRulesReloader, path: PathBuf) {
match &self.state {
Static(st) => st.notified(),
Config(st) => st.should_reload_config.notified(),
}
}
fn try_rewatch_config(this: RestrictionsRulesReloaderState, path: PathBuf) {
thread::spawn(move || { thread::spawn(move || {
while !path.exists() { while !path.exists() {
warn!( warn!(
@ -110,7 +106,7 @@ impl RestrictionsRulesReloader {
); );
thread::sleep(Duration::from_secs(10)); thread::sleep(Duration::from_secs(10));
} }
let mut watcher = this.fs_watcher().lock(); let mut watcher = this.state.fs_watcher().lock();
let _ = watcher.unwatch(&path); let _ = watcher.unwatch(&path);
let Ok(_) = watcher let Ok(_) = watcher
.watch(&path, notify::RecursiveMode::NonRecursive) .watch(&path, notify::RecursiveMode::NonRecursive)
@ -123,23 +119,20 @@ impl RestrictionsRulesReloader {
}; };
drop(watcher); drop(watcher);
// Generate a fake event to force-reload the certificate // Generate a fake event to force-reload the config
let event = notify::Event { let event = notify::Event {
kind: EventKind::Create(notify::event::CreateKind::Any), kind: EventKind::Create(notify::event::CreateKind::Any),
paths: vec![path], paths: vec![path],
attrs: Default::default(), attrs: Default::default(),
}; };
match &this { Self::handle_config_fs_event(&this, Ok(event))
Static(_) => Self::handle_config_fs_event(&this, Ok(event)),
Config(_) => Self::handle_config_fs_event(&this, Ok(event)),
}
}); });
} }
fn handle_config_fs_event(this: &RestrictionsRulesReloaderState, event: notify::Result<notify::Event>) { fn handle_config_fs_event(reloader: &RestrictionsRulesReloader, event: notify::Result<notify::Event>) {
let this = match this { let this = match &reloader.state {
Static(_) => return, Static => return,
Config(st) => st, Config(st) => st,
}; };
@ -159,11 +152,11 @@ impl RestrictionsRulesReloader {
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) { if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) {
match event.kind { match event.kind {
EventKind::Create(_) | EventKind::Modify(_) => { EventKind::Create(_) | EventKind::Modify(_) => {
this.should_reload_config.notify_one(); reloader.reload_restrictions_config();
} }
EventKind::Remove(_) => { EventKind::Remove(_) => {
warn!("Restriction config file has been removed, trying to re-set a watch for it"); warn!("Restriction config file has been removed, trying to re-set a watch for it");
Self::try_rewatch_config(Config(this.clone()), path.to_path_buf()); Self::try_rewatch_config(reloader.clone(), path.to_path_buf());
} }
EventKind::Access(_) | EventKind::Other | EventKind::Any => { EventKind::Access(_) | EventKind::Other | EventKind::Any => {
trace!("Ignoring event {:?}", event); trace!("Ignoring event {:?}", event);

View file

@ -4,16 +4,11 @@ use http_body_util::Either;
use std::fmt; use std::fmt;
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use crate::protocols; use crate::protocols;
use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr}; use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr};
use arc_swap::ArcSwap;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::server::conn::{http1, http2}; use hyper::server::conn::{http1, http2};
use hyper::service::service_fn; use hyper::service::service_fn;
@ -21,6 +16,11 @@ use hyper::{http, Request, Response, StatusCode, Version};
use hyper_util::rt::{TokioExecutor, TokioTimer}; use hyper_util::rt::{TokioExecutor, TokioTimer};
use parking_lot::Mutex; use parking_lot::Mutex;
use socket2::SockRef; use socket2::SockRef;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use crate::protocols::dns::DnsResolver; use crate::protocols::dns::DnsResolver;
use crate::protocols::tls; use crate::protocols::tls;
@ -37,7 +37,6 @@ use crate::tunnel::server::utils::{
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::select;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use tracing::{error, info, span, warn, Instrument, Level, Span}; use tracing::{error, info, span, warn, Instrument, Level, Span};
@ -285,29 +284,41 @@ impl WsServer {
// setup upgrade request handler // setup upgrade request handler
let mk_websocket_upgrade_fn = |server: WsServer, let mk_websocket_upgrade_fn = |server: WsServer,
restrictions: Arc<RestrictionsRules>, restrictions: Arc<ArcSwap<RestrictionsRules>>,
restrict_path: Option<String>, restrict_path: Option<String>,
client_addr: SocketAddr| { client_addr: SocketAddr| {
move |req: Request<Incoming>| { move |req: Request<Incoming>| {
ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) ws_server_upgrade(
.map::<anyhow::Result<_>, _>(Ok) server.clone(),
.instrument(mk_span()) restrictions.load().clone(),
restrict_path.clone(),
client_addr,
req,
)
.map::<anyhow::Result<_>, _>(Ok)
.instrument(mk_span())
} }
}; };
let mk_http_upgrade_fn = |server: WsServer, let mk_http_upgrade_fn = |server: WsServer,
restrictions: Arc<RestrictionsRules>, restrictions: Arc<ArcSwap<RestrictionsRules>>,
restrict_path: Option<String>, restrict_path: Option<String>,
client_addr: SocketAddr| { client_addr: SocketAddr| {
move |req: Request<Incoming>| { move |req: Request<Incoming>| {
http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req) http_server_upgrade(
.map::<anyhow::Result<_>, _>(Ok) server.clone(),
.instrument(mk_span()) restrictions.load().clone(),
restrict_path.clone(),
client_addr,
req,
)
.map::<anyhow::Result<_>, _>(Ok)
.instrument(mk_span())
} }
}; };
let mk_auto_upgrade_fn = |server: WsServer, let mk_auto_upgrade_fn = |server: WsServer,
restrictions: Arc<RestrictionsRules>, restrictions: Arc<ArcSwap<RestrictionsRules>>,
restrict_path: Option<String>, restrict_path: Option<String>,
client_addr: SocketAddr| { client_addr: SocketAddr| {
move |req: Request<Incoming>| { move |req: Request<Incoming>| {
@ -316,13 +327,13 @@ impl WsServer {
let restrict_path = restrict_path.clone(); let restrict_path = restrict_path.clone();
async move { async move {
if fastwebsockets::upgrade::is_upgrade_request(&req) { if fastwebsockets::upgrade::is_upgrade_request(&req) {
ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path, client_addr, req) ws_server_upgrade(server.clone(), restrictions.load().clone(), restrict_path, client_addr, req)
.map::<anyhow::Result<_>, _>(Ok) .map::<anyhow::Result<_>, _>(Ok)
.await .await
} else if req.version() == Version::HTTP_2 { } else if req.version() == Version::HTTP_2 {
http_server_upgrade( http_server_upgrade(
server.clone(), server.clone(),
restrictions.clone(), restrictions.load().clone(),
restrict_path.clone(), restrict_path.clone(),
client_addr, client_addr,
req, req,
@ -357,25 +368,11 @@ impl WsServer {
}; };
// Bind server and run forever to serve incoming connections. // Bind server and run forever to serve incoming connections.
let mut restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
let mut await_config_reload = Box::pin(restrictions.reload_notifier());
let listener = TcpListener::bind(&self.config.bind).await?; let listener = TcpListener::bind(&self.config.bind).await?;
loop { loop {
let cnx = select! { let (stream, peer_addr) = match listener.accept().await {
biased;
_ = &mut await_config_reload => {
drop(await_config_reload);
restrictions.reload_restrictions_config();
await_config_reload = Box::pin(restrictions.reload_notifier());
continue;
},
cnx = listener.accept() => { cnx }
};
let (stream, peer_addr) = match cnx {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
warn!("Error while accepting connection {:?}", err); warn!("Error while accepting connection {:?}", err);
@ -423,7 +420,7 @@ impl WsServer {
} }
let http_upgrade_fn = let http_upgrade_fn =
mk_http_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); mk_http_upgrade_fn(server, restrictions, restrict_path, peer_addr);
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
if let Err(e) = con_fut.await { if let Err(e) = con_fut.await {
error!("Error while upgrading cnx to http: {:?}", e); error!("Error while upgrading cnx to http: {:?}", e);
@ -432,7 +429,7 @@ impl WsServer {
// websocket // websocket
_ => { _ => {
let websocket_upgrade_fn = let websocket_upgrade_fn =
mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr); mk_websocket_upgrade_fn(server, restrictions, restrict_path, peer_addr);
let conn_fut = http1::Builder::new() let conn_fut = http1::Builder::new()
.timer(TokioTimer::new()) .timer(TokioTimer::new())
// https://github.com/erebe/wstunnel/issues/358 // https://github.com/erebe/wstunnel/issues/358
@ -460,7 +457,7 @@ impl WsServer {
conn_fut.http2().keep_alive_interval(ping); conn_fut.http2().keep_alive_interval(ping);
} }
let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions.clone(), None, peer_addr); let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions, None, peer_addr);
let upgradable = let upgradable =
conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));