feat(restriction): Auto-reload restriction file

This commit is contained in:
Σrebe - Romain GERARD 2024-05-01 12:07:18 +02:00
parent 368f6657fd
commit 5ef14d1a8c
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 228 additions and 25 deletions

View file

@ -26,6 +26,7 @@ use jsonwebtoken::TokenData;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use crate::restrictions::config_reloader::RestrictionsRulesReloader;
use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
};
@ -418,6 +419,7 @@ fn validate_tunnel<'a>(
async fn ws_server_upgrade(
server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
mut client_addr: SocketAddr,
mut req: Request<Incoming>,
) -> Response<String> {
@ -463,7 +465,7 @@ async fn ws_server_upgrade(
}
};
match validate_tunnel(&remote, path_prefix, &server_config.restrictions) {
match validate_tunnel(&remote, path_prefix, &restrictions) {
Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
}
@ -542,6 +544,7 @@ async fn ws_server_upgrade(
async fn http_server_upgrade(
server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
mut client_addr: SocketAddr,
mut req: Request<Incoming>,
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
@ -578,7 +581,7 @@ async fn http_server_upgrade(
}
};
match validate_tunnel(&remote, path_prefix, &server_config.restrictions) {
match validate_tunnel(&remote, path_prefix, &restrictions) {
Ok(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
}
@ -664,32 +667,39 @@ impl TlsContext<'_> {
}
}
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: RestrictionsRules) -> anyhow::Result<()> {
info!("Starting wstunnel server listening on {}", server_config.bind);
// setup upgrade request handler
let mk_websocket_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
move |req: Request<Incoming>| {
ws_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_websocket_upgrade_fn =
|server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, client_addr: SocketAddr| {
move |req: Request<Incoming>| {
ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_http_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
move |req: Request<Incoming>| {
http_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_http_upgrade_fn =
|server_config: Arc<WsServerConfig>, restrictions: Arc<RestrictionsRules>, client_addr: SocketAddr| {
move |req: Request<Incoming>| {
http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok)
}
};
let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>,
restrictions: Arc<RestrictionsRules>,
client_addr: SocketAddr| {
move |req: Request<Incoming>| {
let server_config = server_config.clone();
let restrictions = restrictions.clone();
async move {
if fastwebsockets::upgrade::is_upgrade_request(&req) {
ws_server_upgrade(server_config.clone(), client_addr, req)
ws_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left)))
.await
} else if req.version() == Version::HTTP_2 {
http_server_upgrade(server_config.clone(), client_addr, req)
http_server_upgrade(server_config.clone(), restrictions.clone(), client_addr, req)
.map::<anyhow::Result<_>, _>(Ok)
.await
} else {
@ -716,9 +726,21 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
};
// Bind server and run forever to serve incoming connections.
let mut restrictions = RestrictionsRulesReloader::new(restrictions, server_config.restriction_config.clone())?;
let listener = TcpListener::bind(&server_config.bind).await?;
loop {
let (stream, peer_addr) = match listener.accept().await {
let cnx = select! {
biased;
_ = restrictions.wait_for_reload() => {
restrictions.reload_restrictions_config();
continue;
},
cnx = listener.accept() => { cnx }
};
let (stream, peer_addr) = match cnx {
Ok(ret) => ret,
Err(err) => {
warn!("Error while accepting connection {:?}", err);
@ -738,6 +760,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
info!("Accepting connection");
let server_config = server_config.clone();
let restrictions = restrictions.restrictions_rules().clone();
// Check if we need to enable TLS or not
match tls_context.as_mut() {
@ -762,7 +785,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
conn_builder.keep_alive_interval(ping);
}
let http_upgrade_fn = mk_http_upgrade_fn(server_config, peer_addr);
let http_upgrade_fn = mk_http_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
if let Err(e) = con_fut.await {
error!("Error while upgrading cnx to http: {:?}", e);
@ -770,7 +793,8 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
}
// websocket
_ => {
let websocket_upgrade_fn = mk_websocket_upgrade_fn(server_config, peer_addr);
let websocket_upgrade_fn =
mk_websocket_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let conn_fut = http1::Builder::new()
.serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
.with_upgrades();
@ -795,7 +819,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
conn_fut.http2().keep_alive_interval(ping);
}
let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, peer_addr);
let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, restrictions.clone(), peer_addr);
let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));
if let Err(e) = upgradable.await {