diff --git a/src/tcp.rs b/src/tcp.rs index 786c6da..8f704f5 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -5,7 +5,9 @@ use crate::dns::DnsResolver; use base64::Engine; use bytes::BytesMut; use log::warn; +use socket2::TcpKeepalive; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; + use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; @@ -16,22 +18,31 @@ use tracing::{debug, instrument}; use url::{Host, Url}; fn configure_socket(socket: &mut TcpSocket, so_mark: &Option) -> Result<(), anyhow::Error> { + let socket = socket2::SockRef::from(&socket); socket .set_nodelay(true) - .with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?; + .with_context(|| format!("cannot set no_delay on socket: {:?}", io::Error::last_os_error()))?; + + #[cfg(not(target_os = "windows"))] + let tcp_keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(60)) + .with_interval(Duration::from_secs(10)) + .with_retries(3); + + #[cfg(target_os = "windows")] + let tcp_keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(60)) + .with_interval(Duration::from_secs(10)); + + socket + .set_tcp_keepalive(&tcp_keepalive) + .with_context(|| format!("cannot set tcp_keepalive on socket: {:?}", io::Error::last_os_error()))?; #[cfg(target_os = "linux")] if let Some(so_mark) = so_mark { - use std::os::fd::AsFd; - - let ret = nix::sys::socket::setsockopt(&socket.as_fd(), nix::sys::socket::sockopt::Mark, so_mark); - if let Err(err) = ret { - return Err(anyhow!( - "Cannot set SO_MARK on the connection {:?} {:?}", - err, - io::Error::last_os_error() - )); - } + socket + .set_mark(*so_mark) + .with_context(|| format!("cannot set SO_MARK on socket: {:?}", io::Error::last_os_error()))?; } Ok(())