diff --git a/Cargo.lock b/Cargo.lock index 11b6e2c..9ca3e52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + [[package]] name = "block-buffer" version = "0.10.4" @@ -198,9 +204,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.8" +version = "4.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" +checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" dependencies = [ "clap_builder", "clap_derive", @@ -208,9 +214,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.8" +version = "4.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" +checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" dependencies = [ "anstream", "anstyle", @@ -654,6 +660,15 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -674,6 +689,18 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.1", + "cfg-if", + "libc", + "memoffset", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -857,7 +884,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1004,7 +1031,7 @@ version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -1630,6 +1657,7 @@ dependencies = [ "jsonwebtoken", "libc", "log", + "nix", "once_cell", "parking_lot", "pin-project", diff --git a/Cargo.toml b/Cargo.toml index 7fc59b6..608595c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ repository = "https://github.com/erebe/wstunnel.git" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -clap = { version = "4.4.8", features = ["derive"]} +clap = { version = "4.4.10", features = ["derive"]} url = "2.5.0" anyhow = "1.0.75" @@ -24,6 +24,7 @@ rustls-pemfile = { version = "1.0.4", features = [] } bytes = { version = "1.5.0", features = [] } parking_lot = "0.12.1" urlencoding = "2.1.3" +nix = { version = "0.27.1", features = ["socket", "net", "uio"] } rustls-native-certs = { version = "0.6.3", features = [] } tokio = { version = "1.34.0", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index a55d47d..537001d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,17 +51,18 @@ enum Commands { struct Client { /// Listen on local and forwards traffic from remote. Can be specified multiple times /// examples: - /// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443 + /// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443 /// - /// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53 - /// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30] + /// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53 + /// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30] /// - /// 'socks5://[::1]:1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel + /// 'socks5://[::1]:1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel /// - /// 'tproxy+tcp://[::1]:1212' => listen locally on tcp on port 1212 as a *transparent proxy* and forward dynamically requested tunnel - /// linux only and requires sudo/CAP_NET_ADMIN + /// 'tproxy+tcp://[::1]:1212' => listen locally on tcp on port 1212 as a *transparent proxy* and forward dynamically requested tunnel + /// 'tproxy+udp://[::1]:1212?timeout_sec=10' listen locally on udp on port 1212 as a *transparent proxy* and forward dynamically requested tunnel + /// linux only and requires sudo/CAP_NET_ADMIN /// - /// 'stdio://google.com:443' => listen for data from stdio, mainly for `ssh -o ProxyCommand="wstunnel client -L stdio://%h:%p ws://localhost:8080" my-server` + /// 'stdio://google.com:443' => listen for data from stdio, mainly for `ssh -o ProxyCommand="wstunnel client -L stdio://%h:%p ws://localhost:8080" my-server` #[arg(short='L', long, value_name = "{tcp,udp,socks5,stdio}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)] local_to_remote: Vec, @@ -180,6 +181,7 @@ enum LocalProtocol { Stdio, Socks5, TProxyTcp, + TProxyUdp { timeout: Option }, ReverseTcp, ReverseUdp { timeout: Option }, ReverseSocks5, @@ -334,6 +336,21 @@ fn parse_tunnel_arg(arg: &str) -> Result { remote: (dest_host, dest_port), }) } + "tproxy+u" => { + let (local_bind, remaining) = parse_local_bind(&arg["tproxy+udp://".len()..])?; + let x = format!("0.0.0.0:0?{}", remaining); + let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?; + let timeout = options + .get("timeout_sec") + .and_then(|x| x.parse::().ok()) + .map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) }) + .unwrap_or(Some(Duration::from_secs(30))); + Ok(LocalToRemote { + local_protocol: LocalProtocol::TProxyUdp { timeout }, + local: local_bind, + remote: (dest_host, dest_port), + }) + } _ => Err(Error::new( ErrorKind::InvalidInput, format!("Invalid local protocol for tunnel {}", arg), @@ -644,7 +661,7 @@ async fn main() { LocalProtocol::TProxyTcp => { let server = tcp::run_server(tunnel.local, true) .await - .unwrap_or_else(|err| panic!("Cannot start TProxy server on {}: {}", tunnel.local, err)) + .unwrap_or_else(|err| panic!("Cannot start TProxy TCP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) .map_ok(move |stream| { // In TProxy mode local destination is the final ip:port destination @@ -658,13 +675,34 @@ async fn main() { } }); } + #[cfg(target_os = "linux")] + LocalProtocol::TProxyUdp { timeout } => { + let server = + udp::run_server(tunnel.local, *timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) + .await + .unwrap_or_else(|err| { + panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err) + }) + .map_err(anyhow::Error::new) + .map_ok(move |stream| { + // In TProxy mode local destination is the final ip:port destination + let dest = to_host_port(stream.local_addr().unwrap()); + (tokio::io::split(stream), dest) + }); + + tokio::spawn(async move { + if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { + error!("{:?}", err); + } + }); + } #[cfg(not(target_os = "linux"))] - LocalProtocol::TProxyTcp => { + LocalProtocol::TProxyTcp | LocalProtocol::TProxyUdp { .. } => { panic!("Transparent proxy is not available for non Linux platform") } LocalProtocol::Udp { timeout } => { let remote = tunnel.remote.clone(); - let server = udp::run_server(tunnel.local, *timeout) + let server = udp::run_server(tunnel.local, *timeout, |_| Ok(()), |s| Ok(s.clone())) .await .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index c83ec00..c536ec5 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -40,6 +40,7 @@ impl JwtTunnelConfig { LocalProtocol::ReverseUdp { .. } => tunnel.local_protocol, LocalProtocol::ReverseSocks5 => LocalProtocol::ReverseSocks5, LocalProtocol::TProxyTcp => LocalProtocol::Tcp, + LocalProtocol::TProxyUdp { timeout } => LocalProtocol::Udp { timeout }, }, r: tunnel.remote.0.to_string(), rp: tunnel.remote.1, diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index faba112..ec7bec1 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -103,7 +103,8 @@ async fn from_query( let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); let bind = format!("{}:{}", local_srv.0, local_srv.1); - let listening_server = udp::run_server(bind.parse()?, timeout); + let listening_server = + udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone())); let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tokio::io::split(udp); diff --git a/src/udp.rs b/src/udp.rs index e3cc4d3..644e264 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -23,15 +23,15 @@ use tokio::time::{timeout, Interval}; use tracing::{debug, error, info}; use url::Host; -struct IoInner { - has_data_to_read: Notify, - has_read_data: Notify, +pub struct IoInner { + pub has_data_to_read: Notify, + pub has_read_data: Notify, } -struct UdpServer { - listener: Arc, - peers: HashMap>, ahash::RandomState>, - keys_to_delete: Arc>>, - cnx_timeout: Option, +pub struct UdpServer { + pub listener: Arc, + pub peers: HashMap>, ahash::RandomState>, + pub keys_to_delete: Arc>>, + pub cnx_timeout: Option, } impl UdpServer { @@ -55,7 +55,7 @@ impl UdpServer { } } #[inline] - fn clean_dead_keys(&mut self) { + pub fn clean_dead_keys(&mut self) { let nb_key_to_delete = self.keys_to_delete.read().len(); if nb_key_to_delete == 0 { return; @@ -68,14 +68,15 @@ impl UdpServer { } keys_to_delete.clear(); } - fn clone_socket(&self) -> Arc { + pub fn clone_socket(&self) -> Arc { self.listener.clone() } } #[pin_project(PinnedDrop)] pub struct UdpStream { - socket: Arc, + recv_socket: Arc, + send_socket: Arc, peer: SocketAddr, #[pin] watchdog_deadline: Option, @@ -103,8 +104,9 @@ impl PinnedDrop for UdpStream { } impl UdpStream { - fn new( - socket: Arc, + pub fn new( + recv_socket: Arc, + send_socket: Arc, peer: SocketAddr, watchdog_deadline: Option, keys_to_delete: Weak>>, @@ -116,7 +118,8 @@ impl UdpStream { has_read_data, }); let mut s = Self { - socket, + recv_socket, + send_socket, peer, watchdog_deadline: watchdog_deadline .map(|timeout| tokio::time::interval_at(tokio::time::Instant::now() + timeout, timeout)), @@ -132,6 +135,10 @@ impl UdpStream { (s, io) } + + pub fn local_addr(&self) -> io::Result { + self.send_socket.local_addr() + } } impl AsyncRead for UdpStream { @@ -161,7 +168,7 @@ impl AsyncRead for UdpStream { project.pending_notification.as_mut().set(None); } - let peer = ready!(project.socket.poll_recv_from(cx, obuf))?; + let peer = ready!(project.recv_socket.poll_recv_from(cx, obuf))?; debug_assert_eq!(peer, *project.peer); *project.data_read_before_deadline = true; @@ -179,11 +186,11 @@ impl AsyncRead for UdpStream { impl AsyncWrite for UdpStream { fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { - self.socket.poll_send_to(cx, buf, self.peer) + self.send_socket.poll_send_to(cx, buf, self.peer) } fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - self.socket.poll_send_ready(cx) + self.send_socket.poll_send_ready(cx) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { @@ -194,6 +201,8 @@ impl AsyncWrite for UdpStream { pub async fn run_server( bind: SocketAddr, timeout: Option, + configure_listener: impl Fn(&UdpSocket) -> anyhow::Result<()>, + mk_send_socket: impl Fn(&Arc) -> anyhow::Result>, ) -> Result>, anyhow::Error> { info!( "Starting UDP server listening cnx on {} with cnx timeout of {}s", @@ -204,46 +213,51 @@ pub async fn run_server( let listener = UdpSocket::bind(bind) .await .with_context(|| format!("Cannot create UDP server {:?}", bind))?; + configure_listener(&listener)?; let udp_server = UdpServer::new(listener, timeout); - let stream = stream::unfold((udp_server, None), |(mut server, peer_with_data)| async move { - // New returned peer hasn't read its data yet, await for it. - if let Some(await_peer) = peer_with_data { - if let Some(peer) = server.peers.get(&await_peer) { - peer.has_read_data.notified().await; - } - }; - - loop { - server.clean_dead_keys(); - let peer_addr = match server.listener.peek_sender().await { - Ok(ret) => ret, - Err(err) => { - error!("Cannot read from UDP server. Closing server: {}", err); - return None; + let stream = stream::unfold( + (udp_server, None, mk_send_socket), + |(mut server, peer_with_data, mk_send_socket)| async move { + // New returned peer hasn't read its data yet, await for it. + if let Some(await_peer) = peer_with_data { + if let Some(peer) = server.peers.get(&await_peer) { + peer.has_read_data.notified().await; } }; - match server.peers.get(&peer_addr) { - Some(io) => { - io.has_data_to_read.notify_one(); - io.has_read_data.notified().await; - } - None => { - info!("New UDP connection from {}", peer_addr); - let (udp_client, io) = UdpStream::new( - server.clone_socket(), - peer_addr, - server.cnx_timeout, - Arc::downgrade(&server.keys_to_delete), - ); - io.has_data_to_read.notify_waiters(); - server.peers.insert(peer_addr, io); - return Some((Ok(udp_client), (server, Some(peer_addr)))); + loop { + server.clean_dead_keys(); + let peer_addr = match server.listener.peek_sender().await { + Ok(ret) => ret, + Err(err) => { + error!("Cannot read from UDP server. Closing server: {}", err); + return None; + } + }; + + match server.peers.get(&peer_addr) { + Some(io) => { + io.has_data_to_read.notify_one(); + io.has_read_data.notified().await; + } + None => { + info!("New UDP connection from {}", peer_addr); + let (udp_client, io) = UdpStream::new( + server.clone_socket(), + mk_send_socket(&server.listener).ok()?, + peer_addr, + server.cnx_timeout, + Arc::downgrade(&server.keys_to_delete), + ); + io.has_data_to_read.notify_waiters(); + server.peers.insert(peer_addr, io); + return Some((Ok(udp_client), (server, Some(peer_addr), mk_send_socket))); + } } } - } - }); + }, + ); Ok(stream) } @@ -336,6 +350,75 @@ pub async fn connect(host: &Host, port: u16, connect_timeout: Duration) } } +#[cfg(target_os = "linux")] +pub fn configure_tproxy(listener: &UdpSocket) -> anyhow::Result<()> { + use std::net::IpAddr; + use std::os::fd::AsFd; + + socket2::SockRef::from(&listener).set_ip_transparent(true)?; + match listener.local_addr().unwrap().ip() { + IpAddr::V4(_) => { + nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv4OrigDstAddr, &true)?; + } + IpAddr::V6(_) => { + nix::sys::socket::setsockopt(&listener.as_fd(), nix::sys::socket::sockopt::Ipv6OrigDstAddr, &true)?; + } + }; + Ok(()) +} + +#[cfg(target_os = "linux")] +pub fn mk_send_socket_tproxy(listener: &Arc) -> anyhow::Result> { + use nix::cmsg_space; + use nix::sys::socket::{ControlMessageOwned, RecvMsg, SockaddrIn}; + use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + use std::io::IoSliceMut; + use std::net::IpAddr; + use std::os::fd::AsRawFd; + + let mut x = cmsg_space!(libc::sockaddr_in6); + let mut buf = [0; 8]; + let mut io = [IoSliceMut::new(&mut buf)]; + let msg: nix::Result> = nix::sys::socket::recvmsg( + listener.as_raw_fd(), + &mut io, + Some(&mut x), + nix::sys::socket::MsgFlags::MSG_PEEK, + ); + + let mut remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let msg = msg.unwrap(); + for cmsg in msg.cmsgs() { + match cmsg { + ControlMessageOwned::Ipv4OrigDstAddr(ip) => { + remote_addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::from(u32::from_be(ip.sin_addr.s_addr))), + u16::from_be(ip.sin_port), + ); + } + ControlMessageOwned::Ipv6OrigDstAddr(ip) => { + remote_addr = SocketAddr::new( + IpAddr::V6(Ipv6Addr::from(u128::from_be_bytes(ip.sin6_addr.s6_addr))), + u16::from_be(ip.sin6_port), + ); + } + _ => { + warn!("Unknown control message {:?}", cmsg); + } + } + } + + let socket = Socket::new(Domain::for_address(remote_addr), Type::DGRAM, Some(Protocol::UDP)).unwrap(); + socket.set_ip_transparent(true).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.set_reuse_port(true).unwrap(); + socket.bind(&SockAddr::from(remote_addr)).unwrap(); + socket.set_nonblocking(true).unwrap(); + let socket = UdpSocket::from_std(std::net::UdpSocket::from(socket)).unwrap(); + + Ok(Arc::new(socket)) +} + #[cfg(test)] mod tests { use super::*; @@ -347,7 +430,7 @@ mod tests { #[tokio::test] async fn test_udp_server() { let server_addr: SocketAddr = "[::1]:1234".parse().unwrap(); - let server = run_server(server_addr, None).await.unwrap(); + let server = run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())).await.unwrap(); pin_mut!(server); // Should timeout @@ -393,7 +476,7 @@ mod tests { #[tokio::test] async fn test_multiple_client() { let server_addr: SocketAddr = "[::1]:1235".parse().unwrap(); - let mut server = Box::pin(run_server(server_addr, None).await.unwrap()); + let mut server = Box::pin(run_server(server_addr, None, |_| Ok(()), |l| Ok(l.clone())).await.unwrap()); // Send some data to the server let client = UdpSocket::bind("[::1]:0").await.unwrap(); @@ -459,7 +542,7 @@ mod tests { async fn test_udp_should_timeout() { let server_addr: SocketAddr = "[::1]:1237".parse().unwrap(); let socket_timeout = Duration::from_secs(1); - let server = run_server(server_addr, Some(socket_timeout)).await.unwrap(); + let server = run_server(server_addr, Some(socket_timeout), |_| Ok(()), |l| Ok(l.clone())).await.unwrap(); pin_mut!(server); // Send some data to the server