From e11a04eda8c9112511f72cc8ece57d3c619fbaa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sat, 21 Oct 2023 14:14:22 +0200 Subject: [PATCH] Add server restriction for http path prefix Former-commit-id: cd23bc1dc834eb5bcae7c1903186b18920d0c77c [formerly c4cf29aaf8715fe6110a8047a8481c00d8b89406] [formerly b57721363e7e6848d96a17332c7737f734fff58f [formerly a32bcb0bf854ea21500fbd7e6556c77c3bb12d86]] Former-commit-id: 236a241138b6f8f014c60d39ab61a3a59e4311aa [formerly 6a6fe77a9e88309c50af8d79cf659f79122e8dd6] Former-commit-id: a1763d2ee5f0be5ab6c03548b07b0660f85d54c5 Former-commit-id: 177d318424a2d2cfa6c81e08a9960bd065010280 Former-commit-id: 4cc14b5a2b13cc9872aeedcd9dd89fbecd7b8fa3 Former-commit-id: 4657ba8c01362551f269a78c52372dbf9af164da [formerly 87b44b99887048179f349e6142a4d17d9127c872] Former-commit-id: cf32d458d507adeecf98911cfef7211f7528f614 --- src/main.rs | 9 +++++++++ src/transport.rs | 30 +++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index 17e39c1..f5ab9bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,6 +137,13 @@ struct Server { #[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)] restrict_to: Option>, + /// Server will only accept connection from if this specific path prefix is used during websocket upgrade. + /// Useful if you specify in the client a custom path prefix and you want the server to only allow this one. + /// The path prefix act as a secret to authenticate clients + /// Disabled by default. Accept all path prefix + #[arg(long, verbatim_doc_comment)] + restrict_http_upgrade_path_prefix: Option, + /// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate. #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] tls_certificate: Option, @@ -414,6 +421,7 @@ pub struct WsServerConfig { pub socket_so_mark: Option, pub bind: SocketAddr, pub restrict_to: Option>, + pub restrict_http_upgrade_path_prefix: Option, pub websocket_ping_frequency: Option, pub timeout_connect: Duration, pub websocket_mask_frame: bool, @@ -634,6 +642,7 @@ async fn main() { socket_so_mark: args.socket_so_mark, bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], restrict_to: args.restrict_to, + restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix, websocket_ping_frequency: args.websocket_ping_frequency_sec, timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, diff --git a/src/transport.rs b/src/transport.rs index 725eee0..ce1c9fc 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,3 +1,4 @@ +use std::cmp::min; use std::collections::HashSet; use std::future::Future; use std::ops::{Deref, Not}; @@ -151,7 +152,7 @@ pub async fn connect( pub async fn connect_to_server( request_id: Uuid, - server_config: &WsClientConfig, + client_cfg: &WsClientConfig, remote_cfg: &LocalToRemote, duplex_stream: (R, W), ) -> anyhow::Result<()> @@ -159,15 +160,15 @@ where R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { - let mut ws = connect(request_id, server_config, remote_cfg).await?; - ws.set_auto_apply_mask(server_config.websocket_mask_frame); + let mut ws = connect(request_id, client_cfg, remote_cfg).await?; + ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); let (ws_rx, ws_tx) = ws.split(tokio::io::split); let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); // Forward local tx to websocket tx - let ping_frequency = server_config.websocket_ping_frequency; + let ping_frequency = client_cfg.websocket_ping_frequency; tokio::spawn( propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()), ); @@ -266,10 +267,29 @@ async fn server_upgrade( ); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid upgrade request".to_string())) + .body(Body::from("Invalid upgrade request")) .unwrap_or_default()); } + if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix { + let path = req.uri().path(); + let min_len = min(path.len(), 1); + let max_len = min(path.len(), path_prefix.len() + 1); + if &path[0..min_len] != "/" + || &path[min_len..max_len] != path_prefix.as_str() + || !path[max_len..].starts_with('/') + { + warn!( + "Rejecting connection with bad path prefix in upgrade request: {}", + req.uri() + ); + return Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid upgrade request")) + .unwrap_or_default()); + } + } + let (protocol, dest, port, local_rx, local_tx) = match from_query(&server_config, req.uri().query().unwrap_or_default()).await { Ok(ret) => ret,