This commit is contained in:
Σrebe - Romain GERARD 2024-08-04 23:02:38 +02:00
parent a468428791
commit 8c4d091b9e
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
9 changed files with 188 additions and 180 deletions

View file

@ -4,14 +4,13 @@ use crate::tunnel::client::WsClientConfig;
use crate::tunnel::connectors::TunnelConnector; use crate::tunnel::connectors::TunnelConnector;
use crate::tunnel::listeners::TunnelListener; use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::tunnel::transport::io::{TunnelReader, TunnelWriter};
use crate::tunnel::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; use crate::tunnel::transport::jwt_token_to_tunnel;
use crate::tunnel::{RemoteAddr, TransportScheme};
use anyhow::Context; use anyhow::Context;
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::header::COOKIE; use hyper::header::COOKIE;
use jsonwebtoken::TokenData;
use log::debug; use log::debug;
use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
@ -179,11 +178,7 @@ impl WsClient {
.headers .headers
.get(COOKIE) .get(COOKIE)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.and_then(|h| { .and_then(|h| jwt_token_to_tunnel(h).ok())
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
jwt
})
.map(|jwt| RemoteAddr { .map(|jwt| RemoteAddr {
protocol: jwt.claims.p, protocol: jwt.claims.p,
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),

View file

@ -6,68 +6,13 @@ mod tls_reloader;
mod transport; mod transport;
use crate::TlsClientConfig; use crate::TlsClientConfig;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::ops::Deref;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use std::time::Duration; use std::time::Duration;
use url::Host; use url::Host;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String, // tunnel id
pub p: LocalProtocol, // protocol to use
pub r: String, // remote host
pub rp: u16, // remote port
}
impl JwtTunnelConfig {
fn new(request_id: Uuid, dest: &RemoteAddr) -> Self {
Self {
id: request_id.to_string(),
p: match dest.protocol {
LocalProtocol::Tcp { .. } => dest.protocol.clone(),
LocalProtocol::Udp { .. } => dest.protocol.clone(),
LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::HttpProxy { .. } => dest.protocol.clone(),
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(),
LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(),
LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(),
LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(),
},
r: dest.host.to_string(),
rp: dest.port,
}
}
}
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
static JWT_HEADER_PREFIX: &str = "authorization.bearer.";
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> =
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum LocalProtocol { pub enum LocalProtocol {
@ -122,6 +67,10 @@ impl LocalProtocol {
| Self::ReverseHttpProxy { .. } | Self::ReverseHttpProxy { .. }
) )
} }
pub const fn is_dynamic_reverse_tunnel(&self) -> bool {
matches!(self, |Self::ReverseSocks5 { .. }| Self::ReverseHttpProxy { .. })
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -285,20 +234,17 @@ impl TransportAddr {
} }
} }
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
Ok(Self {
protocol: jwt.p,
host: Host::parse(&jwt.r)?,
port: jwt.rp,
})
}
}
pub fn to_host_port(addr: SocketAddr) -> (Host, u16) { pub fn to_host_port(addr: SocketAddr) -> (Host, u16) {
match addr.ip() { match addr.ip() {
IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()), IpAddr::V4(ip) => (Host::Ipv4(ip), addr.port()),
IpAddr::V6(ip) => (Host::Ipv6(ip), addr.port()), IpAddr::V6(ip) => (Host::Ipv6(ip), addr.port()),
} }
} }
pub fn try_to_sock_addr((host, port): (Host, u16)) -> anyhow::Result<SocketAddr> {
match host {
Host::Domain(_) => Err(anyhow::anyhow!("Cannot convert domain to socket address")),
Host::Ipv4(ip) => Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))),
Host::Ipv6(ip) => Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))),
}
}

View file

