From 0dded01b7f5ba923da20c9a953e10bb1f529caf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Tue, 30 Jul 2024 22:24:31 +0200 Subject: [PATCH] refacto: simplify server request handling --- src/tunnel/server/handler_http2.rs | 121 +++---------------------- src/tunnel/server/handler_websocket.rs | 33 +++---- src/tunnel/server/server.rs | 91 ++++++++++++++++++- src/tunnel/server/utils.rs | 57 +++++------- 4 files changed, 135 insertions(+), 167 deletions(-) diff --git a/src/tunnel/server/handler_http2.rs b/src/tunnel/server/handler_http2.rs index 3320df0..2a7ae12 100644 --- a/src/tunnel/server/handler_http2.rs +++ b/src/tunnel/server/handler_http2.rs @@ -1,25 +1,20 @@ use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::server::utils::{ - extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, inject_cookie, validate_tunnel, -}; +use crate::tunnel::server::utils::{bad_request, inject_cookie}; use crate::tunnel::server::WsServer; +use crate::tunnel::transport; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; -use crate::tunnel::{transport, RemoteAddr}; -use crate::LocalProtocol; use bytes::Bytes; use futures_util::StreamExt; use http_body_util::combinators::BoxBody; use http_body_util::{BodyStream, Either, StreamBody}; use hyper::body::{Frame, Incoming}; use hyper::header::CONTENT_TYPE; -use hyper::{http, Request, Response, StatusCode}; +use hyper::{Request, Response, StatusCode}; use std::net::SocketAddr; -use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; -use tracing::{info, warn, Instrument, Span}; +use tracing::{Instrument, Span}; pub(super) async fn http_server_upgrade( server: WsServer, @@ -28,11 +23,13 @@ pub(super) async fn http_server_upgrade( client_addr: SocketAddr, mut req: Request, ) -> Response>> { - let (remote_addr, local_rx, local_tx, need_cookie) = - match exec_tunnel_request(server, restrictions, restrict_path_prefix, client_addr, &req).await { - Ok(ret) => ret, - Err(err) => return err, - }; + let (remote_addr, local_rx, local_tx, need_cookie) = match server + .handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req) + .await + { + Ok(ret) => ret, + Err(err) => return err, + }; let req_content_type = req.headers_mut().remove(CONTENT_TYPE); let ws_rx = BodyStream::new(req.into_body()); @@ -60,10 +57,8 @@ pub(super) async fn http_server_upgrade( .instrument(Span::current()), ); - if need_cookie { - if let Err(response) = inject_cookie(&mut response, &remote_addr, Either::Left) { - return response; - } + if need_cookie && inject_cookie(&mut response, &remote_addr).is_err() { + return bad_request(); } if let Some(content_type) = req_content_type { @@ -72,93 +67,3 @@ pub(super) async fn http_server_upgrade( response } - -pub(super) async fn exec_tunnel_request( - server: WsServer, - restrictions: Arc, - restrict_path_prefix: Option, - mut client_addr: SocketAddr, - req: &Request, -) -> Result< - ( - RemoteAddr, - Pin>, - Pin>, - bool, - ), - Response>>, -> { - match extract_x_forwarded_for(req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(err) => return Err(err.map(Either::Left)), - }; - - let path_prefix = match extract_path_prefix(req) { - Ok(p) => p, - Err(err) => return Err(err.map(Either::Left)), - }; - - if let Some(restrict_path) = restrict_path_prefix { - if path_prefix != restrict_path { - warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path - ); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap()); - } - } - - let jwt = match extract_tunnel_info(req) { - Ok(jwt) => jwt, - Err(err) => return Err(err.map(Either::Left)), - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap()); - } - }; - - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Ok(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - Err(err) => return Err(err.map(Either::Left)), - }; - - let req_protocol = remote.protocol.clone(); - let inject_cookie = matches!( - req_protocol, - LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } - ); - let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap()); - } - }; - - let (remote_addr, local_rx, local_tx) = tunnel; - info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); - Ok((remote_addr, local_rx, local_tx, inject_cookie)) -} diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index bb18633..0c40ffc 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -1,6 +1,5 @@ use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::server::handler_http2::exec_tunnel_request; -use crate::tunnel::server::utils::inject_cookie; +use crate::tunnel::server::utils::{bad_request, inject_cookie}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; @@ -9,7 +8,7 @@ use http_body_util::combinators::BoxBody; use http_body_util::Either; use hyper::body::Incoming; use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; -use hyper::{http, Request, Response, StatusCode}; +use hyper::{Request, Response}; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::oneshot; @@ -24,27 +23,23 @@ pub(super) async fn ws_server_upgrade( ) -> Response>> { if !fastwebsockets::upgrade::is_upgrade_request(&req) { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); + return bad_request(); } let mask_frame = server.config.websocket_mask_frame; - let (remote_addr, local_rx, local_tx, need_cookie) = - match exec_tunnel_request(server, restrictions, restrict_path_prefix, client_addr, &req).await { - Ok(ret) => ret, - Err(err) => return err, - }; + let (remote_addr, local_rx, local_tx, need_cookie) = match server + .handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req) + .await + { + Ok(ret) => ret, + Err(err) => return err, + }; let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left(format!("Invalid upgrade request: {:?}", err))) - .unwrap(); + return bad_request(); } }; @@ -73,10 +68,8 @@ pub(super) async fn ws_server_upgrade( ); let mut response = Response::from_parts(response.into_parts().0, Either::Right(BoxBody::default())); - if need_cookie { - if let Err(response) = inject_cookie(&mut response, &remote_addr, Either::Left) { - return response; - } + if need_cookie && inject_cookie(&mut response, &remote_addr).is_err() { + return bad_request(); } response diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index a9366b9..d4b8986 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -6,6 +6,8 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; use std::net::SocketAddr; use std::ops::Deref; use std::path::PathBuf; @@ -18,7 +20,7 @@ use crate::{protocols, LocalProtocol}; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; -use hyper::{http, Request, StatusCode, Version}; +use hyper::{http, Request, Response, StatusCode, Version}; use hyper_util::rt::TokioExecutor; use once_cell::sync::Lazy; use parking_lot::Mutex; @@ -35,7 +37,9 @@ use crate::tunnel::listeners::{ }; use crate::tunnel::server::handler_http2::http_server_upgrade; use crate::tunnel::server::handler_websocket::ws_server_upgrade; -use crate::tunnel::server::utils::find_mapped_port; +use crate::tunnel::server::utils::{ + bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel, +}; use crate::tunnel::tls_reloader::TlsReloader; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; @@ -79,7 +83,88 @@ impl WsServer { } } - pub(super) async fn run_tunnel( + pub(super) async fn handle_tunnel_request( + &self, + restrictions: Arc, + restrict_path_prefix: Option, + mut client_addr: SocketAddr, + req: &Request, + ) -> Result< + ( + RemoteAddr, + Pin>, + Pin>, + bool, + ), + Response>>, + > { + match extract_x_forwarded_for(req) { + Ok(Some((x_forward_for, x_forward_for_str))) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); + } + Ok(_) => {} + Err(_err) => return Err(bad_request()), + }; + + let path_prefix = match extract_path_prefix(req) { + Ok(p) => p, + Err(_err) => return Err(bad_request()), + }; + + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!( + "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", + path_prefix, restrict_path + ); + return Err(bad_request()); + } + } + + let jwt = match extract_tunnel_info(req) { + Ok(jwt) => jwt, + Err(_err) => return Err(bad_request()), + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return Err(bad_request()); + } + }; + + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction + } + Err(_err) => return Err(bad_request()), + }; + + let req_protocol = remote.protocol.clone(); + let inject_cookie = matches!( + req_protocol, + LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } + ); + let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return Err(bad_request()); + } + }; + + let (remote_addr, local_rx, local_tx) = tunnel; + info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + Ok((remote_addr, local_rx, local_tx, inject_cookie)) + } + + async fn exec_tunnel( &self, restriction: &RestrictionConfig, remote: RemoteAddr, diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs index e3adadc..e0a2153 100644 --- a/src/tunnel/server/utils.rs +++ b/src/tunnel/server/utils.rs @@ -2,6 +2,9 @@ use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use http_body_util::Either; use hyper::body::{Body, Incoming}; use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::{http, Request, Response, StatusCode}; @@ -13,6 +16,13 @@ use tracing::{error, info, warn}; use url::Host; use uuid::Uuid; +pub(super) fn bad_request() -> Response>> { + http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid request".to_string())) + .unwrap() +} + /// Checks if the requested (remote) port has been mapped in the configuration to another port. /// If it is not mapped the original port number is returned. #[inline] @@ -37,7 +47,7 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) - } #[inline] -pub(super) fn extract_x_forwarded_for(req: &Request) -> Result, Response> { +pub(super) fn extract_x_forwarded_for(req: &Request) -> Result, ()> { let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { return Ok(None); }; @@ -50,38 +60,29 @@ pub(super) fn extract_x_forwarded_for(req: &Request) -> Result) -> Result<&str, Response> { +pub(super) fn extract_path_prefix(req: &Request) -> Result<&str, ()> { let path = req.uri().path(); let min_len = min(path.len(), 1); if &path[0..min_len] != "/" { warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); + return Err(()); } let Some((l, r)) = path[min_len..].split_once('/') else { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".into()) - .unwrap()); + return Err(()); }; if !r.ends_with("events") { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".into()) - .unwrap()); + return Err(()); } Ok(l) } #[inline] -pub(super) fn extract_tunnel_info(req: &Request) -> Result, Response> { +pub(super) fn extract_tunnel_info(req: &Request) -> Result, ()> { let jwt = req .headers() .get(SEC_WEBSOCKET_PROTOCOL) @@ -100,10 +101,7 @@ pub(super) fn extract_tunnel_info(req: &Request) -> Result( remote: &RemoteAddr, path_prefix: &str, restrictions: &'a RestrictionsRules, -) -> Result<&'a RestrictionConfig, Response> { +) -> Result<&'a RestrictionConfig, ()> { for restriction in &restrictions.restrictions { if !restriction.r#match.iter().all(|m| match m { MatchConfig::Any => true, @@ -208,26 +206,13 @@ pub(super) fn validate_tunnel<'a>( } warn!("Rejecting connection with not allowed destination: {:?}", remote); - Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()) + Err(()) } -pub(super) fn inject_cookie( - response: &mut http::Response, - remote_addr: &RemoteAddr, - mk_body: impl FnOnce(String) -> B, -) -> Result<(), Response> -where - B: Body, -{ +pub(super) fn inject_cookie(response: &mut http::Response, remote_addr: &RemoteAddr) -> Result<(), ()> { let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), remote_addr)) else { error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(mk_body("Invalid upgrade request".to_string())) - .unwrap()); + return Err(()); }; response.headers_mut().insert(COOKIE, header_val);