diff --git a/Cargo.lock b/Cargo.lock index fcd81c9..6cb7652 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -414,9 +414,8 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "fast-socks5" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d449e348301d5fb9b0e5781510d8235ffe3bbac3286bd305462736a9e7043039" +version = "0.9.1" +source = "git+https://github.com/erebe/fast-socks5.git?branch=master#1912e35f8f5621827927096171d96c54816bf1ba" dependencies = [ "anyhow", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 1b01d0b..fdaccb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ base64 = "0.21.5" bb8 = { version = "0.8", features = [] } bytes = { version = "1.5.0", features = [] } clap = { version = "4.4.11", features = ["derive", "env"] } -fast-socks5 = { version = "0.9.2", features = [] } +fast-socks5 = { git = "https://github.com/erebe/fast-socks5.git", branch = "master", features = [] } fastwebsockets = { version = "0.6.0", features = ["upgrade", "simd", "unstable-split"] } futures-util = { version = "0.3.29" } hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] } diff --git a/src/main.rs b/src/main.rs index 60f283a..ed8a9cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod dns; mod embedded_certificate; mod socks5; +mod socks5_udp; mod stdio; mod tcp; mod tls; @@ -236,7 +237,7 @@ enum LocalProtocol { Tcp, Udp { timeout: Option }, Stdio, - Socks5, + Socks5 { timeout: Option }, TProxyTcp, TProxyUdp { timeout: Option }, ReverseTcp, @@ -368,9 +369,14 @@ fn parse_tunnel_arg(arg: &str) -> Result { "socks5:/" => { let (local_bind, remaining) = parse_local_bind(&arg[9..])?; let x = format!("0.0.0.0:0?{}", remaining); - let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?; + let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; + let timeout = options + .get("timeout_sec") + .and_then(|x| x.parse::().ok()) + .map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) }) + .unwrap_or(Some(Duration::from_secs(30))); Ok(LocalToRemote { - local_protocol: LocalProtocol::Socks5, + local_protocol: LocalProtocol::Socks5 { timeout }, local: local_bind, remote: (dest_host, dest_port), }) @@ -693,7 +699,7 @@ async fn main() { } }); } - LocalProtocol::Socks5 => { + LocalProtocol::Socks5 { .. } => { tunnel.local_protocol = LocalProtocol::ReverseSocks5; tokio::spawn(async move { let cfg = client_config.clone(); @@ -725,7 +731,10 @@ async fn main() { .await .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) - .map_ok(move |stream| (stream.into_split(), remote.clone())); + .map_ok(move |stream| { + let remote = remote.clone(); + (stream.into_split(), (LocalProtocol::Tcp, remote.0, remote.1)) + }); tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { @@ -741,8 +750,8 @@ async fn main() { .map_err(anyhow::Error::new) .map_ok(move |stream| { // In TProxy mode local destination is the final ip:port destination - let dest = to_host_port(stream.local_addr().unwrap()); - (stream.into_split(), dest) + let (host, port) = to_host_port(stream.local_addr().unwrap()); + (stream.into_split(), (LocalProtocol::Tcp, host, port)) }); tokio::spawn(async move { @@ -753,8 +762,9 @@ async fn main() { } #[cfg(target_os = "linux")] LocalProtocol::TProxyUdp { timeout } => { + let timeout = *timeout; let server = - udp::run_server(tunnel.local, *timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) + udp::run_server(tunnel.local, timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) .await .unwrap_or_else(|err| { panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err) @@ -762,8 +772,8 @@ async fn main() { .map_err(anyhow::Error::new) .map_ok(move |stream| { // In TProxy mode local destination is the final ip:port destination - let dest = to_host_port(stream.local_addr().unwrap()); - (tokio::io::split(stream), dest) + let (host, port) = to_host_port(stream.local_addr().unwrap()); + (tokio::io::split(stream), (LocalProtocol::Udp { timeout: timeout }, host, port)) }); tokio::spawn(async move { @@ -777,12 +787,15 @@ async fn main() { panic!("Transparent proxy is not available for non Linux platform") } LocalProtocol::Udp { timeout } => { - let remote = tunnel.remote.clone(); - let server = udp::run_server(tunnel.local, *timeout, |_| Ok(()), |s| Ok(s.clone())) + let (host, port) = tunnel.remote.clone(); + let timeout = *timeout; + let server = udp::run_server(tunnel.local, timeout.clone(), |_| Ok(()), |s| Ok(s.clone())) .await .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) - .map_ok(move |stream| (tokio::io::split(stream), remote.clone())); + .map_ok(move |stream| { + (tokio::io::split(stream), (LocalProtocol::Udp { timeout }, host.clone(), port)) + }); tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { @@ -790,11 +803,14 @@ async fn main() { } }); } - LocalProtocol::Socks5 => { - let server = socks5::run_server(tunnel.local) + LocalProtocol::Socks5 { timeout } => { + let server = socks5::run_server(tunnel.local, *timeout) .await .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)) - .map_ok(|(stream, remote_dest)| (tokio::io::split(stream), remote_dest)); + .map_ok(|(stream, (host, port))| { + let proto = stream.local_protocol(); + (tokio::io::split(stream), (proto, host, port)) + }); tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { @@ -811,7 +827,9 @@ async fn main() { if let Err(err) = tunnel::client::run_tunnel( client_config, tunnel.clone(), - stream::once(async move { Ok((server, tunnel.remote)) }), + stream::once(async move { + Ok((server, (LocalProtocol::Tcp, tunnel.remote.0, tunnel.remote.1))) + }), ) .await { diff --git a/src/socks5.rs b/src/socks5.rs index 8d57012..d96efa0 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -1,4 +1,5 @@ -use crate::udp::UdpStream; +use crate::socks5_udp::Socks5UdpStream; +use crate::{socks5_udp, LocalProtocol}; use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, Socks5Server}; use fast_socks5::util::target_addr::TargetAddr; @@ -8,29 +9,43 @@ use std::io::{Error, IoSlice}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; +use tokio::select; use tracing::{info, warn}; use url::Host; #[allow(clippy::type_complexity)] pub struct Socks5Listener { - stream: Pin> + Send>>, + socks_server: Pin> + Send>>, } -pub enum Socks5Protocol { +pub enum Socks5Stream { Tcp(TcpStream), - Udp(UdpStream), + Udp(Socks5UdpStream), } -impl Stream for Socks5Listener { - type Item = anyhow::Result<(Socks5Protocol, (Host, u16))>; - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx) +impl Socks5Stream { + pub fn local_protocol(&self) -> LocalProtocol { + match self { + Socks5Stream::Tcp(_) => LocalProtocol::Tcp, + Socks5Stream::Udp(s) => LocalProtocol::Udp { + timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), + }, + } } } -pub async fn run_server(bind: SocketAddr) -> Result { +impl Stream for Socks5Listener { + type Item = anyhow::Result<(Socks5Stream, (Host, u16))>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|x| &mut x.socks_server) }.poll_next(cx) + } +} + +pub async fn run_server(bind: SocketAddr, timeout: Option) -> Result { info!("Starting SOCKS5 server listening cnx on {}", bind); let server = Socks5Server::::bind(bind) @@ -43,17 +58,39 @@ pub async fn run_server(bind: SocketAddr) -> Result return None, - Some(Err(err)) => { + let cnx = select! { + biased; + + cnx = acceptor.next() => match cnx { + None => return None, + Some(Err(err)) => { + drop(acceptor); + return Some((Err(anyhow::Error::new(err)), (server, udp_server))); + } + Some(Ok(cnx)) => cnx, + }, + + // new incoming udp stream + udp_conn = udp_server.next() => { drop(acceptor); - return Some((Err(anyhow::Error::new(err)), server)); + return match udp_conn { + Some(Ok(stream)) => { + let dest = stream.destination(); + Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server))) + } + Some(Err(err)) => { + Some((Err(anyhow::Error::new(err)), (server, udp_server))) + } + None => { + None + } + }; } - Some(Ok(cnx)) => cnx, }; let cnx = match cnx.upgrade_to_socks5().await { @@ -68,10 +105,6 @@ pub async fn run_server(bind: SocketAddr) -> Result (Host::Ipv4(*ip.ip()), ip.port()), @@ -79,6 +112,28 @@ pub async fn run_server(bind: SocketAddr) -> Result (Host::Domain(host.clone()), *port), }; + // Special case for UDP Associate where we return the bind addr of the udp server + if let Some(fast_socks5::Socks5Command::UDPAssociate) = cnx.cmd() { + let mut cnx = cnx.into_inner(); + let ret = cnx.write_all(&new_reply(&ReplyError::Succeeded, bind)).await; + + if let Err(err) = ret { + warn!("Cannot reply to socks5 udp client: {}", err); + continue; + } + tokio::spawn(async move { + let mut buf = [0u8; 8]; + loop { + match cnx.read(&mut buf).await { + Ok(0) => return, + Err(_) => return, + _ => {} + } + } + }); + continue; + }; + let mut cnx = cnx.into_inner(); let ret = cnx .write_all(&new_reply( @@ -93,12 +148,12 @@ pub async fn run_server(bind: SocketAddr) -> Result Vec { reply } -impl Unpin for Socks5Protocol {} -impl AsyncRead for Socks5Protocol { +impl Unpin for Socks5Stream {} +impl AsyncRead for Socks5Stream { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { - Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), - Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), + Socks5Stream::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), + Socks5Stream::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), } } } -impl AsyncWrite for Socks5Protocol { +impl AsyncWrite for Socks5Stream { fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { match self.get_mut() { - Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), - Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), + Socks5Stream::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), + Socks5Stream::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match self.get_mut() { - Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), - Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), + Socks5Stream::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), + Socks5Stream::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match self.get_mut() { - Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), - Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), + Socks5Stream::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), + Socks5Stream::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), } } @@ -172,15 +227,15 @@ impl AsyncWrite for Socks5Protocol { bufs: &[IoSlice<'_>], ) -> Poll> { match self.get_mut() { - Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), - Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), + Socks5Stream::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), + Socks5Stream::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), } } fn is_write_vectored(&self) -> bool { match self { - Socks5Protocol::Tcp(s) => s.is_write_vectored(), - Socks5Protocol::Udp(s) => s.is_write_vectored(), + Socks5Stream::Tcp(s) => s.is_write_vectored(), + Socks5Stream::Udp(s) => s.is_write_vectored(), } } } diff --git a/src/socks5_udp.rs b/src/socks5_udp.rs new file mode 100644 index 0000000..d8e02dd --- /dev/null +++ b/src/socks5_udp.rs @@ -0,0 +1,267 @@ +use anyhow::Context; +use futures_util::{stream, Stream}; + +use parking_lot::RwLock; +use pin_project::{pin_project, pinned_drop}; +use std::collections::HashMap; +use std::io; +use std::io::{Error, ErrorKind}; +use std::net::SocketAddr; + +use crate::tunnel::to_host_port; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use fast_socks5::new_udp_header; +use fast_socks5::util::target_addr::TargetAddr; +use log::warn; +use std::pin::{pin, Pin}; +use std::sync::{Arc, Weak}; +use std::task::{ready, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::time::Interval; +use tracing::{debug, error, info}; +use url::Host; + +struct IoInner { + sender: mpsc::Sender, +} +struct Socks5UdpServer { + listener: Arc, + peers: HashMap>, ahash::RandomState>, + keys_to_delete: Arc>>, + cnx_timeout: Option, +} + +impl Socks5UdpServer { + pub fn new(listener: UdpSocket, timeout: Option) -> Self { + let socket = socket2::SockRef::from(&listener); + + // Increase receive buffer + if let Err(err) = socket.set_recv_buffer_size(64 * 1024 * 1024) { + warn!("Cannot set UDP server recv buffer: {}", err); + } + + if let Err(err) = socket.set_send_buffer_size(64 * 1024 * 1024) { + warn!("Cannot set UDP server recv buffer: {}", err); + } + + Self { + listener: Arc::new(listener), + peers: HashMap::with_hasher(ahash::RandomState::new()), + keys_to_delete: Default::default(), + cnx_timeout: timeout, + } + } + #[inline] + pub fn clean_dead_keys(&mut self) { + let nb_key_to_delete = self.keys_to_delete.read().len(); + if nb_key_to_delete == 0 { + return; + } + + debug!("Cleaning {} dead udp peers", nb_key_to_delete); + let mut keys_to_delete = self.keys_to_delete.write(); + for key in keys_to_delete.iter() { + self.peers.remove(key); + } + keys_to_delete.clear(); + } +} + +#[pin_project(PinnedDrop)] +pub struct Socks5UdpStream { + #[pin] + recv_data: mpsc::Receiver, + send_socket: Arc, + destination: TargetAddr, + peer: SocketAddr, + udp_header: Vec, + #[pin] + pub watchdog_deadline: Option, + data_read_before_deadline: bool, + io: Pin>, + keys_to_delete: Weak>>, +} + +#[pinned_drop] +impl PinnedDrop for Socks5UdpStream { + fn drop(self: Pin<&mut Self>) { + if let Some(keys_to_delete) = self.keys_to_delete.upgrade() { + keys_to_delete.write().push(self.destination.clone()); + } + } +} + +impl Socks5UdpStream { + fn new( + send_socket: Arc, + peer: SocketAddr, + destination: TargetAddr, + watchdog_deadline: Option, + keys_to_delete: Weak>>, + ) -> (Self, Pin>) { + let (tx, rx) = mpsc::channel(1024); + let io = Arc::pin(IoInner { sender: tx }); + let udp_header = match &destination { + TargetAddr::Ip(ip) => new_udp_header(*ip).unwrap(), + TargetAddr::Domain(h, p) => new_udp_header((h.as_str(), *p)).unwrap(), + }; + let s = Self { + recv_data: rx, + send_socket, + peer, + destination, + watchdog_deadline: watchdog_deadline + .map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)), + data_read_before_deadline: false, + io: io.clone(), + keys_to_delete, + udp_header, + }; + + (s, io) + } + + pub fn destination(&self) -> (Host, u16) { + match &self.destination { + TargetAddr::Ip(sock_addr) => to_host_port(*sock_addr), + TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p), + } + } +} + +impl AsyncRead for Socks5UdpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + obuf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut project = self.project(); + // Look that the timeout for client has not elapsed + if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() { + if deadline.poll_tick(cx).is_ready() { + if !*project.data_read_before_deadline { + return Poll::Ready(Err(Error::new( + ErrorKind::TimedOut, + format!("UDP stream timeout with {}", project.peer), + ))); + }; + + *project.data_read_before_deadline = false; + while deadline.poll_tick(cx).is_ready() {} + } + } + + let Some(data) = ready!(project.recv_data.poll_recv(cx)) else { + return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof))); + }; + if obuf.remaining() < data.len() { + return Poll::Ready(Err(Error::new( + ErrorKind::InvalidData, + "udp dst buffer does not have enough space left. Can't fragment", + ))); + } + + obuf.put_slice(data.chunk()); + *project.data_read_before_deadline = true; + + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for Socks5UdpStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { + let this = self.project(); + let header_len = this.udp_header.len(); + this.udp_header.extend_from_slice(buf); + let ret = this + .send_socket + .poll_send_to(cx, this.udp_header.as_slice(), *this.peer); + this.udp_header.truncate(header_len); + ret.map(|r| r.map(|write_len| write_len - header_len)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.send_socket.poll_send_ready(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +pub async fn run_server( + bind: SocketAddr, + timeout: Option, + configure_listener: impl Fn(&UdpSocket) -> anyhow::Result<()>, + mk_send_socket: impl Fn(&Arc) -> anyhow::Result>, +) -> Result>, anyhow::Error> { + info!( + "Starting SOCKS5 UDP server listening cnx on {} with cnx timeout of {}s", + bind, + timeout.unwrap_or(Duration::from_secs(0)).as_secs() + ); + + let listener = UdpSocket::bind(bind) + .await + .with_context(|| format!("Cannot create UDP server {:?}", bind))?; + configure_listener(&listener)?; + + let udp_server = Socks5UdpServer::new(listener, timeout); + static MAX_PACKET_LENGTH: usize = 64 * 1024; + let buffer = BytesMut::with_capacity(MAX_PACKET_LENGTH * 10); + let stream = stream::unfold( + (udp_server, mk_send_socket, buffer), + |(mut server, mk_send_socket, mut buf)| async move { + loop { + server.clean_dead_keys(); + if buf.remaining_mut() < MAX_PACKET_LENGTH { + buf.reserve(MAX_PACKET_LENGTH); + } + + let peer_addr = match server.listener.recv_buf_from(&mut buf).await { + Ok((_read_len, peer_addr)) => peer_addr, + Err(err) => { + error!("Cannot read from UDP server. Closing server: {}", err); + return None; + } + }; + + let (destination_addr, data) = { + let payload = buf.split().freeze(); + let (frag, destination_addr, data) = fast_socks5::parse_udp_request(payload.chunk()).await.unwrap(); + // We don't support udp fragmentation + if frag != 0 { + continue; + } + (destination_addr, payload.slice_ref(data)) + }; + + match server.peers.get(&destination_addr) { + Some(io) => { + if let Err(_) = io.sender.send(data).await { + server.peers.remove(&destination_addr); + } + } + None => { + info!("New UDP connection for {}", destination_addr); + let (udp_client, io) = Socks5UdpStream::new( + mk_send_socket(&server.listener).ok()?, + peer_addr, + destination_addr.clone(), + server.cnx_timeout, + Arc::downgrade(&server.keys_to_delete), + ); + let _ = io.sender.send(data).await; + server.peers.insert(destination_addr, io); + return Some((Ok(udp_client), (server, mk_send_socket, buf))); + } + } + } + }, + ); + + Ok(stream) +} diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 808718e..a3e7e48 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,5 +1,5 @@ use super::{to_host_port, JwtTunnelConfig, JWT_HEADER_PREFIX, JWT_KEY}; -use crate::{LocalToRemote, WsClientConfig}; +use crate::{LocalProtocol, LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; use base64::Engine; @@ -111,7 +111,7 @@ pub async fn run_tunnel( incoming_cnx: T, ) -> anyhow::Result<()> where - T: Stream>, + T: Stream>, R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { @@ -122,10 +122,11 @@ where Level::INFO, "tunnel", id = request_id.to_string(), - remote = format!("{}:{}", remote_dest.0, remote_dest.1) + remote = format!("{}:{}", remote_dest.1, remote_dest.2) ); let mut tunnel_cfg = tunnel_cfg.clone(); - tunnel_cfg.remote = remote_dest; + tunnel_cfg.local_protocol = remote_dest.0; + tunnel_cfg.remote = (remote_dest.1, remote_dest.2); let client_config = client_config.clone(); let tunnel = async move { diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 0f11037..607d09e 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,7 +1,6 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use futures_util::{pin_mut, FutureExt}; use hyper::upgrade::Upgraded; -use std::cmp::max; use hyper_util::rt::TokioIo; use std::time::Duration; diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index bdcb4cf..5dc24fc 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -36,7 +36,7 @@ impl JwtTunnelConfig { LocalProtocol::Tcp => LocalProtocol::Tcp, LocalProtocol::Udp { .. } => tunnel.local_protocol, LocalProtocol::Stdio => LocalProtocol::Tcp, - LocalProtocol::Socks5 => LocalProtocol::Tcp, + LocalProtocol::Socks5 { .. } => LocalProtocol::Tcp, LocalProtocol::ReverseTcp => LocalProtocol::ReverseTcp, LocalProtocol::ReverseUdp { .. } => tunnel.local_protocol, LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index bb7664f..70600ab 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -22,7 +22,7 @@ use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; -use crate::socks5::Socks5Protocol; +use crate::socks5::Socks5Stream; use crate::tunnel::tls_reloader::TlsReloader; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite}; @@ -105,12 +105,12 @@ async fn run_tunnel( } LocalProtocol::ReverseSocks5 => { #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Protocol, (Host, u16))>>>> = + static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Stream, (Host, u16))>>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = socks5::run_server(bind.parse()?); + let listening_server = socks5::run_server(bind.parse()?, None); let (tcp, local_srv) = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tokio::io::split(tcp);