@ -13,7 +13,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use crate::protocols; use crate::protocols;
use crate::tunnel::{LocalProtocol, RemoteAddr}; use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::server::conn::{http1, http2}; use hyper::server::conn::{http1, http2};
use hyper::service::service_fn; use hyper::service::service_fn;
@ -33,8 +33,7 @@ use crate::tunnel::server::handler_http2::http_server_upgrade;
use crate::tunnel::server::handler_websocket::ws_server_upgrade; use crate::tunnel::server::handler_websocket::ws_server_upgrade;
use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer; use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer;
use crate::tunnel::server::utils::{ use crate::tunnel::server::utils::{
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, try_to_sock_aadr, bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel,
validate_tunnel,
}; };
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@ -143,10 +142,7 @@ impl WsServer {
}; };
let req_protocol = remote.protocol.clone(); let req_protocol = remote.protocol.clone();
let inject_cookie = matches!( let inject_cookie = req_protocol.is_dynamic_reverse_tunnel();
req_protocol,
LocalProtocol::ReverseSocks5 { .. } | LocalProtocol::ReverseHttpProxy { .. }
);
let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await { let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
@ -213,7 +209,7 @@ impl WsServer {
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?; let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await }; let listening_server = async { TcpTunnelListener::new(bind, local_srv.clone(), false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
@ -224,7 +220,7 @@ impl WsServer {
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?; let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await }; let listening_server = async { UdpTunnelListener::new(bind, local_srv.clone(), timeout).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
@ -234,7 +230,7 @@ impl WsServer {
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?; let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await }; let listening_server = async { Socks5TunnelListener::new(bind, timeout, credentials).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
@ -246,7 +242,7 @@ impl WsServer {
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?; let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await }; let listening_server = async { HttpProxyTunnelListener::new(bind, timeout, credentials, false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;
@ -259,7 +255,7 @@ impl WsServer {
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = try_to_sock_aadr(local_srv.clone())?; let bind = try_to_sock_addr(local_srv.clone())?;
let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await }; let listening_server = async { UnixTunnelListener::new(path, local_srv.clone(), false).await };
let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?; let ((local_rx, local_tx), remote) = SERVERS.run_listening_server(bind, listening_server).await?;

View file

@ -1,7 +1,8 @@
use crate::restrictions::types::{ use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
}; };
use crate::tunnel::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use bytes::Bytes; use bytes::Bytes;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::Either; use http_body_util::Either;
@ -10,8 +11,7 @@ use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL};
use hyper::{http, Request, Response, StatusCode}; use hyper::{http, Request, Response, StatusCode};
use jsonwebtoken::TokenData; use jsonwebtoken::TokenData;
use std::cmp::min; use std::cmp::min;
use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::IpAddr;
use std::ops::Deref;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use url::Host; use url::Host;
use uuid::Uuid; use uuid::Uuid;
@ -92,8 +92,7 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok())) .or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
.unwrap_or_default(); .unwrap_or_default();
let (validation, decode_key) = JWT_DECODE.deref(); let jwt = match jwt_token_to_tunnel(jwt) {
let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt, Ok(jwt) => jwt,
err => { err => {
warn!( warn!(
@ -218,11 +217,3 @@ pub(super) fn inject_cookie(response: &mut http::Response<impl Body>, remote_add
Ok(()) Ok(())
} }
pub fn try_to_sock_aadr((host, port): (Host, u16)) -> anyhow::Result<SocketAddr> {
match host {
Host::Domain(_) => Err(anyhow::anyhow!("Cannot convert domain to socket address")),
Host::Ipv4(ip) => Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))),
Host::Ipv6(ip) => Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))),
}
}

View file

@ -1,6 +1,8 @@
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::client::WsClient; use crate::tunnel::client::WsClient;
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::transport::headers_from_file;
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme}; use crate::tunnel::transport::jwt::tunnel_to_jwt_token;
use crate::tunnel::{RemoteAddr, TransportScheme};
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use http_body_util::{BodyExt, BodyStream, StreamBody}; use http_body_util::{BodyExt, BodyStream, StreamBody};

View file

@ -1,6 +1,8 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite}; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use bytes::BufMut; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use bytes::{BufMut, BytesMut};
use futures_util::{pin_mut, FutureExt}; use futures_util::{pin_mut, FutureExt};
use std::future::Future;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::select; use tokio::select;
@ -9,6 +11,71 @@ use tokio::time::Instant;
use tracing::log::debug; use tracing::log::debug;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
pub(super) static MAX_PACKET_LENGTH: usize = 64 * 1024;
pub trait TunnelWrite: Send + 'static {
fn buf_mut(&mut self) -> &mut BytesMut;
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub trait TunnelRead: Send + 'static {
fn copy(
&mut self,
writer: impl AsyncWrite + Unpin + Send,
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub enum TunnelReader {
Websocket(WebsocketTunnelRead),
Http2(Http2TunnelRead),
}
impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.copy(writer).await,
Self::Http2(s) => s.copy(writer).await,
}
}
}
pub enum TunnelWriter {
Websocket(WebsocketTunnelWrite),
Http2(Http2TunnelWrite),
}
impl TunnelWrite for TunnelWriter {
fn buf_mut(&mut self) -> &mut BytesMut {
match self {
Self::Websocket(s) => s.buf_mut(),
Self::Http2(s) => s.buf_mut(),
}
}
async fn write(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.write().await,
Self::Http2(s) => s.write().await,
}
}
async fn ping(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.ping().await,
Self::Http2(s) => s.ping().await,
}
}
async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.close().await,
Self::Http2(s) => s.close().await,
}
}
}
pub async fn propagate_local_to_remote( pub async fn propagate_local_to_remote(
local_rx: impl AsyncRead, local_rx: impl AsyncRead,
mut ws_tx: impl TunnelWrite, mut ws_tx: impl TunnelWrite,

View file

@ -0,0 +1,75 @@
use crate::tunnel::{LocalProtocol, RemoteAddr};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::ops::Deref;
use url::Host;
use uuid::Uuid;
pub static JWT_HEADER_PREFIX: &str = "authorization.bearer.";
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> =
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtTunnelConfig {
pub id: String, // tunnel id
pub p: LocalProtocol, // protocol to use
pub r: String, // remote host
pub rp: u16, // remote port
}
impl JwtTunnelConfig {
fn new(request_id: Uuid, dest: &RemoteAddr) -> Self {
Self {
id: request_id.to_string(),
p: match dest.protocol {
LocalProtocol::Tcp { .. } => dest.protocol.clone(),
LocalProtocol::Udp { .. } => dest.protocol.clone(),
LocalProtocol::Stdio => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::HttpProxy { .. } => dest.protocol.clone(),
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
LocalProtocol::ReverseUdp { .. } => dest.protocol.clone(),
LocalProtocol::ReverseSocks5 { .. } => dest.protocol.clone(),
LocalProtocol::TProxyTcp => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
LocalProtocol::Unix { .. } => LocalProtocol::Tcp { proxy_protocol: false },
LocalProtocol::ReverseUnix { .. } => dest.protocol.clone(),
LocalProtocol::ReverseHttpProxy { .. } => dest.protocol.clone(),
},
r: dest.host.to_string(),
rp: dest.port,
}
}
}
pub fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
pub fn jwt_token_to_tunnel(token: &str) -> anyhow::Result<TokenData<JwtTunnelConfig>> {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: TokenData<JwtTunnelConfig> = jsonwebtoken::decode(token, decode_key, validation)?;
Ok(jwt)
}
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
Ok(Self {
protocol: jwt.p,
host: Host::parse(&jwt.r)?,
port: jwt.rp,
})
}
}

View file

@ -1,89 +1,23 @@
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use hyper::header::HOST;
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use bytes::BytesMut;
use hyper::http::{HeaderName, HeaderValue}; use hyper::http::{HeaderName, HeaderValue};
use std::future::Future;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::Path; use std::path::Path;
use std::str::FromStr; use std::str::FromStr;
use tokio::io::AsyncWrite;
use tracing::error; use tracing::error;
pub mod http2; pub mod http2;
pub mod io; pub mod io;
mod jwt;
pub mod websocket; pub mod websocket;
pub use jwt::jwt_token_to_tunnel;
static MAX_PACKET_LENGTH: usize = 64 * 1024; pub use jwt::tunnel_to_jwt_token;
pub use jwt::JwtTunnelConfig;
pub trait TunnelWrite: Send + 'static { pub use jwt::JWT_HEADER_PREFIX;
fn buf_mut(&mut self) -> &mut BytesMut;
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub trait TunnelRead: Send + 'static {
fn copy(
&mut self,
writer: impl AsyncWrite + Unpin + Send,
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub enum TunnelReader {
Websocket(WebsocketTunnelRead),
Http2(Http2TunnelRead),
}
impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.copy(writer).await,
Self::Http2(s) => s.copy(writer).await,
}
}
}
pub enum TunnelWriter {
Websocket(WebsocketTunnelWrite),
Http2(Http2TunnelWrite),
}
impl TunnelWrite for TunnelWriter {
fn buf_mut(&mut self) -> &mut BytesMut {
match self {
Self::Websocket(s) => s.buf_mut(),
Self::Http2(s) => s.buf_mut(),
}
}
async fn write(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.write().await,
Self::Http2(s) => s.write().await,
}
}
async fn ping(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.ping().await,
Self::Http2(s) => s.ping().await,
}
}
async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Websocket(s) => s.close().await,
Self::Http2(s) => s.close().await,
}
}
}
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
#[inline] #[inline]
pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec<(HeaderName, HeaderValue)>) { pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec<(HeaderName, HeaderValue)>) {
static HOST_HEADER: HeaderName = HeaderName::from_static("host");
let file = match std::fs::File::open(path) { let file = match std::fs::File::open(path) {
Ok(file) => file, Ok(file) => file,
Err(err) => { Err(err) => {
@ -100,7 +34,7 @@ pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec
let (header, value) = line.split_once(':')?; let (header, value) = line.split_once(':')?;
let header = HeaderName::from_str(header.trim()).ok()?; let header = HeaderName::from_str(header.trim()).ok()?;
let value = HeaderValue::from_str(value.trim()).ok()?; let value = HeaderValue::from_str(value.trim()).ok()?;
if header == HOST_HEADER { if header == HOST {
host_header = Some((header, value)); host_header = Some((header, value));
return None; return None;
} }

View file

@ -1,6 +1,8 @@
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::client::WsClient; use crate::tunnel::client::WsClient;
use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::transport::headers_from_file;
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};