This commit is contained in:
Σrebe - Romain GERARD 2024-01-07 18:37:50 +01:00
parent ac76f52f6d
commit b9bf0f005d
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
6 changed files with 231 additions and 152 deletions

View file

@ -1,8 +1,7 @@
use super::{to_host_port, JwtTunnelConfig, JWT_HEADER_PREFIX, JWT_KEY};
use crate::{LocalProtocol, LocalToRemote, WsClientConfig};
use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::WsClientConfig;
use anyhow::{anyhow, Context};
use base64::Engine;
use bytes::Bytes;
use fastwebsockets::WebSocket;
use futures_util::pin_mut;
@ -13,6 +12,7 @@ use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::upgrade::Upgraded;
use hyper::{Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo};
use jsonwebtoken::TokenData;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
@ -21,24 +21,18 @@ use tokio::sync::oneshot;
use tokio_stream::{Stream, StreamExt};
use tracing::log::debug;
use tracing::{error, span, Instrument, Level, Span};
use url::{Host, Url};
use url::Host;
use uuid::Uuid;
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(WebSocket<TokioIo<Upgraded>>, Response<Incoming>)> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream,
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
};
Ok(cnx) => Ok(cnx),
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
}?;
let mut req = Request::builder()
.method("GET")
@ -50,7 +44,7 @@ pub async fn connect(
.header(SEC_WEBSOCKET_VERSION, "13")
.header(
SEC_WEBSOCKET_PROTOCOL,
format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, tunnel_cfg)),
format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, dest_addr)),
)
.version(hyper::Version::HTTP_11);
@ -79,7 +73,7 @@ pub async fn connect(
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
remote_cfg: &RemoteAddr,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
@ -105,32 +99,25 @@ where
Ok(())
}
pub async fn run_tunnel<T, R, W>(
client_config: Arc<WsClientConfig>,
tunnel_cfg: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
pub async fn run_tunnel<T, R, W>(client_config: Arc<WsClientConfig>, incoming_cnx: T) -> anyhow::Result<()>
where
T: Stream<Item = anyhow::Result<((R, W), (LocalProtocol, Host, u16))>>,
T: Stream<Item = anyhow::Result<((R, W), RemoteAddr)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pin_mut!(incoming_cnx);
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
while let Some(Ok((cnx_stream, remote_addr))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.1, remote_dest.2)
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
let mut tunnel_cfg = tunnel_cfg.clone();
tunnel_cfg.local_protocol = remote_dest.0;
tunnel_cfg.remote = (remote_dest.1, remote_dest.2);
let client_config = client_config.clone();
let tunnel = async move {
let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream)
let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream)
.await
.map_err(|err| error!("{:?}", err));
}
@ -144,18 +131,14 @@ where
pub async fn run_reverse_tunnel<F, Fut, T>(
client_config: Arc<WsClientConfig>,
mut tunnel_cfg: LocalToRemote,
remote_addr: RemoteAddr,
connect_to_dest: F,
) -> anyhow::Result<()>
where
F: Fn((Host, u16)) -> Fut,
F: Fn(Option<RemoteAddr>) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
T: AsyncRead + AsyncWrite + Send + 'static,
{
// Invert local with remote
let remote_ori = tunnel_cfg.remote;
tunnel_cfg.remote = to_host_port(tunnel_cfg.local);
loop {
let client_config = client_config.clone();
let request_id = Uuid::now_v7();
@ -163,12 +146,12 @@ where
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", tunnel_cfg.remote.0, tunnel_cfg.remote.1)
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
);
let _span = span.enter();
// Correctly configure tunnel cfg
let (mut ws, response) = connect(request_id, &client_config, &tunnel_cfg)
let (mut ws, response) = connect(request_id, &client_config, &remote_addr)
.instrument(span.clone())
.await?;
ws.set_auto_apply_mask(client_config.websocket_mask_frame);
@ -178,18 +161,21 @@ where
.headers()
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|h| base64::engine::general_purpose::STANDARD.decode(h).ok())
.and_then(|h| Url::parse(&String::from_utf8_lossy(&h)).ok())
.and_then(|url| match (url.host(), url.port_or_known_default()) {
(Some(h), Some(p)) => Some((h.to_owned(), p)),
_ => None,
.and_then(|h| {
let (validation, decode_key) = JWT_DECODE.deref();
let jwt: Option<TokenData<JwtTunnelConfig>> = jsonwebtoken::decode(h, decode_key, validation).ok();
jwt
})
.unwrap_or(remote_ori.clone());
.map(|jwt| RemoteAddr {
protocol: jwt.claims.p,
host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())),
port: jwt.claims.rp,
});
let stream = match connect_to_dest(remote.clone()).instrument(span.clone()).await {
let stream = match connect_to_dest(remote).instrument(span.clone()).await {
Ok(s) => s,
Err(err) => {
error!("Cannot connect to {remote:?}: {err:?}");
error!("Cannot connect to xxxx: {err:?}");
continue;
}
};

View file

@ -3,7 +3,7 @@ mod io;
pub mod server;
mod tls_reloader;
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::io::{Error, IoSlice};
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
@ -29,26 +30,32 @@ struct JwtTunnelConfig {
}
impl JwtTunnelConfig {
fn new(request_id: Uuid, tunnel: &LocalToRemote) -> Self {
fn new(request_id: Uuid, dest: &RemoteAddr) -> Self {
Self {
id: request_id.to_string(),
p: match tunnel.local_protocol {
p: match dest.protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel.local_protocol,
LocalProtocol::Udp { .. } => dest.protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp,
LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp,
LocalProtocol::ReverseUdp { .. } => tunnel.local_protocol,
LocalProtocol::ReverseUdp { .. } => dest.protocol,
LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5,
LocalProtocol::TProxyTcp => LocalProtocol::Tcp,
LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout },
},
r: tunnel.remote.0.to_string(),
rp: tunnel.remote.1,
r: dest.host.to_string(),
rp: dest.port,
}
}
}
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
static JWT_HEADER_PREFIX: &str = "authorization.bearer.";
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> =
@ -60,6 +67,13 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
(validation, DecodingKey::from_secret(JWT_SECRET))
});
#[derive(Debug)]
pub struct RemoteAddr {
pub protocol: LocalProtocol,
pub host: Host,
pub port: u16,
}
pub enum TransportStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),

