From b9bf0f005d6a3089426196c1b314ad9c8a8e3eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sun, 7 Jan 2024 18:37:50 +0100 Subject: [PATCH] cleanup --- src/main.rs | 121 +++++++++++++++++++++++++++++++++---------- src/socks5.rs | 2 +- src/socks5_udp.rs | 88 +++++++++++++++---------------- src/tunnel/client.rs | 74 +++++++++++--------------- src/tunnel/mod.rs | 28 +++++++--- src/tunnel/server.rs | 70 +++++++++++++++---------- 6 files changed, 231 insertions(+), 152 deletions(-) diff --git a/src/main.rs b/src/main.rs index d52445f..c0ac3f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ mod tls; mod tunnel; mod udp; +use anyhow::anyhow; use base64::Engine; use clap::Parser; use futures_util::{stream, TryStreamExt}; @@ -26,6 +27,8 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio_rustls::rustls::server::DnsName; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; @@ -33,7 +36,8 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; use tracing::{error, info}; use crate::dns::DnsResolver; -use crate::tunnel::to_host_port; +use crate::tunnel::{to_host_port, RemoteAddr}; +use crate::udp::MyUdpSocket; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use url::{Host, Url}; @@ -656,11 +660,10 @@ async fn main() { let client_config = Arc::new(client_config); // Start tunnels - for mut tunnel in args.remote_to_local.into_iter() { + for tunnel in args.remote_to_local.into_iter() { let client_config = client_config.clone(); match &tunnel.local_protocol { LocalProtocol::Tcp => { - tunnel.local_protocol = LocalProtocol::ReverseTcp; tokio::spawn(async move { let remote = tunnel.remote.clone(); let cfg = client_config.clone(); @@ -675,43 +678,82 @@ async fn main() { .await }; + let (host, port) = to_host_port(tunnel.local); + let remote = RemoteAddr { + protocol: LocalProtocol::ReverseTcp, + host, + port, + }; if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await { error!("{:?}", err); } }); } LocalProtocol::Udp { timeout } => { - tunnel.local_protocol = LocalProtocol::ReverseUdp { timeout: *timeout }; + let timeout = *timeout; tokio::spawn(async move { let cfg = client_config.clone(); - let remote = tunnel.remote.clone(); + let (host, port) = to_host_port(tunnel.local); + let remote = RemoteAddr { + protocol: LocalProtocol::ReverseUdp { timeout }, + host, + port, + }; let connect_to_dest = |_| async { - udp::connect(&remote.0, remote.1, cfg.timeout_connect, &cfg.dns_resolver).await + udp::connect(&tunnel.remote.0, tunnel.remote.1, cfg.timeout_connect, &cfg.dns_resolver) + .await }; if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await { error!("{:?}", err); } }); } LocalProtocol::Socks5 { .. } => { - tunnel.local_protocol = LocalProtocol::ReverseSocks5; + trait T: AsyncWrite + AsyncRead + Unpin + Send {} + impl T for TcpStream {} + impl T for MyUdpSocket {} + tokio::spawn(async move { let cfg = client_config.clone(); - let connect_to_dest = |remote: (Host, u16)| { + let (host, port) = to_host_port(tunnel.local); + let remote = RemoteAddr { + protocol: LocalProtocol::ReverseSocks5, + host, + port, + }; + let connect_to_dest = |remote: Option| { let so_mark = cfg.socket_so_mark; let timeout = cfg.timeout_connect; let dns_resolver = &cfg.dns_resolver; - async move { tcp::connect(&remote.0, remote.1, so_mark, timeout, dns_resolver).await } + async move { + let Some(remote) = remote else { + return Err(anyhow!("Missing remote destination for reverse socks5")); + }; + + match remote.protocol { + LocalProtocol::Tcp => { + tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) + .await + .map(|s| Box::new(s) as Box) + } + LocalProtocol::Udp { .. } => { + udp::connect(&remote.host, remote.port, timeout, dns_resolver) + .await + .map(|s| Box::new(s) as Box) + } + _ => Err(anyhow!("Invalid protocol for reverse socks5 {:?}", remote.protocol)), + } + } }; if let Err(err) = - tunnel::client::run_reverse_tunnel(client_config, tunnel, connect_to_dest).await + tunnel::client::run_reverse_tunnel(client_config, remote, connect_to_dest).await { error!("{:?}", err); } @@ -732,12 +774,16 @@ async fn main() { .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) .map_ok(move |stream| { - let remote = remote.clone(); - (stream.into_split(), (LocalProtocol::Tcp, remote.0, remote.1)) + let remote = RemoteAddr { + protocol: LocalProtocol::Tcp, + host: remote.0.clone(), + port: remote.1, + }; + (stream.into_split(), remote) }); tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); } }); @@ -751,11 +797,16 @@ async fn main() { .map_ok(move |stream| { // In TProxy mode local destination is the final ip:port destination let (host, port) = to_host_port(stream.local_addr().unwrap()); - (stream.into_split(), (LocalProtocol::Tcp, host, port)) + let remote = RemoteAddr { + protocol: LocalProtocol::Tcp, + host, + port, + }; + (stream.into_split(), remote) }); tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); } }); @@ -773,11 +824,16 @@ async fn main() { .map_ok(move |stream| { // In TProxy mode local destination is the final ip:port destination let (host, port) = to_host_port(stream.local_addr().unwrap()); - (tokio::io::split(stream), (LocalProtocol::Udp { timeout }, host, port)) + let remote = RemoteAddr { + protocol: LocalProtocol::Udp { timeout }, + host, + port, + }; + (tokio::io::split(stream), remote) }); tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); } }); @@ -794,11 +850,16 @@ async fn main() { .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) .map_ok(move |stream| { - (tokio::io::split(stream), (LocalProtocol::Udp { timeout }, host.clone(), port)) + let remote = RemoteAddr { + protocol: LocalProtocol::Udp { timeout }, + host: host.clone(), + port, + }; + (tokio::io::split(stream), remote) }); tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); } }); @@ -808,12 +869,16 @@ async fn main() { .await .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)) .map_ok(|(stream, (host, port))| { - let proto = stream.local_protocol(); - (tokio::io::split(stream), (proto, host, port)) + let remote = RemoteAddr { + protocol: stream.local_protocol(), + host, + port, + }; + (tokio::io::split(stream), remote) }); tokio::spawn(async move { - if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { error!("{:?}", err); } }); @@ -826,9 +891,13 @@ async fn main() { tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel( client_config, - tunnel.clone(), stream::once(async move { - Ok((server, (LocalProtocol::Tcp, tunnel.remote.0, tunnel.remote.1))) + let remote = RemoteAddr { + protocol: LocalProtocol::Tcp, + host: tunnel.remote.0, + port: tunnel.remote.1, + }; + Ok((server, remote)) }), ) .await diff --git a/src/socks5.rs b/src/socks5.rs index d96efa0..b413221 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -58,7 +58,7 @@ pub async fn run_server(bind: SocketAddr, timeout: Option) -> Result, - configure_listener: impl Fn(&UdpSocket) -> anyhow::Result<()>, - mk_send_socket: impl Fn(&Arc) -> anyhow::Result>, ) -> Result>, anyhow::Error> { info!( "Starting SOCKS5 UDP server listening cnx on {} with cnx timeout of {}s", @@ -207,61 +205,57 @@ pub async fn run_server( let listener = UdpSocket::bind(bind) .await .with_context(|| format!("Cannot create UDP server {:?}", bind))?; - configure_listener(&listener)?; let udp_server = Socks5UdpServer::new(listener, timeout); static MAX_PACKET_LENGTH: usize = 64 * 1024; let buffer = BytesMut::with_capacity(MAX_PACKET_LENGTH * 10); - let stream = stream::unfold( - (udp_server, mk_send_socket, buffer), - |(mut server, mk_send_socket, mut buf)| async move { - loop { - server.clean_dead_keys(); - if buf.remaining_mut() < MAX_PACKET_LENGTH { - buf.reserve(MAX_PACKET_LENGTH); + let stream = stream::unfold((udp_server, buffer), |(mut server, mut buf)| async move { + loop { + server.clean_dead_keys(); + if buf.remaining_mut() < MAX_PACKET_LENGTH { + buf.reserve(MAX_PACKET_LENGTH); + } + + let peer_addr = match server.listener.recv_buf_from(&mut buf).await { + Ok((_read_len, peer_addr)) => peer_addr, + Err(err) => { + error!("Cannot read from UDP server. Closing server: {}", err); + return None; } + }; - let peer_addr = match server.listener.recv_buf_from(&mut buf).await { - Ok((_read_len, peer_addr)) => peer_addr, - Err(err) => { - error!("Cannot read from UDP server. Closing server: {}", err); - return None; - } - }; + let (destination_addr, data) = { + let payload = buf.split().freeze(); + let (frag, destination_addr, data) = fast_socks5::parse_udp_request(payload.chunk()).await.unwrap(); + // We don't support udp fragmentation + if frag != 0 { + continue; + } + (destination_addr, payload.slice_ref(data)) + }; - let (destination_addr, data) = { - let payload = buf.split().freeze(); - let (frag, destination_addr, data) = fast_socks5::parse_udp_request(payload.chunk()).await.unwrap(); - // We don't support udp fragmentation - if frag != 0 { - continue; - } - (destination_addr, payload.slice_ref(data)) - }; - - match server.peers.get(&destination_addr) { - Some(io) => { - if io.sender.send(data).await.is_err() { - server.peers.remove(&destination_addr); - } - } - None => { - info!("New UDP connection for {}", destination_addr); - let (udp_client, io) = Socks5UdpStream::new( - mk_send_socket(&server.listener).ok()?, - peer_addr, - destination_addr.clone(), - server.cnx_timeout, - Arc::downgrade(&server.keys_to_delete), - ); - let _ = io.sender.send(data).await; - server.peers.insert(destination_addr, io); - return Some((Ok(udp_client), (server, mk_send_socket, buf))); + match server.peers.get(&destination_addr) { + Some(io) => { + if io.sender.send(data).await.is_err() { + server.peers.remove(&destination_addr); } } + None => { + info!("New UDP connection for {}", destination_addr); + let (udp_client, io) = Socks5UdpStream::new( + server.listener.clone(), + peer_addr, + destination_addr.clone(), + server.cnx_timeout, + Arc::downgrade(&server.keys_to_delete), + ); + let _ = io.sender.send(data).await; + server.peers.insert(destination_addr, io); + return Some((Ok(udp_client), (server, buf))); + } } - }, - ); + } + }); Ok(stream) } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index a3e7e48..c0b2d56 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,8 +1,7 @@ -use super::{to_host_port, JwtTunnelConfig, JWT_HEADER_PREFIX, JWT_KEY}; -use crate::{LocalProtocol, LocalToRemote, WsClientConfig}; +use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; +use crate::WsClientConfig; use anyhow::{anyhow, Context}; -use base64::Engine; use bytes::Bytes; use fastwebsockets::WebSocket; use futures_util::pin_mut; @@ -13,6 +12,7 @@ use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; use hyper::{Request, Response}; use hyper_util::rt::{TokioExecutor, TokioIo}; +use jsonwebtoken::TokenData; use std::future::Future; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -21,24 +21,18 @@ use tokio::sync::oneshot; use tokio_stream::{Stream, StreamExt}; use tracing::log::debug; use tracing::{error, span, Instrument, Level, Span}; -use url::{Host, Url}; +use url::Host; use uuid::Uuid; -fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String { - let cfg = JwtTunnelConfig::new(request_id, tunnel); - let (alg, secret) = JWT_KEY.deref(); - jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default() -} - pub async fn connect( request_id: Uuid, client_cfg: &WsClientConfig, - tunnel_cfg: &LocalToRemote, + dest_addr: &RemoteAddr, ) -> anyhow::Result<(WebSocket>, Response)> { let mut pooled_cnx = match client_cfg.cnx_pool().get().await { - Ok(tcp_stream) => tcp_stream, - Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?, - }; + Ok(cnx) => Ok(cnx), + Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), + }?; let mut req = Request::builder() .method("GET") @@ -50,7 +44,7 @@ pub async fn connect( .header(SEC_WEBSOCKET_VERSION, "13") .header( SEC_WEBSOCKET_PROTOCOL, - format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, tunnel_cfg)), + format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, dest_addr)), ) .version(hyper::Version::HTTP_11); @@ -79,7 +73,7 @@ pub async fn connect( async fn connect_to_server( request_id: Uuid, client_cfg: &WsClientConfig, - remote_cfg: &LocalToRemote, + remote_cfg: &RemoteAddr, duplex_stream: (R, W), ) -> anyhow::Result<()> where @@ -105,32 +99,25 @@ where Ok(()) } -pub async fn run_tunnel( - client_config: Arc, - tunnel_cfg: LocalToRemote, - incoming_cnx: T, -) -> anyhow::Result<()> +pub async fn run_tunnel(client_config: Arc, incoming_cnx: T) -> anyhow::Result<()> where - T: Stream>, + T: Stream>, R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { pin_mut!(incoming_cnx); - while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await { + while let Some(Ok((cnx_stream, remote_addr))) = incoming_cnx.next().await { let request_id = Uuid::now_v7(); let span = span!( Level::INFO, "tunnel", id = request_id.to_string(), - remote = format!("{}:{}", remote_dest.1, remote_dest.2) + remote = format!("{}:{}", remote_addr.host, remote_addr.port) ); - let mut tunnel_cfg = tunnel_cfg.clone(); - tunnel_cfg.local_protocol = remote_dest.0; - tunnel_cfg.remote = (remote_dest.1, remote_dest.2); let client_config = client_config.clone(); let tunnel = async move { - let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream) + let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream) .await .map_err(|err| error!("{:?}", err)); } @@ -144,18 +131,14 @@ where pub async fn run_reverse_tunnel( client_config: Arc, - mut tunnel_cfg: LocalToRemote, + remote_addr: RemoteAddr, connect_to_dest: F, ) -> anyhow::Result<()> where - F: Fn((Host, u16)) -> Fut, + F: Fn(Option) -> Fut, Fut: Future>, T: AsyncRead + AsyncWrite + Send + 'static, { - // Invert local with remote - let remote_ori = tunnel_cfg.remote; - tunnel_cfg.remote = to_host_port(tunnel_cfg.local); - loop { let client_config = client_config.clone(); let request_id = Uuid::now_v7(); @@ -163,12 +146,12 @@ where Level::INFO, "tunnel", id = request_id.to_string(), - remote = format!("{}:{}", tunnel_cfg.remote.0, tunnel_cfg.remote.1) + remote = format!("{}:{}", remote_addr.host, remote_addr.port) ); let _span = span.enter(); // Correctly configure tunnel cfg - let (mut ws, response) = connect(request_id, &client_config, &tunnel_cfg) + let (mut ws, response) = connect(request_id, &client_config, &remote_addr) .instrument(span.clone()) .await?; ws.set_auto_apply_mask(client_config.websocket_mask_frame); @@ -178,18 +161,21 @@ where .headers() .get(COOKIE) .and_then(|h| h.to_str().ok()) - .and_then(|h| base64::engine::general_purpose::STANDARD.decode(h).ok()) - .and_then(|h| Url::parse(&String::from_utf8_lossy(&h)).ok()) - .and_then(|url| match (url.host(), url.port_or_known_default()) { - (Some(h), Some(p)) => Some((h.to_owned(), p)), - _ => None, + .and_then(|h| { + let (validation, decode_key) = JWT_DECODE.deref(); + let jwt: Option> = jsonwebtoken::decode(h, decode_key, validation).ok(); + jwt }) - .unwrap_or(remote_ori.clone()); + .map(|jwt| RemoteAddr { + protocol: jwt.claims.p, + host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), + port: jwt.claims.rp, + }); - let stream = match connect_to_dest(remote.clone()).instrument(span.clone()).await { + let stream = match connect_to_dest(remote).instrument(span.clone()).await { Ok(s) => s, Err(err) => { - error!("Cannot connect to {remote:?}: {err:?}"); + error!("Cannot connect to xxxx: {err:?}"); continue; } }; diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 5dc24fc..b235063 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -3,7 +3,7 @@ mod io; pub mod server; mod tls_reloader; -use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig}; +use crate::{tcp, tls, LocalProtocol, WsClientConfig}; use async_trait::async_trait; use bb8::ManageConnection; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::io::{Error, IoSlice}; use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -29,26 +30,32 @@ struct JwtTunnelConfig { } impl JwtTunnelConfig { - fn new(request_id: Uuid, tunnel: &LocalToRemote) -> Self { + fn new(request_id: Uuid, dest: &RemoteAddr) -> Self { Self { id: request_id.to_string(), - p: match tunnel.local_protocol { + p: match dest.protocol { LocalProtocol::Tcp => LocalProtocol::Tcp, - LocalProtocol::Udp { .. } => tunnel.local_protocol, + LocalProtocol::Udp { .. } => dest.protocol, LocalProtocol::Stdio => LocalProtocol::Tcp, LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp, LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, - LocalProtocol::ReverseUdp { .. } => tunnel.local_protocol, + LocalProtocol::ReverseUdp { .. } => dest.protocol, LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5, LocalProtocol::TProxyTcp => LocalProtocol::Tcp, LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, }, - r: tunnel.remote.0.to_string(), - rp: tunnel.remote.1, + r: dest.host.to_string(), + rp: dest.port, } } } +fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &RemoteAddr) -> String { + let cfg = JwtTunnelConfig::new(request_id, tunnel); + let (alg, secret) = JWT_KEY.deref(); + jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default() +} + static JWT_HEADER_PREFIX: &str = "authorization.bearer."; static JWT_SECRET: &[u8; 15] = b"champignonfrais"; static JWT_KEY: Lazy<(Header, EncodingKey)> = @@ -60,6 +67,13 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { (validation, DecodingKey::from_secret(JWT_SECRET)) }); +#[derive(Debug)] +pub struct RemoteAddr { + pub protocol: LocalProtocol, + pub host: Host, + pub port: u16, +} + pub enum TransportStream { Plain(TcpStream), Tls(TlsStream), diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 4ba7c23..a2c75f6 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,6 +1,5 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; -use base64::Engine; use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; use std::cmp::min; use std::fmt::Debug; @@ -10,7 +9,7 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX}; +use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; use hyper::body::Incoming; use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; @@ -32,17 +31,12 @@ use tokio::sync::{mpsc, oneshot}; use tokio_rustls::TlsAcceptor; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; +use uuid::Uuid; async fn run_tunnel( server_config: &WsServerConfig, jwt: TokenData, -) -> anyhow::Result<( - LocalProtocol, - Host, - u16, - Pin>, - Pin>, -)> { +) -> anyhow::Result<(RemoteAddr, Pin>, Pin>)> { match jwt.claims.p { LocalProtocol::Udp { timeout, .. } => { let host = Host::parse(&jwt.claims.r)?; @@ -53,13 +47,13 @@ async fn run_tunnel( &server_config.dns_resolver, ) .await?; - Ok(( - LocalProtocol::Udp { timeout: None }, + + let remote = RemoteAddr { + protocol: jwt.claims.p, host, - jwt.claims.rp, - Box::pin(cnx.clone()), - Box::pin(cnx), - )) + port: jwt.claims.rp, + }; + Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) } LocalProtocol::Tcp => { let host = Host::parse(&jwt.claims.r)?; @@ -74,7 +68,12 @@ async fn run_tunnel( .await? .into_split(); - Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx))) + let remote = RemoteAddr { + protocol: jwt.claims.p, + host, + port, + }; + Ok((remote, Box::pin(rx), Box::pin(tx))) } LocalProtocol::ReverseTcp => { #[allow(clippy::type_complexity)] @@ -87,7 +86,12 @@ async fn run_tunnel( let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tcp.into_split(); - Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) + let remote = RemoteAddr { + protocol: jwt.claims.p, + host: local_srv.0, + port: local_srv.1, + }; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseUdp { timeout } => { #[allow(clippy::type_complexity)] @@ -101,7 +105,12 @@ async fn run_tunnel( let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tokio::io::split(udp); - Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) + let remote = RemoteAddr { + protocol: jwt.claims.p, + host: local_srv.0, + port: local_srv.1, + }; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseSocks5 => { #[allow(clippy::type_complexity)] @@ -112,10 +121,15 @@ async fn run_tunnel( let bind = format!("{}:{}", local_srv.0, local_srv.1); let listening_server = socks5::run_server(bind.parse()?, None); let (stream, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; - let proto = stream.local_protocol(); + let protocol = stream.local_protocol(); let (local_rx, local_tx) = tokio::io::split(stream); - Ok((proto, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) + let remote = RemoteAddr { + protocol, + host: local_srv.0, + port: local_srv.1, + }; + Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } _ => Err(anyhow::anyhow!("Invalid upgrade request")), } @@ -308,6 +322,7 @@ async fn server_upgrade(server_config: Arc, mut req: Request ret, Err(err) => { @@ -319,8 +334,11 @@ async fn server_upgrade(server_config: Arc, mut req: Request ret, Err(err) => { @@ -351,11 +369,9 @@ async fn server_upgrade(server_config: Arc, mut req: Request