diff --git a/src/tunnel/client/client.rs b/src/tunnel/client/client.rs index 7cd7bb3..54fa375 100644 --- a/src/tunnel/client/client.rs +++ b/src/tunnel/client/client.rs @@ -4,14 +4,13 @@ use crate::tunnel::client::WsClientConfig; use crate::tunnel::connectors::TunnelConnector; use crate::tunnel::listeners::TunnelListener; use crate::tunnel::tls_reloader::TlsReloader; -use crate::tunnel::transport::{TunnelReader, TunnelWriter}; -use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; +use crate::tunnel::transport::io::{TunnelReader, TunnelWriter}; +use crate::tunnel::transport::jwt_token_to_tunnel; +use crate::tunnel::{RemoteAddr, TransportScheme}; use anyhow::Context; use futures_util::pin_mut; use hyper::header::COOKIE; -use jsonwebtoken::TokenData; use log::debug; -use std::ops::Deref; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; @@ -179,11 +178,7 @@ impl WsClient { .headers .get(COOKIE) .and_then(|h| h.to_str().ok()) - .and_then(|h| { - let (validation, decode_key) = JWT_DECODE.deref(); - let jwt: Option> = jsonwebtoken::decode(h, decode_key, validation).ok(); - jwt - }) + .and_then(|h| jwt_token_to_tunnel(h).ok()) .map(|jwt| RemoteAddr { protocol: jwt.claims.p, host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 15cc12a..6987fe4 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -6,68 +6,13 @@ mod tls_reloader; mod transport; use crate::TlsClientConfig; -use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; -use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; use std::fmt::{Debug, Display, Formatter}; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; +use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::path::PathBuf; use std::str::FromStr; use std::time::Duration; use url::Host; -use uuid::Uuid; - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct JwtTunnelConfig { - pub id: String, // tunnel id - pub p: LocalProtocol, // protocol to use - pub r: String, // remote host - pub rp: u16, // remote port -} - -impl JwtTunnelConfig { - fn new(request_id: Uuid, dest: &RemoteAddr) -> Self { - Self { - id: request_id.to_string(), - p: match dest.protocol { - LocalProtocol::Tcp { .. } => dest.protocol.clone(), - LocalProtocol::Udp { .. } => dest.protocol.clone(), - LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false }, - LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false }, - LocalProtocol::HttpProxy { .. } => dest.protocol.clone(), - LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, - LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(), - LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(), - LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false }, - LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, - LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false }, - LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(), - LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(), - }, - r: dest.host.to_string(), - rp: dest.port, - } - } -} - -fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String { - let cfg = JwtTunnelConfig::new(request_id, tunnel); - let (alg, secret) = JWT_KEY.deref(); - jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default() -} - -static JWT_HEADER_PREFIX: &str = "authorization.bearer."; -static JWT_SECRET: &[u8; 15] = b"champignonfrais"; -static JWT_KEY: Lazy<(Header, EncodingKey)> = - Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET))); - -static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { - let mut validation = Validation::new(Algorithm::HS256); - validation.required_spec_claims = HashSet::with_capacity(0); - (validation, DecodingKey::from_secret(JWT_SECRET)) -}); #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum LocalProtocol { @@ -122,6 +67,10 @@ impl LocalProtocol { | Self::ReverseHttpProxy { .. } ) } + + pub const fn is_dynamic_reverse_tunnel(&self) -> bool { + matches!(self, |Self::ReverseSocks5 { .. }| Self::ReverseHttpProxy { .. }) + } } #[derive(Debug, Clone)] @@ -285,20 +234,17 @@ impl TransportAddr { } } -impl TryFrom for RemoteAddr { - type Error = anyhow::Error; - fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result { - Ok(Self { - protocol: jwt.p, - host: Host::parse(&jwt.r)?, - port: jwt.rp, - }) - } -} - pub fn to_host_port(addr: SocketAddr) -> (Host, u16) { match addr.ip() { IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()), IpAddr::V6(ip) => (Host::Ipv6(ip), addr.port()), } } + +pub fn try_to_sock_addr((host, port): (Host, u16)) -> anyhow::Result { + match host { + Host::Domain(_) => Err(anyhow::anyhow!("Cannot convert domain to socket address")), + Host::Ipv4(ip) => Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))), + Host::Ipv6(ip) => Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))), + } +} diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index ae4102f..538be2c 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use std::time::Duration; use crate::protocols; -use crate::tunnel::{LocalProtocol, RemoteAddr}; +use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr}; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; @@ -33,8 +33,7 @@ use crate::tunnel::server::handler_http2::http_server_upgrade; use crate::tunnel::server::handler_websocket::ws_server_upgrade; use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer; use crate::tunnel::server::utils::{ - bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, try_to_sock_aadr, - validate_tunnel, + bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel, }; use crate::tunnel::tls_reloader::TlsReloader; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -143,10 +142,7 @@ impl WsServer { }; let req_protocol = remote.protocol.clone(); - let inject_cookie = matches!( - req_protocol, - LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. } - ); + let inject_cookie = req_protocol.is_dynamic_reverse_tunnel(); let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await { Ok(ret) => ret, Err(err) => { @@ -213,7 +209,7 @@ impl WsServer { let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = try_to_sock_aadr(local_srv.clone())?; + let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await }; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; @@ -224,7 +220,7 @@ impl WsServer { let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = try_to_sock_aadr(local_srv.clone())?; + let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await }; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) @@ -234,7 +230,7 @@ impl WsServer { let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = try_to_sock_aadr(local_srv.clone())?; + let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await }; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; @@ -246,7 +242,7 @@ impl WsServer { let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = try_to_sock_aadr(local_srv.clone())?; + let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await }; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; @@ -259,7 +255,7 @@ impl WsServer { let remote_port = find_mapped_port(remote.port, restriction); let local_srv = (remote.host, remote_port); - let bind = try_to_sock_aadr(local_srv.clone())?; + let bind = try_to_sock_addr(local_srv.clone())?; let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await }; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs index 1cff621..a77013e 100644 --- a/src/tunnel/server/utils.rs +++ b/src/tunnel/server/utils.rs @@ -1,7 +1,8 @@ use crate::restrictions::types::{ AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, }; -use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX}; +use crate::tunnel::RemoteAddr; use bytes::Bytes; use http_body_util::combinators::BoxBody; use http_body_util::Either; @@ -10,8 +11,7 @@ use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::{http, Request, Response, StatusCode}; use jsonwebtoken::TokenData; use std::cmp::min; -use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::ops::Deref; +use std::net::IpAddr; use tracing::{error, info, warn}; use url::Host; use uuid::Uuid; @@ -92,8 +92,7 @@ pub(super) fn extract_tunnel_info(req: &Request) -> Result jwt, err => { warn!( @@ -218,11 +217,3 @@ pub(super) fn inject_cookie(response: &mut http::Response, remote_add Ok(()) } - -pub fn try_to_sock_aadr((host, port): (Host, u16)) -> anyhow::Result { - match host { - Host::Domain(_) => Err(anyhow::anyhow!("Cannot convert domain to socket address")), - Host::Ipv4(ip) => Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))), - Host::Ipv6(ip) => Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))), - } -} diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs index 951f2df..c1cc713 100644 --- a/src/tunnel/transport/http2.rs +++ b/src/tunnel/transport/http2.rs @@ -1,6 +1,8 @@ +use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::client::WsClient; -use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; -use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme}; +use crate::tunnel::transport::headers_from_file; +use crate::tunnel::transport::jwt::tunnel_to_jwt_token; +use crate::tunnel::{RemoteAddr, TransportScheme}; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; use http_body_util::{BodyExt, BodyStream, StreamBody}; diff --git a/src/tunnel/transport/io.rs b/src/tunnel/transport/io.rs index a3706a3..14457a7 100644 --- a/src/tunnel/transport/io.rs +++ b/src/tunnel/transport/io.rs @@ -1,6 +1,8 @@ -use crate::tunnel::transport::{TunnelRead, TunnelWrite}; -use bytes::BufMut; +use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; +use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; +use bytes::{BufMut, BytesMut}; use futures_util::{pin_mut, FutureExt}; +use std::future::Future; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::select; @@ -9,6 +11,71 @@ use tokio::time::Instant; use tracing::log::debug; use tracing::{error, info, warn}; +pub(super) static MAX_PACKET_LENGTH: usize = 64 * 1024; + +pub trait TunnelWrite: Send + 'static { + fn buf_mut(&mut self) -> &mut BytesMut; + fn write(&mut self) -> impl Future> + Send; + fn ping(&mut self) -> impl Future> + Send; + fn close(&mut self) -> impl Future> + Send; +} + +pub trait TunnelRead: Send + 'static { + fn copy( + &mut self, + writer: impl AsyncWrite + Unpin + Send, + ) -> impl Future> + Send; +} + +pub enum TunnelReader { + Websocket(WebsocketTunnelRead), + Http2(Http2TunnelRead), +} + +impl TunnelRead for TunnelReader { + async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> { + match self { + Self::Websocket(s) => s.copy(writer).await, + Self::Http2(s) => s.copy(writer).await, + } + } +} + +pub enum TunnelWriter { + Websocket(WebsocketTunnelWrite), + Http2(Http2TunnelWrite), +} + +impl TunnelWrite for TunnelWriter { + fn buf_mut(&mut self) -> &mut BytesMut { + match self { + Self::Websocket(s) => s.buf_mut(), + Self::Http2(s) => s.buf_mut(), + } + } + + async fn write(&mut self) -> Result<(), std::io::Error> { + match self { + Self::Websocket(s) => s.write().await, + Self::Http2(s) => s.write().await, + } + } + + async fn ping(&mut self) -> Result<(), std::io::Error> { + match self { + Self::Websocket(s) => s.ping().await, + Self::Http2(s) => s.ping().await, + } + } + + async fn close(&mut self) -> Result<(), std::io::Error> { + match self { + Self::Websocket(s) => s.close().await, + Self::Http2(s) => s.close().await, + } + } +} + pub async fn propagate_local_to_remote( local_rx: impl AsyncRead, mut ws_tx: impl TunnelWrite, diff --git a/src/tunnel/transport/jwt.rs b/src/tunnel/transport/jwt.rs new file mode 100644 index 0000000..74bc416 --- /dev/null +++ b/src/tunnel/transport/jwt.rs @@ -0,0 +1,75 @@ +use crate::tunnel::{LocalProtocol, RemoteAddr}; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use std::ops::Deref; +use url::Host; +use uuid::Uuid; + +pub static JWT_HEADER_PREFIX: &str = "authorization.bearer."; +static JWT_SECRET: &[u8; 15] = b"champignonfrais"; +static JWT_KEY: Lazy<(Header, EncodingKey)> = + Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET))); + +static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { + let mut validation = Validation::new(Algorithm::HS256); + validation.required_spec_claims = HashSet::with_capacity(0); + (validation, DecodingKey::from_secret(JWT_SECRET)) +}); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JwtTunnelConfig { + pub id: String, // tunnel id + pub p: LocalProtocol, // protocol to use + pub r: String, // remote host + pub rp: u16, // remote port +} + +impl JwtTunnelConfig { + fn new(request_id: Uuid, dest: &RemoteAddr) -> Self { + Self { + id: request_id.to_string(), + p: match dest.protocol { + LocalProtocol::Tcp { .. } => dest.protocol.clone(), + LocalProtocol::Udp { .. } => dest.protocol.clone(), + LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::HttpProxy { .. } => dest.protocol.clone(), + LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, + LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(), + LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(), + LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, + LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false }, + LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(), + LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(), + }, + r: dest.host.to_string(), + rp: dest.port, + } + } +} + +pub fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String { + let cfg = JwtTunnelConfig::new(request_id, tunnel); + let (alg, secret) = JWT_KEY.deref(); + jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default() +} + +pub fn jwt_token_to_tunnel(token: &str) -> anyhow::Result> { + let (validation, decode_key) = JWT_DECODE.deref(); + let jwt: TokenData = jsonwebtoken::decode(token, decode_key, validation)?; + Ok(jwt) +} + +impl TryFrom for RemoteAddr { + type Error = anyhow::Error; + fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result { + Ok(Self { + protocol: jwt.p, + host: Host::parse(&jwt.r)?, + port: jwt.rp, + }) + } +} diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs index 56b9788..059e696 100644 --- a/src/tunnel/transport/mod.rs +++ b/src/tunnel/transport/mod.rs @@ -1,89 +1,23 @@ -use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; -use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; -use bytes::BytesMut; +use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; -use std::future::Future; use std::io::{BufRead, BufReader}; use std::path::Path; use std::str::FromStr; -use tokio::io::AsyncWrite; use tracing::error; pub mod http2; pub mod io; +mod jwt; pub mod websocket; - -static MAX_PACKET_LENGTH: usize = 64 * 1024; - -pub trait TunnelWrite: Send + 'static { - fn buf_mut(&mut self) -> &mut BytesMut; - fn write(&mut self) -> impl Future> + Send; - fn ping(&mut self) -> impl Future> + Send; - fn close(&mut self) -> impl Future> + Send; -} - -pub trait TunnelRead: Send + 'static { - fn copy( - &mut self, - writer: impl AsyncWrite + Unpin + Send, - ) -> impl Future> + Send; -} - -pub enum TunnelReader { - Websocket(WebsocketTunnelRead), - Http2(Http2TunnelRead), -} - -impl TunnelRead for TunnelReader { - async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> { - match self { - Self::Websocket(s) => s.copy(writer).await, - Self::Http2(s) => s.copy(writer).await, - } - } -} - -pub enum TunnelWriter { - Websocket(WebsocketTunnelWrite), - Http2(Http2TunnelWrite), -} - -impl TunnelWrite for TunnelWriter { - fn buf_mut(&mut self) -> &mut BytesMut { - match self { - Self::Websocket(s) => s.buf_mut(), - Self::Http2(s) => s.buf_mut(), - } - } - - async fn write(&mut self) -> Result<(), std::io::Error> { - match self { - Self::Websocket(s) => s.write().await, - Self::Http2(s) => s.write().await, - } - } - - async fn ping(&mut self) -> Result<(), std::io::Error> { - match self { - Self::Websocket(s) => s.ping().await, - Self::Http2(s) => s.ping().await, - } - } - - async fn close(&mut self) -> Result<(), std::io::Error> { - match self { - Self::Websocket(s) => s.close().await, - Self::Http2(s) => s.close().await, - } - } -} +pub use jwt::jwt_token_to_tunnel; +pub use jwt::tunnel_to_jwt_token; +pub use jwt::JwtTunnelConfig; +pub use jwt::JWT_HEADER_PREFIX; #[allow(clippy::type_complexity)] #[inline] pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec<(HeaderName, HeaderValue)>) { - static HOST_HEADER: HeaderName = HeaderName::from_static("host"); - let file = match std::fs::File::open(path) { Ok(file) => file, Err(err) => { @@ -100,7 +34,7 @@ pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec let (header, value) = line.split_once(':')?; let header = HeaderName::from_str(header.trim()).ok()?; let value = HeaderValue::from_str(value.trim()).ok()?; - if header == HOST_HEADER { + if header == HOST { host_header = Some((header, value)); return None; } diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index f1ebe9d..317a5db 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,6 +1,8 @@ +use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::client::WsClient; -use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; -use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; +use crate::tunnel::transport::headers_from_file; +use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX}; +use crate::tunnel::RemoteAddr; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};