From f19efa37f19fb3365dac56e11a0632c6d3d45b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Mon, 25 Dec 2023 18:06:44 +0100 Subject: [PATCH] cleanup code --- src/tunnel/server.rs | 191 +++++++++++++++++++++++++++++-------------- 1 file changed, 131 insertions(+), 60 deletions(-) diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 8217a18..32746c7 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,7 +1,7 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; use base64::Engine; -use futures_util::{pin_mut, Stream, StreamExt}; +use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; use std::cmp::min; use std::fmt::Debug; use std::future::Future; @@ -30,9 +30,9 @@ use tokio::sync::{mpsc, oneshot}; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; -async fn from_query( +async fn run_tunnel( server_config: &WsServerConfig, - query: &str, + jwt: TokenData, ) -> anyhow::Result<( LocalProtocol, Host, @@ -40,30 +40,6 @@ async fn from_query( Pin>, Pin>, )> { - let jwt: TokenData = match query.split_once('=') { - Some(("bearer", jwt)) => { - let (validation, decode_key) = JWT_DECODE.deref(); - match jsonwebtoken::decode(jwt, decode_key, validation) { - Ok(jwt) => jwt, - err => { - error!("error while decoding jwt for tunnel info {:?}", err); - return Err(anyhow::anyhow!("Invalid upgrade request")); - } - } - } - _err => return Err(anyhow::anyhow!("Invalid upgrade request")), - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - if let Some(allowed_dests) = &server_config.restrict_to { - let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); - if allowed_dests.iter().any(|dest| dest == &requested_dest).not() { - warn!("Rejecting connection with not allowed destination: {}", requested_dest); - return Err(anyhow::anyhow!("Invalid upgrade request")); - } - } - match jwt.claims.p { LocalProtocol::Udp { timeout, .. } => { let host = Host::parse(&jwt.claims.r)?; @@ -199,24 +175,29 @@ where Ok(cnx) } -async fn server_upgrade( - server_config: Arc, - mut req: Request, -) -> Result, anyhow::Error> { - if let Some(x) = req.headers().get("X-Forwarded-For") { - info!("Request X-Forwarded-For: {:?}", x); - Span::current().record("forwarded_for", x.to_str().unwrap_or_default()); - } +#[inline] +fn extract_x_forwarded_for(req: &Request) -> Result, Response> { + let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { + return Ok(None); + }; + Ok(Some(x_forward_for.to_str().unwrap_or_default())) +} + +#[inline] +fn validate_url( + req: &Request, + path_restriction_prefix: &Option>, +) -> Result<(), Response> { if !req.uri().path().ends_with("/events") { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Ok(http::Response::builder() + return Err(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body("Invalid upgrade request".into()) .unwrap()); } - if let Some(paths_prefix) = &server_config.restrict_http_upgrade_path_prefix { + if let Some(paths_prefix) = &path_restriction_prefix { let path = req.uri().path(); let min_len = min(path.len(), 1); let mut max_len = 0; @@ -228,34 +209,121 @@ async fn server_upgrade( || !path[max_len..].starts_with('/') { warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); - return Ok(http::Response::builder() + return Err(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body("Invalid upgrade request".to_string()) .unwrap()); } } - let (protocol, dest, port, local_rx, local_tx) = - match from_query(&server_config, req.uri().query().unwrap_or_default()).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return Ok(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); - } - }; + Ok(()) +} +#[inline] +fn extract_tunnel_info(req: &Request) -> Result, Response> { + let jwt: TokenData = match req.uri().query().unwrap_or_default().split_once('=') { + Some(("bearer", jwt)) => { + let (validation, decode_key) = JWT_DECODE.deref(); + match jsonwebtoken::decode(jwt, decode_key, validation) { + Ok(jwt) => jwt, + err => { + error!("error while decoding jwt for tunnel info {:?}", err); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + } + } + err => { + error!("Missing jwt tunnel config from request {:?}", err); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + }; + + Ok(jwt) +} + +#[inline] +fn validate_destination( + _req: &Request, + jwt: &TokenData, + destination_restriction: &Option>, +) -> Result<(), Response> { + let Some(allowed_dests) = &destination_restriction else { + return Ok(()); + }; + + let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); + if allowed_dests.iter().any(|dest| dest == &requested_dest).not() { + warn!("Rejecting connection with not allowed destination: {}", requested_dest); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + + Ok(()) +} + +async fn server_upgrade(server_config: Arc, mut req: Request) -> Response { + if !fastwebsockets::upgrade::is_upgrade_request(&req) { + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + + match extract_x_forwarded_for(&req) { + Ok(Some(x_forward_for)) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for); + } + Ok(_) => {} + Err(err) => return err, + } + + if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) { + return err; + } + + let jwt = match extract_tunnel_info(&req) { + Ok(jwt) => jwt, + Err(err) => return err, + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + + if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) { + return err; + } + + let tunnel = match run_tunnel(&server_config, jwt).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + }; + + let (protocol, dest, port, local_rx, local_tx) = tunnel; info!("connected to {:?} {:?} {:?}", protocol, dest, port); let (mut response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return Ok(http::Response::builder() + return http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(format!("Invalid upgrade request: {:?}", err)) - .unwrap()); + .unwrap(); } }; @@ -279,23 +347,26 @@ async fn server_upgrade( ); if protocol == LocalProtocol::ReverseSocks5 { - response.headers_mut().insert( - COOKIE, - HeaderValue::from_str( - &base64::engine::general_purpose::STANDARD.encode(format!("fake://{}:{}", dest, port)), - )?, - ); + let Ok(header_val) = HeaderValue::from_str( + &base64::engine::general_purpose::STANDARD.encode(format!("fake://{}:{}", dest, port)), + ) else { + error!("Bad headervalue for reverse socks5: {} {}", dest, port); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + }; + response.headers_mut().insert(COOKIE, header_val); } - let response = Response::from_parts(response.into_parts().0, "".to_string()); - Ok(response) + Response::from_parts(response.into_parts().0, "".to_string()) } pub async fn run_server(server_config: Arc) -> anyhow::Result<()> { info!("Starting wstunnel server listening on {}", server_config.bind); let config = server_config.clone(); - let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req); + let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req).map::, _>(Ok); let listener = TcpListener::bind(&server_config.bind).await?; let tls_acceptor = if let Some(tls) = &server_config.tls {