diff --git a/src/tunnel/server/handler_http2.rs b/src/tunnel/server/handler_http2.rs index 066df0c..3320df0 100644 --- a/src/tunnel/server/handler_http2.rs +++ b/src/tunnel/server/handler_http2.rs @@ -5,6 +5,7 @@ use crate::tunnel::server::utils::{ use crate::tunnel::server::WsServer; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::{transport, RemoteAddr}; +use crate::LocalProtocol; use bytes::Bytes; use futures_util::StreamExt; use http_body_util::combinators::BoxBody; @@ -13,7 +14,9 @@ use hyper::body::{Frame, Incoming}; use hyper::header::CONTENT_TYPE; use hyper::{http, Request, Response, StatusCode}; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; use tracing::{info, warn, Instrument, Span}; @@ -22,77 +25,14 @@ pub(super) async fn http_server_upgrade( server: WsServer, restrictions: Arc, restrict_path_prefix: Option, - mut client_addr: SocketAddr, + client_addr: SocketAddr, mut req: Request, ) -> Response>> { - match extract_x_forwarded_for(&req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(err) => return err.map(Either::Left), - }; - - let path_prefix = match extract_path_prefix(&req) { - Ok(p) => p, - Err(err) => return err.map(Either::Left), - }; - - if let Some(restrict_path) = restrict_path_prefix { - if path_prefix != restrict_path { - warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path - ); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - } - - let jwt = match extract_tunnel_info(&req) { - Ok(jwt) => jwt, - Err(err) => return err.map(Either::Left), - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - }; - - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Ok(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - Err(err) => return err.map(Either::Left), - }; - - let req_protocol = remote.protocol.clone(); - let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - }; - - let (remote_addr, local_rx, local_tx) = tunnel; - info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + let (remote_addr, local_rx, local_tx, need_cookie) = + match exec_tunnel_request(server, restrictions, restrict_path_prefix, client_addr, &req).await { + Ok(ret) => ret, + Err(err) => return err, + }; let req_content_type = req.headers_mut().remove(CONTENT_TYPE); let ws_rx = BodyStream::new(req.into_body()); @@ -120,8 +60,10 @@ pub(super) async fn http_server_upgrade( .instrument(Span::current()), ); - if let Err(response) = inject_cookie(&req_protocol, &mut response, &remote_addr, Either::Left) { - return response; + if need_cookie { + if let Err(response) = inject_cookie(&mut response, &remote_addr, Either::Left) { + return response; + } } if let Some(content_type) = req_content_type { @@ -130,3 +72,93 @@ pub(super) async fn http_server_upgrade( response } + +pub(super) async fn exec_tunnel_request( + server: WsServer, + restrictions: Arc, + restrict_path_prefix: Option, + mut client_addr: SocketAddr, + req: &Request, +) -> Result< + ( + RemoteAddr, + Pin>, + Pin>, + bool, + ), + Response>>, +> { + match extract_x_forwarded_for(req) { + Ok(Some((x_forward_for, x_forward_for_str))) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); + } + Ok(_) => {} + Err(err) => return Err(err.map(Either::Left)), + }; + + let path_prefix = match extract_path_prefix(req) { + Ok(p) => p, + Err(err) => return Err(err.map(Either::Left)), + }; + + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!( + "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", + path_prefix, restrict_path + ); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap()); + } + } + + let jwt = match extract_tunnel_info(req) { + Ok(jwt) => jwt, + Err(err) => return Err(err.map(Either::Left)), + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap()); + } + }; + + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction + } + Err(err) => return Err(err.map(Either::Left)), + }; + + let req_protocol = remote.protocol.clone(); + let inject_cookie = matches!( + req_protocol, + LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } + ); + let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap()); + } + }; + + let (remote_addr, local_rx, local_tx) = tunnel; + info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + Ok((remote_addr, local_rx, local_tx, inject_cookie)) +} diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index e78fe2f..bb18633 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -1,109 +1,49 @@ use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::server::utils::{ - extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, inject_cookie, validate_tunnel, -}; +use crate::tunnel::server::handler_http2::exec_tunnel_request; +use crate::tunnel::server::utils::inject_cookie; use crate::tunnel::server::WsServer; +use crate::tunnel::transport; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; -use crate::tunnel::{transport, RemoteAddr}; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use http_body_util::Either; use hyper::body::Incoming; use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use hyper::{http, Request, Response, StatusCode}; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::oneshot; -use tracing::{error, info, warn, Instrument, Span}; +use tracing::{error, warn, Instrument, Span}; pub(super) async fn ws_server_upgrade( server: WsServer, restrictions: Arc, restrict_path_prefix: Option, - mut client_addr: SocketAddr, + client_addr: SocketAddr, mut req: Request, -) -> Response { +) -> Response>> { if !fastwebsockets::upgrade::is_upgrade_request(&req) { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) + .body(Either::Left("Invalid upgrade request".to_string())) .unwrap(); } - match extract_x_forwarded_for(&req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(err) => return err, - }; + let mask_frame = server.config.websocket_mask_frame; + let (remote_addr, local_rx, local_tx, need_cookie) = + match exec_tunnel_request(server, restrictions, restrict_path_prefix, client_addr, &req).await { + Ok(ret) => ret, + Err(err) => return err, + }; - let path_prefix = match extract_path_prefix(&req) { - Ok(p) => p, - Err(err) => return err, - }; - - if let Some(restrict_path) = restrict_path_prefix { - if path_prefix != restrict_path { - warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path - ); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - } - - let jwt = match extract_tunnel_info(&req) { - Ok(jwt) => jwt, - Err(err) => return err, - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - }; - - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Ok(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - Err(err) => return err, - }; - - let req_protocol = remote.protocol.clone(); - let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - }; - - let (remote_addr, local_rx, local_tx) = tunnel; - info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { Ok(ret) => ret, Err(err) => { warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); return http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(format!("Invalid upgrade request: {:?}", err)) + .body(Either::Left(format!("Invalid upgrade request: {:?}", err))) .unwrap(); } }; @@ -118,7 +58,7 @@ pub(super) async fn ws_server_upgrade( } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - ws_tx.set_auto_apply_mask(server.config.websocket_mask_frame); + ws_tx.set_auto_apply_mask(mask_frame); tokio::task::spawn( transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) @@ -132,9 +72,11 @@ pub(super) async fn ws_server_upgrade( .instrument(Span::current()), ); - let mut response = Response::from_parts(response.into_parts().0, "".to_string()); - if let Err(response) = inject_cookie(&req_protocol, &mut response, &remote_addr, |s| s) { - return response; + let mut response = Response::from_parts(response.into_parts().0, Either::Right(BoxBody::default())); + if need_cookie { + if let Err(response) = inject_cookie(&mut response, &remote_addr, Either::Left) { + return response; + } } response diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 212c0f0..a9366b9 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -258,7 +258,7 @@ impl WsServer { async move { if fastwebsockets::upgrade::is_upgrade_request(&req) { ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path, client_addr, req) - .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) + .map::, _>(Ok) .await } else if req.version() == Version::HTTP_2 { http_server_upgrade( diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs index e18bf22..e3adadc 100644 --- a/src/tunnel/server/utils.rs +++ b/src/tunnel/server/utils.rs @@ -2,7 +2,6 @@ use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::LocalProtocol; use hyper::body::{Body, Incoming}; use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::{http, Request, Response, StatusCode}; @@ -216,7 +215,6 @@ pub(super) fn validate_tunnel<'a>( } pub(super) fn inject_cookie( - req_protocol: &LocalProtocol, response: &mut http::Response, remote_addr: &RemoteAddr, mk_body: impl FnOnce(String) -> B, @@ -224,19 +222,14 @@ pub(super) fn inject_cookie( where B: Body, { - if matches!( - req_protocol, - LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } - ) { - let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), remote_addr)) else { - error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(mk_body("Invalid upgrade request".to_string())) - .unwrap()); - }; - response.headers_mut().insert(COOKIE, header_val); - } + let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), remote_addr)) else { + error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(mk_body("Invalid upgrade request".to_string())) + .unwrap()); + }; + response.headers_mut().insert(COOKIE, header_val); Ok(()) }