diff --git a/src/main.rs b/src/main.rs index 17e39c1..f5ab9bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,6 +137,13 @@ struct Server { #[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)] restrict_to: Option>, + /// Server will only accept connection from if this specific path prefix is used during websocket upgrade. + /// Useful if you specify in the client a custom path prefix and you want the server to only allow this one. + /// The path prefix act as a secret to authenticate clients + /// Disabled by default. Accept all path prefix + #[arg(long, verbatim_doc_comment)] + restrict_http_upgrade_path_prefix: Option, + /// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate. #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] tls_certificate: Option, @@ -414,6 +421,7 @@ pub struct WsServerConfig { pub socket_so_mark: Option, pub bind: SocketAddr, pub restrict_to: Option>, + pub restrict_http_upgrade_path_prefix: Option, pub websocket_ping_frequency: Option, pub timeout_connect: Duration, pub websocket_mask_frame: bool, @@ -634,6 +642,7 @@ async fn main() { socket_so_mark: args.socket_so_mark, bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], restrict_to: args.restrict_to, + restrict_http_upgrade_path_prefix: args.restrict_http_upgrade_path_prefix, websocket_ping_frequency: args.websocket_ping_frequency_sec, timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, diff --git a/src/transport.rs b/src/transport.rs index 725eee0..ce1c9fc 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,3 +1,4 @@ +use std::cmp::min; use std::collections::HashSet; use std::future::Future; use std::ops::{Deref, Not}; @@ -151,7 +152,7 @@ pub async fn connect( pub async fn connect_to_server( request_id: Uuid, - server_config: &WsClientConfig, + client_cfg: &WsClientConfig, remote_cfg: &LocalToRemote, duplex_stream: (R, W), ) -> anyhow::Result<()> @@ -159,15 +160,15 @@ where R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { - let mut ws = connect(request_id, server_config, remote_cfg).await?; - ws.set_auto_apply_mask(server_config.websocket_mask_frame); + let mut ws = connect(request_id, client_cfg, remote_cfg).await?; + ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); let (ws_rx, ws_tx) = ws.split(tokio::io::split); let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); // Forward local tx to websocket tx - let ping_frequency = server_config.websocket_ping_frequency; + let ping_frequency = client_cfg.websocket_ping_frequency; tokio::spawn( propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()), ); @@ -266,10 +267,29 @@ async fn server_upgrade( ); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid upgrade request".to_string())) + .body(Body::from("Invalid upgrade request")) .unwrap_or_default()); } + if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix { + let path = req.uri().path(); + let min_len = min(path.len(), 1); + let max_len = min(path.len(), path_prefix.len() + 1); + if &path[0..min_len] != "/" + || &path[min_len..max_len] != path_prefix.as_str() + || !path[max_len..].starts_with('/') + { + warn!( + "Rejecting connection with bad path prefix in upgrade request: {}", + req.uri() + ); + return Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid upgrade request")) + .unwrap_or_default()); + } + } + let (protocol, dest, port, local_rx, local_tx) = match from_query(&server_config, req.uri().query().unwrap_or_default()).await { Ok(ret) => ret,