Add server restriction for http path prefix

Former-commit-id: cd23bc1dc834eb5bcae7c1903186b18920d0c77c [formerly c4cf29aaf8715fe6110a8047a8481c00d8b89406] [formerly b57721363e7e6848d96a17332c7737f734fff58f [formerly a32bcb0bf854ea21500fbd7e6556c77c3bb12d86]]
Former-commit-id: 236a241138b6f8f014c60d39ab61a3a59e4311aa [formerly 6a6fe77a9e88309c50af8d79cf659f79122e8dd6]
Former-commit-id: a1763d2ee5f0be5ab6c03548b07b0660f85d54c5
Former-commit-id: 177d318424a2d2cfa6c81e08a9960bd065010280
Former-commit-id: 4cc14b5a2b13cc9872aeedcd9dd89fbecd7b8fa3
Former-commit-id: 4657ba8c01362551f269a78c52372dbf9af164da [formerly 87b44b99887048179f349e6142a4d17d9127c872]
Former-commit-id: cf32d458d507adeecf98911cfef7211f7528f614
This commit is contained in:
Σrebe - Romain GERARD 2023-10-21 14:14:22 +02:00
parent 4e524fe550
commit e11a04eda8
2 changed files with 34 additions and 5 deletions

View file

@ -137,6 +137,13 @@ struct Server {
#[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)]
restrict_to: Option<Vec<String>>,
/// Server will only accept connection from if this specific path prefix is used during websocket upgrade.
/// Useful if you specify in the client a custom path prefix and you want the server to only allow this one.
/// The path prefix act as a secret to authenticate clients
/// Disabled by default. Accept all path prefix
#[arg(long, verbatim_doc_comment)]
restrict_http_upgrade_path_prefix: Option<String>,
/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_certificate: Option<PathBuf>,
@ -414,6 +421,7 @@ pub struct WsServerConfig {
pub socket_so_mark: Option<i32>,
pub bind: SocketAddr,
pub restrict_to: Option<Vec<String>>,
pub restrict_http_upgrade_path_prefix: Option<String>,
pub websocket_ping_frequency: Option<Duration>,
pub timeout_connect: Duration,
pub websocket_mask_frame: bool,
@ -634,6 +642,7 @@ async fn main() {
socket_so_mark: args.socket_so_mark,
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
restrict_to: args.restrict_to,
restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix,
websocket_ping_frequency: args.websocket_ping_frequency_sec,
timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame,

View file

@ -1,3 +1,4 @@
use std::cmp::min;
use std::collections::HashSet;
use std::future::Future;
use std::ops::{Deref, Not};
@ -151,7 +152,7 @@ pub async fn connect(
pub async fn connect_to_server<R, W>(
request_id: Uuid,
server_config: &WsClientConfig,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
duplex_stream: (R, W),
) -> anyhow::Result<()>
@ -159,15 +160,15 @@ where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let mut ws = connect(request_id, server_config, remote_cfg).await?;
ws.set_auto_apply_mask(server_config.websocket_mask_frame);
let mut ws = connect(request_id, client_cfg, remote_cfg).await?;
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = server_config.websocket_ping_frequency;
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()),
);
@ -266,10 +267,29 @@ async fn server_upgrade(
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request".to_string()))
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix {
let path = req.uri().path();
let min_len = min(path.len(), 1);
let max_len = min(path.len(), path_prefix.len() + 1);
if &path[0..min_len] != "/"
|| &path[min_len..max_len] != path_prefix.as_str()
|| !path[max_len..].starts_with('/')
{
warn!(
"Rejecting connection with bad path prefix in upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
}
let (protocol, dest, port, local_rx, local_tx) =
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret,