diff --git a/Cargo.toml b/Cargo.toml index 63b17e1..9f742f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,3 +53,4 @@ lto = "fat" panic = "abort" codegen-units = 1 opt-level = 3 +debug = 1 diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 8235d9e..70e5b41 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -5,15 +5,14 @@ use std::sync::Arc; use std::time::Duration; use super::{JwtTunnelConfig, JWT_DECODE}; -use crate::udp::MyUdpSocket; -use crate::{tcp, tls, LocalProtocol, WsServerConfig}; +use crate::{tcp, tls, LocalProtocol, WsServerConfig, udp}; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{http, Body, Request, Response, StatusCode}; use jsonwebtoken::TokenData; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, UdpSocket}; +use tokio::net::{TcpListener}; use tokio::sync::oneshot; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; @@ -53,16 +52,15 @@ async fn from_query( } match jwt.claims.p { - LocalProtocol::Udp { .. } => { + LocalProtocol::Udp { timeout, .. } => { let host = Host::parse(&jwt.claims.r)?; - let cnx = Arc::new(UdpSocket::bind("[::]:0").await?); - cnx.connect((host.to_string(), jwt.claims.rp)).await?; + let cnx = udp::connect(&host, jwt.claims.rp, timeout.unwrap_or(Duration::from_secs(10))).await?; Ok(( LocalProtocol::Udp { timeout: None }, host, jwt.claims.rp, - Box::pin(MyUdpSocket::new(cnx.clone())), - Box::pin(MyUdpSocket::new(cnx)), + Box::pin(cnx.clone()), + Box::pin(cnx), )) } LocalProtocol::Tcp { .. } => { diff --git a/src/udp.rs b/src/udp.rs index 617cc62..66f54c8 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,4 +1,4 @@ -use anyhow::Context; +use anyhow::{anyhow, Context}; use futures_util::{stream, Stream}; use parking_lot::RwLock; @@ -7,19 +7,21 @@ use std::collections::HashMap; use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::pin::{pin, Pin}; use std::sync::{Arc, Weak}; use std::task::{ready, Poll}; use std::time::Duration; +use log::warn; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; use tokio::sync::futures::Notified; use tokio::sync::Notify; -use tokio::time::Sleep; +use tokio::time::{Sleep, timeout}; use tracing::{debug, error, info}; +use url::Host; struct IoInner { has_data_to_read: Notify, @@ -29,7 +31,7 @@ struct UdpServer { listener: Arc, peers: HashMap, ahash::RandomState>, keys_to_delete: Arc>>, - pub cnx_timeout: Option, + cnx_timeout: Option, } impl UdpServer { @@ -213,6 +215,7 @@ pub async fn run_server( Ok(stream) } +#[derive(Clone)] pub struct MyUdpSocket { socket: Arc, } @@ -245,6 +248,71 @@ impl AsyncWrite for MyUdpSocket { } } +pub async fn connect( + host: &Host, + port: u16, + connect_timeout: Duration, +) -> anyhow::Result { + info!("Opening UDP connection to {}:{}", host, port); + + let socket_addrs: Vec = match host { + Host::Domain(domain) => timeout(connect_timeout, tokio::net::lookup_host(format!("{}:{}", domain, port))) + .await + .with_context(|| format!("cannot resolve domain: {}", domain))?? + .collect(), + Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))], + Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))], + }; + + + let mut cnx = None; + let mut last_err = None; + for addr in socket_addrs { + debug!("connecting to {}", addr); + + let socket = match &addr { + SocketAddr::V4(_) => UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await, + SocketAddr::V6(_) => UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await, + }; + + let socket = match socket { + Ok(socket) => socket, + Err(err) => { + warn!("cannot bind udp socket {:?}", err); + continue; + }, + }; + + match timeout(connect_timeout, socket.connect(addr)).await { + Ok(Ok(_)) => { + cnx = Some(socket); + break; + } + Ok(Err(err)) => { + debug!("Cannot connect udp socket to specified peer {addr} reason {err}"); + last_err = Some(err); + } + Err(_) => { + debug!( + "Cannot connect udp socket to specified peer {addr} due to timeout of {}s elapsed", + connect_timeout.as_secs() + ); + } + } + } + + if let Some(cnx) = cnx { + Ok(MyUdpSocket::new(Arc::new(cnx))) + } else { + Err(anyhow!( + "Cannot connect to udp peer {}:{} reason {:?}", + host, + port, + last_err + )) + } +} + #[cfg(test)] mod tests { use super::*;