wstunnel/src/tcp.rs
Σrebe - Romain GERARD 8387557459 ground 1
Former-commit-id: d125471a7e73cbde30e1d8cb42a9e6d7aac10131 [formerly d14a90382da6197691d28f61151f1278dca23a53] [formerly 51e8380286fd2a4dc2ff577a507d0df2356b1e79 [formerly 0cd5c5c0eaa4a0538a566ba9e6bb5d925da77c1a]]
Former-commit-id: b4b5769ad601c8cce35047a3e16ff185e228ea41 [formerly 4c6f9e6bd777c187a240b0a6119c2a4eaa396da4]
Former-commit-id: 2aa31860ffe6a5a51f0148527598a7399f968801
Former-commit-id: a826342ca74b913ce45171f92bc6b9d19ac8db08
Former-commit-id: fc039d0217ff4d0c47048755da4f833d86568586
Former-commit-id: d742a7134f042fd67fb1b9399490473babef28d9 [formerly 7e2ea8487c5ce2a5fb39b015eb6d00d8a48654c6]
Former-commit-id: f4273cd0403ef19d4cf19c861fa4d730ab10b29d
2023-10-15 17:56:05 +02:00

110 lines
3.4 KiB
Rust

use anyhow::{anyhow, Context};
use std::{io, vec};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::fd::AsRawFd;
use std::time::Duration;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::time::timeout;
use tokio_stream::wrappers::TcpListenerStream;
use tracing::debug;
use tracing::log::info;
use url::Host;
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
socket.set_nodelay(true).with_context(|| {
format!(
"cannot set no_delay on socket: {}",
io::Error::last_os_error()
)
})?;
if let Some(so_mark) = so_mark {
unsafe {
let optval: libc::c_int = *so_mark;
let ret = libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_MARK,
&optval as *const _ as *const libc::c_void,
std::mem::size_of_val(&optval) as libc::socklen_t,
);
if ret != 0 {
return Err(anyhow!(
"Cannot set SO_MARK on the connection {:?}",
io::Error::last_os_error()
));
}
}
}
Ok(())
}
pub async fn connect(
host: &Host<String>,
port: u16,
so_mark: &Option<i32>,
connect_timeout: Duration,
) -> Result<TcpStream, anyhow::Error> {
info!("Opening TCP connection to {}:{}", host, port);
// TODO: Avoid allocation of vec by extracting the code that does the connection in a separate function
let socket_addrs: Vec<SocketAddr> = match host {
Host::Domain(domain) => tokio::net::lookup_host(format!("{}:{}", domain, port))
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?
.collect(),
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
};
let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
debug!("connecting to {}", addr);
let mut socket = match &addr {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
configure_socket(&mut socket, so_mark)?;
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(stream)) => {
cnx = Some(stream);
break;
}
Ok(Err(err)) => {
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err(_) => {
debug!(
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
);
}
}
}
if let Some(cnx) = cnx {
Ok(cnx)
} else {
Err(anyhow!(
"Cannot connect to tcp endpoint {}:{} reason {:?}",
host,
port,
last_err
))
}
}
pub async fn run_server(bind: SocketAddr) -> Result<TcpListenerStream, anyhow::Error> {
info!("Starting TCP server listening cnx on {}", bind);
let listener = TcpListener::bind(bind)
.await
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
Ok(TcpListenerStream::new(listener))
}