View file

@ -1,6 +1,5 @@
use ahash::{HashMap, HashMapExt};
use anyhow::anyhow;
use base64::Engine;
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
use std::cmp::min;
use std::fmt::Debug;
@ -10,7 +9,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig};
use hyper::body::Incoming;
use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL};
@ -32,17 +31,12 @@ use tokio::sync::{mpsc, oneshot};
use tokio_rustls::TlsAcceptor;
use tracing::{error, info, span, warn, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
async fn run_tunnel(
server_config: &WsServerConfig,
jwt: TokenData<JwtTunnelConfig>,
) -> anyhow::Result<(
LocalProtocol,
Host,
u16,
Pin<Box<dyn AsyncRead + Send>>,
Pin<Box<dyn AsyncWrite + Send>>,
)> {
) -> anyhow::Result<(RemoteAddr, Pin<Box<dyn AsyncRead + Send>>, Pin<Box<dyn AsyncWrite + Send>>)> {
match jwt.claims.p {
LocalProtocol::Udp { timeout, .. } => {
let host = Host::parse(&jwt.claims.r)?;
@ -53,13 +47,13 @@ async fn run_tunnel(
&server_config.dns_resolver,
)
.await?;
Ok((
LocalProtocol::Udp { timeout: None },
let remote = RemoteAddr {
protocol: jwt.claims.p,
host,
jwt.claims.rp,
Box::pin(cnx.clone()),
Box::pin(cnx),
))
port: jwt.claims.rp,
};
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
}
LocalProtocol::Tcp => {
let host = Host::parse(&jwt.claims.r)?;
@ -74,7 +68,12 @@ async fn run_tunnel(
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
let remote = RemoteAddr {
protocol: jwt.claims.p,
host,
port,
};
Ok((remote, Box::pin(rx), Box::pin(tx)))
}
LocalProtocol::ReverseTcp => {
#[allow(clippy::type_complexity)]
@ -87,7 +86,12 @@ async fn run_tunnel(
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = tcp.into_split();
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
let remote = RemoteAddr {
protocol: jwt.claims.p,
host: local_srv.0,
port: local_srv.1,
};
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseUdp { timeout } => {
#[allow(clippy::type_complexity)]
@ -101,7 +105,12 @@ async fn run_tunnel(
let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = tokio::io::split(udp);
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
let remote = RemoteAddr {
protocol: jwt.claims.p,
host: local_srv.0,
port: local_srv.1,
};
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
LocalProtocol::ReverseSocks5 => {
#[allow(clippy::type_complexity)]
@ -112,10 +121,15 @@ async fn run_tunnel(
let bind = format!("{}:{}", local_srv.0, local_srv.1);
let listening_server = socks5::run_server(bind.parse()?, None);
let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let proto = stream.local_protocol();
let protocol = stream.local_protocol();
let (local_rx, local_tx) = tokio::io::split(stream);
Ok((proto, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
let remote = RemoteAddr {
protocol,
host: local_srv.0,
port: local_srv.1,
};
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
}
_ => Err(anyhow::anyhow!("Invalid upgrade request")),
}
@ -308,6 +322,7 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
return err;
}
let req_protocol = jwt.claims.p;
let tunnel = match run_tunnel(&server_config, jwt).await {
Ok(ret) => ret,
Err(err) => {
@ -319,8 +334,11 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
}
};
let (protocol, dest, port, local_rx, local_tx) = tunnel;
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
let (remote_addr, local_rx, local_tx) = tunnel;
info!(
"connected to {:?} {:?} {:?}",
remote_addr.protocol, remote_addr.host, remote_addr.port
);
let (mut response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret,
Err(err) => {
@ -351,11 +369,9 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
.instrument(Span::current()),
);
if protocol == LocalProtocol::ReverseSocks5 {
let Ok(header_val) = HeaderValue::from_str(
&base64::engine::general_purpose::STANDARD.encode(format!("https://{}:{}", dest, port)),
) else {
error!("Bad headervalue for reverse socks5: {} {}", dest, port);
if req_protocol == LocalProtocol::ReverseSocks5 {
let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), &remote_addr)) else {
error!("Bad headervalue for reverse socks5: {} {}", remote_addr.host, remote_addr.port);
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())