diff --git a/src/main.rs b/src/main.rs index 537001d..7118e87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -161,9 +161,9 @@ struct Server { /// Server will only accept connection from if this specific path prefix is used during websocket upgrade. /// Useful if you specify in the client a custom path prefix and you want the server to only allow this one. /// The path prefix act as a secret to authenticate clients - /// Disabled by default. Accept all path prefix + /// Disabled by default. Accept all path prefix. Can be specified multiple time #[arg(long, verbatim_doc_comment)] - restrict_http_upgrade_path_prefix: Option, + restrict_http_upgrade_path_prefix: Option>, /// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate. #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] @@ -441,7 +441,7 @@ pub struct WsServerConfig { pub socket_so_mark: Option, pub bind: SocketAddr, pub restrict_to: Option>, - pub restrict_http_upgrade_path_prefix: Option, + pub restrict_http_upgrade_path_prefix: Option>, pub websocket_ping_frequency: Option, pub timeout_connect: Duration, pub websocket_mask_frame: bool, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index ec7bec1..cae18b4 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -202,12 +202,15 @@ async fn server_upgrade( .unwrap_or_default()); } - if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix { + if let Some(paths_prefix) = &server_config.restrict_http_upgrade_path_prefix { let path = req.uri().path(); let min_len = min(path.len(), 1); - let max_len = min(path.len(), path_prefix.len() + 1); + let mut max_len = 0; if &path[0..min_len] != "/" - || &path[min_len..max_len] != path_prefix.as_str() + || !paths_prefix.iter().any(|p| { + max_len = min(path.len(), p.len() + 1); + p == &path[min_len..max_len] + }) || !path[max_len..].starts_with('/') { warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); diff --git a/src/udp.rs b/src/udp.rs index 644e264..bdfa4bb 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -430,7 +430,9 @@ mod tests { #[tokio::test] async fn test_udp_server() { let server_addr: SocketAddr = "[::1]:1234".parse().unwrap(); - let server = run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())).await.unwrap(); + let server = run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())) + .await + .unwrap(); pin_mut!(server); // Should timeout @@ -476,7 +478,11 @@ mod tests { #[tokio::test] async fn test_multiple_client() { let server_addr: SocketAddr = "[::1]:1235".parse().unwrap(); - let mut server = Box::pin(run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())).await.unwrap()); + let mut server = Box::pin( + run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())) + .await + .unwrap(), + ); // Send some data to the server let client = UdpSocket::bind("[::1]:0").await.unwrap(); @@ -542,7 +548,9 @@ mod tests { async fn test_udp_should_timeout() { let server_addr: SocketAddr = "[::1]:1237".parse().unwrap(); let socket_timeout = Duration::from_secs(1); - let server = run_server(server_addr, Some(socket_timeout), |_| Ok(()), |l| Ok(l.clone())).await.unwrap(); + let server = run_server(server_addr, Some(socket_timeout), |_| Ok(()), |l| Ok(l.clone())) + .await + .unwrap(); pin_mut!(server); // Send some data to the server