diff --git a/Cargo.lock b/Cargo.lock index 3f6452b..d2b8d91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -346,13 +346,12 @@ checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "cc" -version = "1.1.0" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" +checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" dependencies = [ "jobserver", "libc", - "once_cell", ] [[package]] @@ -402,9 +401,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.9" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" +checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" dependencies = [ "clap_builder", "clap_derive", @@ -412,9 +411,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.9" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" +checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" dependencies = [ "anstream", "anstyle", @@ -424,9 +423,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.8" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" +checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -504,7 +503,7 @@ dependencies = [ "bitflags 2.6.0", "crossterm_winapi", "libc", - "mio", + "mio 0.8.11", "parking_lot", "signal-hook", "signal-hook-mio", @@ -1500,6 +1499,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "mirai-annotations" version = "1.12.0" @@ -1543,7 +1554,7 @@ dependencies = [ "kqueue", "libc", "log", - "mio", + "mio 0.8.11", "walkdir", "windows-sys 0.48.0", ] @@ -1592,16 +1603,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "num_threads" version = "0.1.7" @@ -2289,7 +2290,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" dependencies = [ "libc", - "mio", + "mio 0.8.11", "signal-hook", ] @@ -2502,21 +2503,20 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.1" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", - "mio", - "num_cpus", + "mio 1.0.1", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2531,9 +2531,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 4a84e42..44abf12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ scopeguard = "1.2.0" bb8 = { version = "0.8", features = [] } bytes = { version = "1.6.1", features = [] } -clap = { version = "4.5.9", features = ["derive", "env"] } +clap = { version = "4.5.11", features = ["derive", "env"] } fast-socks5 = { version = "0.9.6", features = [] } fastwebsockets = { version = "0.8.0", features = ["upgrade", "simd", "unstable-split"] } futures-util = { version = "0.3.30" } @@ -43,7 +43,7 @@ rustls-pemfile = { version = "2.1.2", features = [] } x509-parser = "0.16.0" serde = { version = "1.0.204", features = ["derive"] } socket2 = { version = "0.5.7", features = [] } -tokio = { version = "1.38.1", features = ["full"] } +tokio = { version = "1.39.2", features = ["full"] } tokio-stream = { version = "0.1.15", features = ["net"] } [target.'cfg(any(os = "linux", os = "macos"))'.dependencies] diff --git a/src/tunnel/server/handler_http2.rs b/src/tunnel/server/handler_http2.rs new file mode 100644 index 0000000..066df0c --- /dev/null +++ b/src/tunnel/server/handler_http2.rs @@ -0,0 +1,132 @@ +use crate::restrictions::types::RestrictionsRules; +use crate::tunnel::server::utils::{ + extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, inject_cookie, validate_tunnel, +}; +use crate::tunnel::server::WsServer; +use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; +use crate::tunnel::{transport, RemoteAddr}; +use bytes::Bytes; +use futures_util::StreamExt; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyStream, Either, StreamBody}; +use hyper::body::{Frame, Incoming}; +use hyper::header::CONTENT_TYPE; +use hyper::{http, Request, Response, StatusCode}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{info, warn, Instrument, Span}; + +pub(super) async fn http_server_upgrade( + server: WsServer, + restrictions: Arc, + restrict_path_prefix: Option, + mut client_addr: SocketAddr, + mut req: Request, +) -> Response>> { + match extract_x_forwarded_for(&req) { + Ok(Some((x_forward_for, x_forward_for_str))) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); + } + Ok(_) => {} + Err(err) => return err.map(Either::Left), + }; + + let path_prefix = match extract_path_prefix(&req) { + Ok(p) => p, + Err(err) => return err.map(Either::Left), + }; + + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!( + "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", + path_prefix, restrict_path + ); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + } + } + + let jwt = match extract_tunnel_info(&req) { + Ok(jwt) => jwt, + Err(err) => return err.map(Either::Left), + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + } + }; + + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction + } + Err(err) => return err.map(Either::Left), + }; + + let req_protocol = remote.protocol.clone(); + let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + } + }; + + let (remote_addr, local_rx, local_tx) = tunnel; + info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + + let req_content_type = req.headers_mut().remove(CONTENT_TYPE); + let ws_rx = BodyStream::new(req.into_body()); + let (ws_tx, rx) = mpsc::channel::(1024); + let body = BoxBody::new(StreamBody::new( + ReceiverStream::new(rx).map(|s| -> anyhow::Result> { Ok(Frame::data(s)) }), + )); + + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Either::Right(body)) + .expect("bug: failed to build response"); + + tokio::spawn( + async move { + let (close_tx, close_rx) = oneshot::channel::<()>(); + tokio::task::spawn( + transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx) + .instrument(Span::current()), + ); + + let _ = + transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None).await; + } + .instrument(Span::current()), + ); + + if let Err(response) = inject_cookie(&req_protocol, &mut response, &remote_addr, Either::Left) { + return response; + } + + if let Some(content_type) = req_content_type { + response.headers_mut().insert(CONTENT_TYPE, content_type); + } + + response +} diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs new file mode 100644 index 0000000..e78fe2f --- /dev/null +++ b/src/tunnel/server/handler_websocket.rs @@ -0,0 +1,145 @@ +use crate::restrictions::types::RestrictionsRules; +use crate::tunnel::server::utils::{ + extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, inject_cookie, validate_tunnel, +}; +use crate::tunnel::server::WsServer; +use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; +use crate::tunnel::{transport, RemoteAddr}; +use hyper::body::Incoming; +use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; +use hyper::{http, Request, Response, StatusCode}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::oneshot; +use tracing::{error, info, warn, Instrument, Span}; + +pub(super) async fn ws_server_upgrade( + server: WsServer, + restrictions: Arc, + restrict_path_prefix: Option, + mut client_addr: SocketAddr, + mut req: Request, +) -> Response { + if !fastwebsockets::upgrade::is_upgrade_request(&req) { + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + + match extract_x_forwarded_for(&req) { + Ok(Some((x_forward_for, x_forward_for_str))) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); + } + Ok(_) => {} + Err(err) => return err, + }; + + let path_prefix = match extract_path_prefix(&req) { + Ok(p) => p, + Err(err) => return err, + }; + + if let Some(restrict_path) = restrict_path_prefix { + if path_prefix != restrict_path { + warn!( + "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", + path_prefix, restrict_path + ); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + } + + let jwt = match extract_tunnel_info(&req) { + Ok(jwt) => jwt, + Err(err) => return err, + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + + let remote = match RemoteAddr::try_from(jwt.claims) { + Ok(remote) => remote, + Err(err) => { + warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + }; + + let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { + Ok(matched_restriction) => { + info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); + matched_restriction + } + Err(err) => return err, + }; + + let req_protocol = remote.protocol.clone(); + let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap(); + } + }; + + let (remote_addr, local_rx, local_tx) = tunnel; + info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(format!("Invalid upgrade request: {:?}", err)) + .unwrap(); + } + }; + + tokio::spawn( + async move { + let (ws_rx, mut ws_tx) = match fut.await { + Ok(ws) => ws.split(tokio::io::split), + Err(err) => { + error!("Error during http upgrade request: {:?}", err); + return; + } + }; + let (close_tx, close_rx) = oneshot::channel::<()>(); + ws_tx.set_auto_apply_mask(server.config.websocket_mask_frame); + + tokio::task::spawn( + transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) + .instrument(Span::current()), + ); + + let _ = + transport::io::propagate_local_to_remote(local_rx, WebsocketTunnelWrite::new(ws_tx), close_tx, None) + .await; + } + .instrument(Span::current()), + ); + + let mut response = Response::from_parts(response.into_parts().0, "".to_string()); + if let Err(response) = inject_cookie(&req_protocol, &mut response, &remote_addr, |s| s) { + return response; + } + + response + .headers_mut() + .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("v1")); + + response +} diff --git a/src/tunnel/server/mod.rs b/src/tunnel/server/mod.rs index 2d22698..d318810 100644 --- a/src/tunnel/server/mod.rs +++ b/src/tunnel/server/mod.rs @@ -1,5 +1,8 @@ #![allow(clippy::module_inception)] +mod handler_http2; +mod handler_websocket; mod server; +mod utils; pub use server::TlsServerConfig; pub use server::WsServer; diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 41496e6..212c0f0 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -1,31 +1,25 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; -use bytes::Bytes; use futures_util::{pin_mut, FutureExt, StreamExt}; -use http_body_util::combinators::BoxBody; -use http_body_util::{BodyStream, Either, StreamBody}; -use std::cmp::min; +use http_body_util::Either; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use std::ops::Deref; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use crate::tunnel::{transport, tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use crate::tunnel::RemoteAddr; use crate::{protocols, LocalProtocol}; -use hyper::body::{Frame, Incoming}; -use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL}; -use hyper::http::HeaderValue; +use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; -use hyper::{http, Request, Response, StatusCode, Version}; +use hyper::{http, Request, StatusCode, Version}; use hyper_util::rt::TokioExecutor; -use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; use socket2::SockRef; @@ -34,26 +28,23 @@ use crate::protocols::dns::DnsResolver; use crate::protocols::tls; use crate::protocols::udp::{UdpStream, UdpStreamWriter}; use crate::restrictions::config_reloader::RestrictionsRulesReloader; -use crate::restrictions::types::{ - AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, -}; +use crate::restrictions::types::{RestrictionConfig, RestrictionsRules}; use crate::tunnel::connectors::{TcpTunnelConnector, TunnelConnector, UdpTunnelConnector}; use crate::tunnel::listeners::{ new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener, TunnelListener, }; +use crate::tunnel::server::handler_http2::http_server_upgrade; +use crate::tunnel::server::handler_websocket::ws_server_upgrade; +use crate::tunnel::server::utils::find_mapped_port; use crate::tunnel::tls_reloader::TlsReloader; -use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; -use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::select; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::TlsAcceptor; -use tokio_stream::wrappers::ReceiverStream; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; -use uuid::Uuid; #[derive(Debug)] pub struct TlsServerConfig { @@ -88,7 +79,7 @@ impl WsServer { } } - async fn run_tunnel( + pub(super) async fn run_tunnel( &self, restriction: &RestrictionConfig, remote: RemoteAddr, @@ -452,466 +443,6 @@ impl Debug for WsServerConfig { } } -/// Checks if the requested (remote) port has been mapped in the configuration to another port. -/// If it is not mapped the original port number is returned. -#[inline] -fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -> u16 { - // Determine if the requested port is to be mapped to a different port. - let remote_port = restriction - .allow - .iter() - .find_map(|allow| { - if let AllowConfig::ReverseTunnel(allow) = allow { - return allow.port_mapping.get(&req_port).cloned(); - } - None - }) - .unwrap_or(req_port); - - if req_port != remote_port { - info!("Client requested port {} was mapped to {}", req_port, remote_port); - } - - remote_port -} - -#[inline] -fn extract_x_forwarded_for(req: &Request) -> Result, Response> { - let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { - return Ok(None); - }; - - // X-Forwarded-For: , , - let x_forward_for = x_forward_for.to_str().unwrap_or_default(); - let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for); - let ip: Option = x_forward_for.parse().ok(); - Ok(ip.map(|ip| (ip, x_forward_for))) -} - -#[inline] -fn extract_path_prefix(req: &Request) -> Result<&str, Response> { - let path = req.uri().path(); - let min_len = min(path.len(), 1); - if &path[0..min_len] != "/" { - warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); - } - - let Some((l, r)) = path[min_len..].split_once('/') else { - warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".into()) - .unwrap()); - }; - - if !r.ends_with("events") { - warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".into()) - .unwrap()); - } - - Ok(l) -} - -#[inline] -fn extract_tunnel_info(req: &Request) -> Result, Response> { - let jwt = req - .headers() - .get(SEC_WEBSOCKET_PROTOCOL) - .and_then(|header| header.to_str().ok()) - .and_then(|header| header.split_once(JWT_HEADER_PREFIX)) - .map(|(_prefix, jwt)| jwt) - .or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok())) - .unwrap_or_default(); - - let (validation, decode_key) = JWT_DECODE.deref(); - let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) { - Ok(jwt) => jwt, - err => { - warn!( - "error while decoding jwt for tunnel info {:?} header {:?}", - err, - req.headers().get(SEC_WEBSOCKET_PROTOCOL) - ); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); - } - }; - - Ok(jwt) -} - -#[inline] -fn validate_tunnel<'a>( - remote: &RemoteAddr, - path_prefix: &str, - restrictions: &'a RestrictionsRules, -) -> Result<&'a RestrictionConfig, Response> { - for restriction in &restrictions.restrictions { - if !restriction.r#match.iter().all(|m| match m { - MatchConfig::Any => true, - MatchConfig::PathPrefix(path) => path.is_match(path_prefix), - }) { - continue; - } - - for allow in &restriction.allow { - match allow { - AllowConfig::ReverseTunnel(allow) => { - if !remote.protocol.is_reverse_tunnel() { - continue; - } - - if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { - continue; - } - - if !allow.protocol.is_empty() - && !allow - .protocol - .contains(&ReverseTunnelConfigProtocol::from(&remote.protocol)) - { - continue; - } - - match &remote.host { - Host::Domain(_) => {} - Host::Ipv4(ip) => { - let ip = IpAddr::V4(*ip); - for cidr in &allow.cidr { - if cidr.contains(&ip) { - return Ok(restriction); - } - } - } - Host::Ipv6(ip) => { - let ip = IpAddr::V6(*ip); - for cidr in &allow.cidr { - if cidr.contains(&ip) { - return Ok(restriction); - } - } - } - } - } - - AllowConfig::Tunnel(allow) => { - if remote.protocol.is_reverse_tunnel() { - continue; - } - - if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { - continue; - } - - if !allow.protocol.is_empty() - && !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol)) - { - continue; - } - - match &remote.host { - Host::Domain(host) => { - if allow.host.is_match(host) { - return Ok(restriction); - } - } - Host::Ipv4(ip) => { - let ip = IpAddr::V4(*ip); - for cidr in &allow.cidr { - if cidr.contains(&ip) { - return Ok(restriction); - } - } - } - Host::Ipv6(ip) => { - let ip = IpAddr::V6(*ip); - for cidr in &allow.cidr { - if cidr.contains(&ip) { - return Ok(restriction); - } - } - } - } - } - } - } - } - - warn!("Rejecting connection with not allowed destination: {:?}", remote); - Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()) -} - -async fn ws_server_upgrade( - server: WsServer, - restrictions: Arc, - restrict_path_prefix: Option, - mut client_addr: SocketAddr, - mut req: Request, -) -> Response { - if !fastwebsockets::upgrade::is_upgrade_request(&req) { - warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - - match extract_x_forwarded_for(&req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(err) => return err, - }; - - let path_prefix = match extract_path_prefix(&req) { - Ok(p) => p, - Err(err) => return err, - }; - - if let Some(restrict_path) = restrict_path_prefix { - if path_prefix != restrict_path { - warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path - ); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - } - - let jwt = match extract_tunnel_info(&req) { - Ok(jwt) => jwt, - Err(err) => return err, - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - }; - - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Ok(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - Err(err) => return err, - }; - - let req_protocol = remote.protocol.clone(); - let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - } - }; - - let (remote_addr, local_rx, local_tx) = tunnel; - info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); - let (mut response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(format!("Invalid upgrade request: {:?}", err)) - .unwrap(); - } - }; - - tokio::spawn( - async move { - let (ws_rx, mut ws_tx) = match fut.await { - Ok(ws) => ws.split(tokio::io::split), - Err(err) => { - error!("Error during http upgrade request: {:?}", err); - return; - } - }; - let (close_tx, close_rx) = oneshot::channel::<()>(); - ws_tx.set_auto_apply_mask(server.config.websocket_mask_frame); - - tokio::task::spawn( - transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) - .instrument(Span::current()), - ); - - let _ = - transport::io::propagate_local_to_remote(local_rx, WebsocketTunnelWrite::new(ws_tx), close_tx, None) - .await; - } - .instrument(Span::current()), - ); - - if matches!( - req_protocol, - LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } - ) { - let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), &remote_addr)) else { - error!("Bad headervalue for reverse socks5: {} {}", remote_addr.host, remote_addr.port); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap(); - }; - response.headers_mut().insert(COOKIE, header_val); - } - response - .headers_mut() - .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("v1")); - - Response::from_parts(response.into_parts().0, "".to_string()) -} - -async fn http_server_upgrade( - server: WsServer, - restrictions: Arc, - restrict_path_prefix: Option, - mut client_addr: SocketAddr, - mut req: Request, -) -> Response>> { - match extract_x_forwarded_for(&req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(err) => return err.map(Either::Left), - }; - - let path_prefix = match extract_path_prefix(&req) { - Ok(p) => p, - Err(err) => return err.map(Either::Left), - }; - - if let Some(restrict_path) = restrict_path_prefix { - if path_prefix != restrict_path { - warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path - ); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - } - - let jwt = match extract_tunnel_info(&req) { - Ok(jwt) => jwt, - Err(err) => return err.map(Either::Left), - }; - - Span::current().record("id", &jwt.claims.id); - Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - }; - - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Ok(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - Err(err) => return err.map(Either::Left), - }; - - let req_protocol = remote.protocol.clone(); - let tunnel = match server.run_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - } - }; - - let (remote_addr, local_rx, local_tx) = tunnel; - info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); - - let req_content_type = req.headers_mut().remove(CONTENT_TYPE); - let ws_rx = BodyStream::new(req.into_body()); - let (ws_tx, rx) = mpsc::channel::(1024); - let body = BoxBody::new(StreamBody::new( - ReceiverStream::new(rx).map(|s| -> anyhow::Result> { Ok(Frame::data(s)) }), - )); - - let mut response = Response::builder() - .status(StatusCode::OK) - .body(Either::Right(body)) - .expect("bug: failed to build response"); - - tokio::spawn( - async move { - let (close_tx, close_rx) = oneshot::channel::<()>(); - tokio::task::spawn( - transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx) - .instrument(Span::current()), - ); - - let _ = - transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None).await; - } - .instrument(Span::current()), - ); - - if matches!(req_protocol, LocalProtocol::ReverseSocks5 { .. }) { - let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), &remote_addr)) else { - error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); - return http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Either::Left("Invalid upgrade request".to_string())) - .unwrap(); - }; - response.headers_mut().insert(COOKIE, header_val); - } - - if let Some(content_type) = req_content_type { - response.headers_mut().insert(CONTENT_TYPE, content_type); - } - - response -} - struct TlsContext<'a> { tls_acceptor: Arc, tls_reloader: TlsReloader, diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs new file mode 100644 index 0000000..e18bf22 --- /dev/null +++ b/src/tunnel/server/utils.rs @@ -0,0 +1,242 @@ +use crate::restrictions::types::{ + AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, +}; +use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use crate::LocalProtocol; +use hyper::body::{Body, Incoming}; +use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL}; +use hyper::{http, Request, Response, StatusCode}; +use jsonwebtoken::TokenData; +use std::cmp::min; +use std::net::IpAddr; +use std::ops::Deref; +use tracing::{error, info, warn}; +use url::Host; +use uuid::Uuid; + +/// Checks if the requested (remote) port has been mapped in the configuration to another port. +/// If it is not mapped the original port number is returned. +#[inline] +pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -> u16 { + // Determine if the requested port is to be mapped to a different port. + let remote_port = restriction + .allow + .iter() + .find_map(|allow| { + if let AllowConfig::ReverseTunnel(allow) = allow { + return allow.port_mapping.get(&req_port).cloned(); + } + None + }) + .unwrap_or(req_port); + + if req_port != remote_port { + info!("Client requested port {} was mapped to {}", req_port, remote_port); + } + + remote_port +} + +#[inline] +pub(super) fn extract_x_forwarded_for(req: &Request) -> Result, Response> { + let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { + return Ok(None); + }; + + // X-Forwarded-For: , , + let x_forward_for = x_forward_for.to_str().unwrap_or_default(); + let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for); + let ip: Option = x_forward_for.parse().ok(); + Ok(ip.map(|ip| (ip, x_forward_for))) +} + +#[inline] +pub(super) fn extract_path_prefix(req: &Request) -> Result<&str, Response> { + let path = req.uri().path(); + let min_len = min(path.len(), 1); + if &path[0..min_len] != "/" { + warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + + let Some((l, r)) = path[min_len..].split_once('/') else { + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".into()) + .unwrap()); + }; + + if !r.ends_with("events") { + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".into()) + .unwrap()); + } + + Ok(l) +} + +#[inline] +pub(super) fn extract_tunnel_info(req: &Request) -> Result, Response> { + let jwt = req + .headers() + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|header| header.to_str().ok()) + .and_then(|header| header.split_once(JWT_HEADER_PREFIX)) + .map(|(_prefix, jwt)| jwt) + .or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok())) + .unwrap_or_default(); + + let (validation, decode_key) = JWT_DECODE.deref(); + let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) { + Ok(jwt) => jwt, + err => { + warn!( + "error while decoding jwt for tunnel info {:?} header {:?}", + err, + req.headers().get(SEC_WEBSOCKET_PROTOCOL) + ); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()); + } + }; + + Ok(jwt) +} + +#[inline] +pub(super) fn validate_tunnel<'a>( + remote: &RemoteAddr, + path_prefix: &str, + restrictions: &'a RestrictionsRules, +) -> Result<&'a RestrictionConfig, Response> { + for restriction in &restrictions.restrictions { + if !restriction.r#match.iter().all(|m| match m { + MatchConfig::Any => true, + MatchConfig::PathPrefix(path) => path.is_match(path_prefix), + }) { + continue; + } + + for allow in &restriction.allow { + match allow { + AllowConfig::ReverseTunnel(allow) => { + if !remote.protocol.is_reverse_tunnel() { + continue; + } + + if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { + continue; + } + + if !allow.protocol.is_empty() + && !allow + .protocol + .contains(&ReverseTunnelConfigProtocol::from(&remote.protocol)) + { + continue; + } + + match &remote.host { + Host::Domain(_) => {} + Host::Ipv4(ip) => { + let ip = IpAddr::V4(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + Host::Ipv6(ip) => { + let ip = IpAddr::V6(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + } + } + + AllowConfig::Tunnel(allow) => { + if remote.protocol.is_reverse_tunnel() { + continue; + } + + if !allow.port.is_empty() && !allow.port.iter().any(|range| range.contains(&remote.port)) { + continue; + } + + if !allow.protocol.is_empty() + && !allow.protocol.contains(&TunnelConfigProtocol::from(&remote.protocol)) + { + continue; + } + + match &remote.host { + Host::Domain(host) => { + if allow.host.is_match(host) { + return Ok(restriction); + } + } + Host::Ipv4(ip) => { + let ip = IpAddr::V4(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + Host::Ipv6(ip) => { + let ip = IpAddr::V6(*ip); + for cidr in &allow.cidr { + if cidr.contains(&ip) { + return Ok(restriction); + } + } + } + } + } + } + } + } + + warn!("Rejecting connection with not allowed destination: {:?}", remote); + Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid upgrade request".to_string()) + .unwrap()) +} + +pub(super) fn inject_cookie( + req_protocol: &LocalProtocol, + response: &mut http::Response, + remote_addr: &RemoteAddr, + mk_body: impl FnOnce(String) -> B, +) -> Result<(), Response> +where + B: Body, +{ + if matches!( + req_protocol, + LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } + ) { + let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), remote_addr)) else { + error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); + return Err(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(mk_body("Invalid upgrade request".to_string())) + .unwrap()); + }; + response.headers_mut().insert(COOKIE, header_val); + } + + Ok(()) +}