refacto: split into modules

This commit is contained in:
Σrebe - Romain GERARD 2024-07-28 13:14:08 +02:00
parent 6a07201de1
commit 38cb7ed5f8
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
35 changed files with 745 additions and 596 deletions

View file

@ -1,18 +1,7 @@
mod dns;
mod embedded_certificate; mod embedded_certificate;
mod http_proxy; mod protocols;
mod restrictions; mod restrictions;
mod socks5;
mod socks5_udp;
mod stdio;
mod tcp;
mod tls;
mod tls_utils;
mod tunnel; mod tunnel;
mod types;
mod udp;
#[cfg(unix)]
mod unix_socket;
use anyhow::anyhow; use anyhow::anyhow;
use base64::Engine; use base64::Engine;
@ -41,16 +30,15 @@ use tokio_rustls::TlsConnector;
use tracing::{error, info}; use tracing::{error, info};
use crate::dns::DnsResolver; use crate::protocols::dns::DnsResolver;
use crate::protocols::udp::MyUdpSocket;
use crate::protocols::{socks5, tls, udp};
use crate::restrictions::types::RestrictionsRules; use crate::restrictions::types::RestrictionsRules;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use crate::tunnel::listeners::{
new_stdio_listener, new_udp_listener, HttpProxyTunnelListener, Socks5TunnelListener, TcpTunnelListener,
};
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use crate::types::{
HttpProxyTunnelListener, Socks5TunnelListener, StdioTunnelListener, TProxyUdpTunnelListener, TcpTunnelListener,
TproxyTcpTunnelListener, UdpTunnelListener, UnixTunnelListener,
};
use crate::udp::MyUdpSocket;
use tracing_subscriber::filter::Directive; use tracing_subscriber::filter::Directive;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use url::{Host, Url}; use url::{Host, Url};
@ -823,7 +811,7 @@ impl WsClientConfig {
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> anyhow::Result<()> {
let args = Wstunnel::parse(); let args = Wstunnel::parse();
// Setup logging // Setup logging
@ -870,8 +858,8 @@ async fn main() {
// to be the common name (CN) of the client's certificate. // to be the common name (CN) of the client's certificate.
tls_certificate tls_certificate
.as_ref() .as_ref()
.and_then(|certs| find_leaf_certificate(certs.as_slice())) .and_then(|certs| tls::find_leaf_certificate(certs.as_slice()))
.and_then(|leaf_cert| cn_from_certificate(&leaf_cert)) .and_then(|leaf_cert| tls::cn_from_certificate(&leaf_cert))
.unwrap_or(args.http_upgrade_path_prefix) .unwrap_or(args.http_upgrade_path_prefix)
} else { } else {
args.http_upgrade_path_prefix args.http_upgrade_path_prefix
@ -1004,7 +992,7 @@ async fn main() {
let remote = tunnel.remote.clone(); let remote = tunnel.remote.clone();
let cfg = client_config.clone(); let cfg = client_config.clone();
let connect_to_dest = |_| async { let connect_to_dest = |_| async {
tcp::connect( protocols::tcp::connect(
&remote.0, &remote.0,
remote.1, remote.1,
cfg.socket_so_mark, cfg.socket_so_mark,
@ -1081,11 +1069,15 @@ async fn main() {
}; };
match remote.protocol { match remote.protocol {
LocalProtocol::Tcp { proxy_protocol: _ } => { LocalProtocol::Tcp { proxy_protocol: _ } => protocols::tcp::connect(
tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver) &remote.host,
.await remote.port,
.map(|s| Box::new(s) as Box<dyn T>) so_mark,
} timeout,
dns_resolver,
)
.await
.map(|s| Box::new(s) as Box<dyn T>),
LocalProtocol::Udp { .. } => { LocalProtocol::Udp { .. } => {
udp::connect(&remote.host, remote.port, timeout, so_mark, dns_resolver) udp::connect(&remote.host, remote.port, timeout, so_mark, dns_resolver)
.await .await
@ -1125,7 +1117,8 @@ async fn main() {
return Err(anyhow!("Missing remote destination for reverse socks5")); return Err(anyhow!("Missing remote destination for reverse socks5"));
}; };
tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver).await protocols::tcp::connect(&remote.host, remote.port, so_mark, timeout, dns_resolver)
.await
} }
}; };
@ -1143,7 +1136,7 @@ async fn main() {
let remote = tunnel.remote.clone(); let remote = tunnel.remote.clone();
let cfg = client_config.clone(); let cfg = client_config.clone();
let connect_to_dest = |_| async { let connect_to_dest = |_| async {
tcp::connect( protocols::tcp::connect(
&remote.0, &remote.0,
remote.1, remote.1,
cfg.socket_so_mark, cfg.socket_so_mark,
@ -1188,10 +1181,8 @@ async fn main() {
match &tunnel.local_protocol { match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol } => { LocalProtocol::Tcp { proxy_protocol } => {
let server = tcp::run_server(tunnel.local, false) let server =
.await TcpTunnelListener::new(tunnel.local, tunnel.remote.clone(), *proxy_protocol).await?;
.unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err));
let server = TcpTunnelListener::new(server, tunnel.remote.clone(), *proxy_protocol);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1200,10 +1191,8 @@ async fn main() {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
LocalProtocol::TProxyTcp => { LocalProtocol::TProxyTcp => {
let server = tcp::run_server(tunnel.local, true).await.unwrap_or_else(|err| { use crate::tunnel::listeners::TproxyTcpTunnelListener;
panic!("Cannot start TProxy TCP server on {}: {}", tunnel.local, err) let server = TproxyTcpTunnelListener::new(tunnel.local, false).await?;
});
let server = TproxyTcpTunnelListener::new(server, false); // TODO: support proxy protocol
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
@ -1213,11 +1202,8 @@ async fn main() {
} }
#[cfg(unix)] #[cfg(unix)]
LocalProtocol::Unix { path } => { LocalProtocol::Unix { path } => {
let server = unix_socket::run_server(path).await.unwrap_or_else(|err| { use crate::tunnel::listeners::UnixTunnelListener;
panic!("Cannot start Unix domain server on {}: {}", tunnel.local, err) let server = UnixTunnelListener::new(path, tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
});
let server = UnixTunnelListener::new(server, tunnel.remote.clone(), false); // TODO: support proxy protocol
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1231,14 +1217,8 @@ async fn main() {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
LocalProtocol::TProxyUdp { timeout } => { LocalProtocol::TProxyUdp { timeout } => {
let server = use crate::tunnel::listeners::new_tproxy_udp;
udp::run_server(tunnel.local, *timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) let server = new_tproxy_udp(tunnel.local, *timeout).await?;
.await
.unwrap_or_else(|err| {
panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err)
});
let server = TProxyUdpTunnelListener::new(server, *timeout);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1250,10 +1230,7 @@ async fn main() {
panic!("Transparent proxy is not available for non Linux platform") panic!("Transparent proxy is not available for non Linux platform")
} }
LocalProtocol::Udp { timeout } => { LocalProtocol::Udp { timeout } => {
let server = udp::run_server(tunnel.local, *timeout, |_| Ok(()), |s| Ok(s.clone())) let server = new_udp_listener(tunnel.local, tunnel.remote.clone(), *timeout).await?;
.await
.unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err));
let server = UdpTunnelListener::new(server, tunnel.remote.clone(), *timeout);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
@ -1262,11 +1239,7 @@ async fn main() {
}); });
} }
LocalProtocol::Socks5 { timeout, credentials } => { LocalProtocol::Socks5 { timeout, credentials } => {
let server = socks5::run_server(tunnel.local, *timeout, credentials.clone()) let server = Socks5TunnelListener::new(tunnel.local, *timeout, credentials.clone()).await?;
.await
.unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err));
let server = Socks5TunnelListener::new(server);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1278,13 +1251,9 @@ async fn main() {
credentials, credentials,
proxy_protocol, proxy_protocol,
} => { } => {
let server = http_proxy::run_server(tunnel.local, *timeout, credentials.clone()) let server =
.await HttpProxyTunnelListener::new(tunnel.local, *timeout, credentials.clone(), *proxy_protocol)
.unwrap_or_else(|err| { .await?;
panic!("Cannot start http proxy server on {}: {}", tunnel.local, err)
});
let server = HttpProxyTunnelListener::new(server, *proxy_protocol);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1293,10 +1262,7 @@ async fn main() {
} }
LocalProtocol::Stdio => { LocalProtocol::Stdio => {
let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| { let (server, mut handle) = new_stdio_listener(tunnel.remote.clone(), false).await?; // TODO: support proxy protocol
panic!("Cannot start STDIO server: {}", err);
});
let server = StdioTunnelListener::new(server, tunnel.remote.clone(), false);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1410,4 +1376,5 @@ async fn main() {
} }
tokio::signal::ctrl_c().await.unwrap(); tokio::signal::ctrl_c().await.unwrap();
Ok(())
} }

3
src/protocols/dns/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod resolver;
pub use resolver::DnsResolver;

View file

@ -1,4 +1,4 @@
use crate::tcp; use crate::protocols;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use futures_util::{FutureExt, TryFutureExt}; use futures_util::{FutureExt, TryFutureExt};
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
@ -205,7 +205,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
}; };
if let Some(proxy) = &proxy { if let Some(proxy) = &proxy {
tcp::connect_with_http_proxy( protocols::tcp::connect_with_http_proxy(
proxy, proxy,
&host, &host,
server_addr.port(), server_addr.port(),
@ -217,7 +217,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
.map(|s| s.map(AsyncIoTokioAsStd)) .map(|s| s.map(AsyncIoTokioAsStd))
.await .await
} else { } else {
tcp::connect( protocols::tcp::connect(
&host, &host,
server_addr.port(), server_addr.port(),
so_mark, so_mark,
@ -261,7 +261,7 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::dns::sort_socket_addrs; use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[test] #[test]

View file

@ -0,0 +1,4 @@
mod server;
pub use server::run_server;
pub use server::HttpProxyListener;

View file

@ -33,7 +33,7 @@ impl Stream for HttpProxyListener {
} }
} }
fn server_client( fn handle_request(
credentials: &Option<String>, credentials: &Option<String>,
dest: &Mutex<(Host, u16)>, dest: &Mutex<(Host, u16)>,
req: Request<Incoming>, req: Request<Incoming>,
@ -112,7 +112,7 @@ pub async fn run_server(
let forward_to = Mutex::new((Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0)), 0)); let forward_to = Mutex::new((Host::Ipv4(Ipv4Addr::new(0, 0, 0, 0)), 0));
let conn_fut = http1.serve_connection( let conn_fut = http1.serve_connection(
hyper_util::rt::TokioIo::new(&mut stream), hyper_util::rt::TokioIo::new(&mut stream),
service_fn(|req| server_client(&auth_header, &forward_to, req)), service_fn(|req| handle_request(&auth_header, &forward_to, req)),
); );
match conn_fut.await { match conn_fut.await {
Ok(_) => return Some((Ok((stream, forward_to.into_inner())), (listener, http1, auth_header))), Ok(_) => return Some((Ok((stream, forward_to.into_inner())), (listener, http1, auth_header))),

9
src/protocols/mod.rs Normal file
View file

@ -0,0 +1,9 @@
pub mod dns;
pub mod http_proxy;
pub mod socks5;
pub mod stdio;
pub mod tcp;
pub mod tls;
pub mod udp;
#[cfg(unix)]
pub mod unix_sock;

View file

@ -0,0 +1,6 @@
mod tcp_server;
mod udp_server;
pub use tcp_server::run_server;
pub use tcp_server::Socks5Listener;
pub use tcp_server::Socks5Stream;

View file

@ -1,5 +1,5 @@
use crate::socks5_udp::Socks5UdpStream; use super::udp_server::Socks5UdpStream;
use crate::{socks5_udp, LocalProtocol}; use crate::LocalProtocol;
use anyhow::Context; use anyhow::Context;
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server}; use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
use fast_socks5::util::target_addr::TargetAddr; use fast_socks5::util::target_addr::TargetAddr;
@ -29,7 +29,7 @@ pub enum Socks5Stream {
impl Socks5Stream { impl Socks5Stream {
pub fn local_protocol(&self) -> LocalProtocol { pub fn local_protocol(&self) -> LocalProtocol {
match self { match self {
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
Self::Udp(s) => LocalProtocol::Udp { Self::Udp(s) => LocalProtocol::Udp {
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
}, },
@ -72,7 +72,7 @@ pub async fn run_server(
cfg.set_execute_command(false); cfg.set_execute_command(false);
cfg.set_udp_support(true); cfg.set_udp_support(true);
let udp_server = socks5_udp::run_server(bind, timeout).await?; let udp_server = super::udp_server::run_server(bind, timeout).await?;
let server = server.with_config(cfg); let server = server.with_config(cfg);
let stream = stream::unfold((server, Box::pin(udp_server)), move |(server, mut udp_server)| async move { let stream = stream::unfold((server, Box::pin(udp_server)), move |(server, mut udp_server)| async move {
let mut acceptor = server.incoming(); let mut acceptor = server.incoming();

View file

@ -0,0 +1,9 @@
#[cfg(unix)]
mod server_unix;
#[cfg(not(unix))]
mod server_windows;
#[cfg(unix)]
pub use server_unix::run_server;
#[cfg(not(unix))]
pub use server_windows::run_server;

View file

@ -0,0 +1,27 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::oneshot;
use tokio_fd::AsyncFd;
use tracing::info;
pub struct WsStdin {
stdin: AsyncFd,
_receiver: oneshot::Receiver<()>,
}
impl AsyncRead for WsStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
}
}
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
let (tx, rx) = oneshot::channel::<()>();
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
}

View file

@ -0,0 +1,82 @@
use bytes::BytesMut;
use log::error;
use parking_lot::Mutex;
use scopeguard::guard;
use std::io::{Read, Write};
use std::sync::Arc;
use std::{io, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader;
use tracing::info;
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = oneshot::channel::<()>();
let abort_rx = Arc::new(Mutex::new(abort_rx));
let abort_rx2 = abort_rx.clone();
thread::spawn(move || {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx.lock().close();
});
let stdin = stdin;
let mut stdin = stdin.lock();
let mut buf = [0u8; 65536];
loop {
let n = stdin.read(&mut buf).unwrap_or(0);
if n == 0 || (n == 1 && buf[0] == 3) {
// ctrl+c send char 3
break;
}
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
error!("Failed send inout: {:?}", err);
break;
}
}
});
let stdin = StreamReader::new(UnboundedReceiverStream::new(recv));
let (stdout, mut recv) = tokio::io::duplex(65536);
let rt = tokio::runtime::Handle::current();
thread::spawn(move || {
let task = async move {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx2.lock().close();
});
let mut stdout = io::stdout().lock();
let mut buf = [0u8; 65536];
loop {
let Ok(n) = recv.read(&mut buf).await else {
break;
};
if n == 0 {
break;
}
if let Err(err) = stdout.write_all(&buf[..n]) {
error!("Failed to write to stdout: {:?}", err);
break;
};
let _ = stdout.flush();
}
};
let local = LocalSet::new();
local.spawn_local(task);
rt.block_on(local);
});
Ok(((stdin, stdout), abort_tx))
}

6
src/protocols/tcp/mod.rs Normal file
View file

@ -0,0 +1,6 @@
mod server;
pub use server::configure_socket;
pub use server::connect;
pub use server::connect_with_http_proxy;
pub use server::run_server;

View file

@ -2,13 +2,13 @@ use anyhow::{anyhow, Context};
use std::{io, vec}; use std::{io, vec};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use crate::dns::DnsResolver;
use base64::Engine; use base64::Engine;
use bytes::BytesMut; use bytes::BytesMut;
use log::warn; use log::warn;
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::protocols::dns::DnsResolver;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};

10
src/protocols/tls/mod.rs Normal file
View file

@ -0,0 +1,10 @@
mod server;
mod utils;
pub use server::connect;
pub use server::load_certificates_from_pem;
pub use server::load_private_key_from_file;
pub use server::tls_acceptor;
pub use server::tls_connector;
pub use utils::cn_from_certificate;
pub use utils::find_leaf_certificate;

11
src/protocols/udp/mod.rs Normal file
View file

@ -0,0 +1,11 @@
mod server;
#[cfg(target_os = "linux")]
pub use server::configure_tproxy;
pub use server::connect;
#[cfg(target_os = "linux")]
pub use server::mk_send_socket_tproxy;
pub use server::run_server;
pub use server::MyUdpSocket;
pub use server::UdpStream;
pub use server::UdpStreamWriter;

View file

@ -5,9 +5,9 @@ use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::{io, task};
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::{io, task};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use log::warn; use log::warn;
@ -20,7 +20,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::futures::Notified; use tokio::sync::futures::Notified;
use crate::dns::DnsResolver; use crate::protocols::dns::DnsResolver;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio::time::{sleep, timeout, Interval}; use tokio::time::{sleep, timeout, Interval};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
@ -173,11 +173,7 @@ impl UdpStream {
} }
impl AsyncRead for UdpStream { impl AsyncRead for UdpStream {
fn poll_read( fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, obuf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut project = self.project(); let mut project = self.project();
// Look that the timeout for client has not elapsed // Look that the timeout for client has not elapsed
if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() { if let Some(mut deadline) = project.watchdog_deadline.as_pin_mut() {

View file

@ -0,0 +1,4 @@
mod server;
pub use server::run_server;
pub use server::UnixListenerStream;

View file

@ -1,116 +0,0 @@
#[cfg(unix)]
pub mod server {
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::oneshot;
use tokio_fd::AsyncFd;
use tracing::info;
pub struct WsStdin {
stdin: AsyncFd,
_receiver: oneshot::Receiver<()>,
}
impl AsyncRead for WsStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
}
}
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
let (tx, rx) = oneshot::channel::<()>();
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
}
}
#[cfg(not(unix))]
pub mod server {
use bytes::BytesMut;
use log::error;
use parking_lot::Mutex;
use scopeguard::guard;
use std::io::{Read, Write};
use std::sync::Arc;
use std::{io, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader;
use tracing::info;
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = oneshot::channel::<()>();
let abort_rx = Arc::new(Mutex::new(abort_rx));
let abort_rx2 = abort_rx.clone();
thread::spawn(move || {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx.lock().close();
});
let stdin = stdin;
let mut stdin = stdin.lock();
let mut buf = [0u8; 65536];
loop {
let n = stdin.read(&mut buf).unwrap_or(0);
if n == 0 || (n == 1 && buf[0] == 3) {
// ctrl+c send char 3
break;
}
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
error!("Failed send inout: {:?}", err);
break;
}
}
});
let stdin = StreamReader::new(UnboundedReceiverStream::new(recv));
let (stdout, mut recv) = tokio::io::duplex(65536);
let rt = tokio::runtime::Handle::current();
thread::spawn(move || {
let task = async move {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
abort_rx2.lock().close();
});
let mut stdout = io::stdout().lock();
let mut buf = [0u8; 65536];
loop {
let Ok(n) = recv.read(&mut buf).await else {
break;
};
if n == 0 {
break;
}
if let Err(err) = stdout.write_all(&buf[..n]) {
error!("Failed to write to stdout: {:?}", err);
break;
};
let _ = stdout.flush();
}
};
let local = LocalSet::new();
local.spawn_local(task);
rt.block_on(local);
});
Ok(((stdin, stdout), abort_tx))
}
}

View file

@ -1,6 +1,6 @@
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
use crate::tunnel::listeners::TunnelListener;
use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::tunnel::transport::{TunnelReader, TunnelWriter};
use crate::types::TunnelListener;
use crate::{tunnel, WsClientConfig}; use crate::{tunnel, WsClientConfig};
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::header::COOKIE; use hyper::header::COOKIE;
@ -57,8 +57,7 @@ where
Ok(()) Ok(())
} }
pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()> {
{
pin_mut!(incoming_cnx); pin_mut!(incoming_cnx);
while let Some(cnx) = incoming_cnx.next().await { while let Some(cnx) = incoming_cnx.next().await {
let (cnx_stream, remote_addr) = match cnx { let (cnx_stream, remote_addr) = match cnx {

View file

@ -0,0 +1,54 @@
use crate::protocols::http_proxy;
use crate::protocols::http_proxy::HttpProxyListener;
use crate::tunnel::RemoteAddr;
use crate::LocalProtocol;
use anyhow::{anyhow, Context};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio_stream::Stream;
pub struct HttpProxyTunnelListener {
listener: HttpProxyListener,
proxy_protocol: bool,
}
impl HttpProxyTunnelListener {
pub async fn new(
bind_addr: SocketAddr,
timeout: Option<Duration>,
credentials: Option<(String, String)>,
proxy_protocol: bool,
) -> anyhow::Result<Self> {
let listener = http_proxy::run_server(bind_addr, timeout, credentials)
.await
.with_context(|| anyhow!("Cannot start http proxy server on {}", bind_addr))?;
Ok(Self {
listener,
proxy_protocol,
})
}
}
impl Stream for HttpProxyTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
};
Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,44 @@
mod tcp;
#[cfg(target_os = "linux")]
mod tproxy;
mod http_proxy;
mod socks5;
mod stdio;
mod udp;
#[cfg(unix)]
mod unix_sock;
#[cfg(target_os = "linux")]
pub use tproxy::new_tproxy_udp;
#[cfg(target_os = "linux")]
pub use tproxy::TproxyTcpTunnelListener;
pub use http_proxy::HttpProxyTunnelListener;
pub use socks5::Socks5TunnelListener;
pub use stdio::new_stdio_listener;
pub use tcp::TcpTunnelListener;
pub use udp::new_udp_listener;
#[cfg(unix)]
pub use unix_sock::UnixTunnelListener;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::Stream;
pub trait TunnelListener:
Stream<Item = anyhow::Result<((Self::Reader, Self::Writer), crate::tunnel::RemoteAddr)>>
{
type Reader: AsyncRead + Send + 'static;
type Writer: AsyncWrite + Send + 'static;
}
impl<T, R, W> TunnelListener for T
where
T: Stream<Item = anyhow::Result<((R, W), crate::tunnel::RemoteAddr)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Reader = R;
type Writer = W;
}

View file

@ -0,0 +1,47 @@
use crate::protocols::socks5;
use crate::protocols::socks5::{Socks5Listener, Socks5Stream};
use crate::tunnel::RemoteAddr;
use anyhow::{anyhow, Context};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{ReadHalf, WriteHalf};
use tokio_stream::Stream;
pub struct Socks5TunnelListener {
listener: Socks5Listener,
}
impl Socks5TunnelListener {
pub async fn new(
bind_addr: SocketAddr,
timeout: Option<Duration>,
credentials: Option<(String, String)>,
) -> anyhow::Result<Self> {
let listener = socks5::run_server(bind_addr, timeout, credentials)
.await
.with_context(|| anyhow!("Cannot start Socks5 server on {}", bind_addr))?;
Ok(Self { listener })
}
}
impl Stream for Socks5TunnelListener {
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
// TODO: Check if tokio::io::split can be avoided
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = stream.local_protocol();
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,70 @@
use crate::protocols::stdio;
use crate::tunnel::RemoteAddr;
use crate::LocalProtocol;
use anyhow::{anyhow, Context};
use std::pin::Pin;
use std::task::Poll;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
use tokio_stream::Stream;
use url::Host;
pub struct StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
listener: Option<(R, W)>,
dest: (Host, u16),
proxy_protocol: bool,
}
pub async fn new_stdio_listener(
dest: (Host, u16),
proxy_protocol: bool,
) -> anyhow::Result<(
StdioTunnelListener<impl AsyncRead + Send, impl AsyncWrite + Send>,
oneshot::Sender<()>,
)> {
let (listener, handle) = stdio::run_server()
.await
.with_context(|| anyhow!("Cannot start STDIO server"))?;
Ok((
StdioTunnelListener {
listener: Some(listener),
proxy_protocol,
dest,
},
handle,
))
}
impl<R, W> Stream for StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Item = anyhow::Result<((R, W), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = match this.listener.take() {
None => None,
Some(stream) => {
let (host, port) = this.dest.clone();
Some(Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,57 @@
use crate::tunnel::RemoteAddr;
use crate::{protocols, LocalProtocol};
use anyhow::{anyhow, Context};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Poll};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::Stream;
use url::Host;
pub struct TcpTunnelListener {
listener: TcpListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl TcpTunnelListener {
pub async fn new(bind_addr: SocketAddr, dest: (Host, u16), proxy_protocol: bool) -> anyhow::Result<Self> {
let listener = protocols::tcp::run_server(bind_addr, false)
.await
.with_context(|| anyhow!("Cannot start TCP server on {}", bind_addr))?;
Ok(Self {
listener,
dest,
proxy_protocol,
})
}
}
impl Stream for TcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(strean)) => {
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
strean.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,107 @@
use crate::protocols::udp;
use crate::protocols::udp::{UdpStream, UdpStreamWriter};
use crate::tunnel::{to_host_port, RemoteAddr};
use crate::{protocols, LocalProtocol};
use anyhow::{anyhow, Context};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::Stream;
pub struct TproxyTcpTunnelListener {
listener: TcpListenerStream,
proxy_protocol: bool,
}
impl TproxyTcpTunnelListener {
pub async fn new(bind_addr: SocketAddr, proxy_protocol: bool) -> anyhow::Result<Self> {
let listener = protocols::tcp::run_server(bind_addr, true)
.await
.with_context(|| anyhow!("Cannot start TProxy TCP server on {}", bind_addr))?;
Ok(Self {
listener,
proxy_protocol,
})
}
}
impl Stream for TproxyTcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
Some(anyhow::Ok((
stream.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// TPROXY UDP
pub struct TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
timeout: Option<Duration>,
}
pub async fn new_tproxy_udp(
bind_addr: SocketAddr,
timeout: Option<Duration>,
) -> anyhow::Result<TProxyUdpTunnelListener<impl Stream<Item = io::Result<UdpStream>>>> {
let listener = udp::run_server(bind_addr, timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy)
.await
.with_context(|| anyhow!("Cannot start TProxy UDP server on {}", bind_addr))?;
Ok(TProxyUdpTunnelListener { listener, timeout })
}
impl<S> Stream for TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,66 @@
use crate::protocols::udp;
use crate::protocols::udp::{UdpStream, UdpStreamWriter};
use crate::tunnel::RemoteAddr;
use crate::LocalProtocol;
use anyhow::{anyhow, Context};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio_stream::Stream;
use url::Host;
pub struct UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
dest: (Host, u16),
timeout: Option<Duration>,
}
pub async fn new_udp_listener(
bind_addr: SocketAddr,
dest: (Host, u16),
timeout: Option<Duration>,
) -> anyhow::Result<UdpTunnelListener<impl Stream<Item = io::Result<UdpStream>>>> {
let listener = udp::run_server(bind_addr, timeout, |_| Ok(()), |s| Ok(s.clone()))
.await
.with_context(|| anyhow!("Cannot start UDP server on {}", bind_addr))?;
Ok(UdpTunnelListener {
listener,
dest,
timeout,
})
}
impl<S> Stream for UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = this.dest.clone();
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -0,0 +1,58 @@
use crate::protocols::unix_sock;
use crate::protocols::unix_sock::UnixListenerStream;
use crate::tunnel::RemoteAddr;
use crate::LocalProtocol;
use anyhow::{anyhow, Context};
use std::path::Path;
use std::pin::Pin;
use std::task::{ready, Poll};
use tokio::net::unix;
use tokio_stream::Stream;
use url::Host;
pub struct UnixTunnelListener {
listener: UnixListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl UnixTunnelListener {
pub async fn new(path: &Path, dest: (Host, u16), proxy_protocol: bool) -> anyhow::Result<Self> {
let listener = unix_sock::run_server(path)
.await
.with_context(|| anyhow!("Cannot start Unix domain server on {}", path.display()))?;
Ok(Self {
listener,
dest,
proxy_protocol,
})
}
}
impl Stream for UnixTunnelListener {
type Item = anyhow::Result<((unix::OwnedReadHalf, unix::OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let stream = stream.into_split();
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}

View file

@ -1,9 +1,11 @@
pub mod client; pub mod client;
pub mod listeners;
pub mod server; pub mod server;
pub mod tls_reloader; pub mod tls_reloader;
mod transport; mod transport;
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig}; use crate::protocols::tls;
use crate::{protocols, LocalProtocol, TlsClientConfig, WsClientConfig};
use async_trait::async_trait; use async_trait::async_trait;
use bb8::ManageConnection; use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
@ -312,7 +314,7 @@ impl ManageConnection for WsClientConfig {
let timeout = self.timeout_connect; let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy { let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
tcp::connect_with_http_proxy( protocols::tcp::connect_with_http_proxy(
http_proxy, http_proxy,
self.remote_addr.host(), self.remote_addr.host(),
self.remote_addr.port(), self.remote_addr.port(),
@ -322,7 +324,7 @@ impl ManageConnection for WsClientConfig {
) )
.await? .await?
} else { } else {
tcp::connect( protocols::tcp::connect(
self.remote_addr.host(), self.remote_addr.host(),
self.remote_addr.port(), self.remote_addr.port(),
so_mark, so_mark,

View file

@ -15,7 +15,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::{http_proxy, socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; use crate::{protocols, socks5, LocalProtocol, TlsServerConfig, WsServerConfig};
use hyper::body::{Frame, Incoming}; use hyper::body::{Frame, Incoming};
use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::header::{CONTENT_TYPE, COOKIE, SEC_WEBSOCKET_PROTOCOL};
use hyper::http::HeaderValue; use hyper::http::HeaderValue;
@ -28,16 +28,16 @@ use once_cell::sync::Lazy;
use parking_lot::Mutex; use parking_lot::Mutex;
use socket2::SockRef; use socket2::SockRef;
use crate::protocols::udp::UdpStream;
use crate::protocols::{http_proxy, tls, udp};
use crate::restrictions::config_reloader::RestrictionsRulesReloader; use crate::restrictions::config_reloader::RestrictionsRulesReloader;
use crate::restrictions::types::{ use crate::restrictions::types::{
AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol, AllowConfig, MatchConfig, RestrictionConfig, RestrictionsRules, ReverseTunnelConfigProtocol, TunnelConfigProtocol,
}; };
use crate::socks5::Socks5Stream; use crate::socks5::Socks5Stream;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use crate::udp::UdpStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::select; use tokio::select;
@ -68,7 +68,7 @@ async fn run_tunnel(
Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx))) Ok((remote, Box::pin(cnx.clone()), Box::pin(cnx)))
} }
LocalProtocol::Tcp { proxy_protocol } => { LocalProtocol::Tcp { proxy_protocol } => {
let mut socket = tcp::connect( let mut socket = protocols::tcp::connect(
&remote.host, &remote.host,
remote.port, remote.port,
server_config.socket_so_mark, server_config.socket_so_mark,
@ -99,7 +99,7 @@ async fn run_tunnel(
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let bind = format!("{}:{}", local_srv.0, local_srv.1); let bind = format!("{}:{}", local_srv.0, local_srv.1);
let listening_server = tcp::run_server(bind.parse()?, false); let listening_server = protocols::tcp::run_server(bind.parse()?, false);
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = tcp.into_split(); let (local_rx, local_tx) = tcp.into_split();
@ -172,7 +172,7 @@ async fn run_tunnel(
} }
#[cfg(unix)] #[cfg(unix)]
LocalProtocol::ReverseUnix { ref path } => { LocalProtocol::ReverseUnix { ref path } => {
use crate::unix_socket; use protocols::unix_sock;
use tokio::net::UnixStream; use tokio::net::UnixStream;
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
@ -181,7 +181,7 @@ async fn run_tunnel(
let remote_port = find_mapped_port(remote.port, restriction); let remote_port = find_mapped_port(remote.port, restriction);
let local_srv = (remote.host, remote_port); let local_srv = (remote.host, remote_port);
let listening_server = unix_socket::run_server(path); let listening_server = unix_sock::run_server(path);
let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let stream = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = stream.into_split(); let (local_rx, local_tx) = stream.into_split();
@ -862,7 +862,7 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
} }
}; };
if let Err(err) = tcp::configure_socket(SockRef::from(&stream), &None) { if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), &None) {
warn!("Error while configuring server socket {:?}", err); warn!("Error while configuring server socket {:?}", err);
} }
@ -898,8 +898,8 @@ pub async fn run_server(server_config: Arc<WsServerConfig>, restrictions: Restri
// extract client certificate common name if any // extract client certificate common name if any
let restrict_path = tls_ctx let restrict_path = tls_ctx
.peer_certificates() .peer_certificates()
.and_then(find_leaf_certificate) .and_then(tls::find_leaf_certificate)
.and_then(|c| cn_from_certificate(&c)); .and_then(|c| tls::cn_from_certificate(&c));
match tls_ctx.alpn_protocol() { match tls_ctx.alpn_protocol() {
// http2 // http2
Some(b"h2") => { Some(b"h2") => {

View file

@ -1,5 +1,6 @@
use crate::protocols::tls;
use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server}; use crate::tunnel::tls_reloader::TlsReloaderState::{Client, Server};
use crate::{tls, WsClientConfig, WsServerConfig}; use crate::{WsClientConfig, WsServerConfig};
use anyhow::Context; use anyhow::Context;
use log::trace; use log::trace;
use notify::{EventKind, RecommendedWatcher, Watcher}; use notify::{EventKind, RecommendedWatcher, Watcher};

View file

@ -1,374 +0,0 @@
use crate::http_proxy::HttpProxyListener;
use crate::socks5::{Socks5Listener, Socks5Stream};
use crate::tunnel::{to_host_port, RemoteAddr};
use crate::udp::{UdpStream, UdpStreamWriter};
use crate::unix_socket::UnixListenerStream;
use crate::LocalProtocol;
use std::io;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::unix;
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::Stream;
use url::Host;
pub trait TunnelListener: Stream<Item = anyhow::Result<((Self::Reader, Self::Writer), RemoteAddr)>> {
type Reader: AsyncRead + Send + 'static;
type Writer: AsyncWrite + Send + 'static;
}
impl<T, R, W> TunnelListener for T
where
T: Stream<Item = anyhow::Result<((R, W), RemoteAddr)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Reader = R;
type Writer = W;
}
pub struct TcpTunnelListener {
listener: TcpListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl TcpTunnelListener {
pub fn new(listener: TcpListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener,
dest,
proxy_protocol,
}
}
}
impl Stream for TcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(strean)) => {
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
strean.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// TPROXY
pub struct TproxyTcpTunnelListener {
listener: TcpListenerStream,
proxy_protocol: bool,
}
impl TproxyTcpTunnelListener {
pub fn new(listener: TcpListenerStream, proxy_protocol: bool) -> Self {
Self {
listener,
proxy_protocol,
}
}
}
impl Stream for TproxyTcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
Some(anyhow::Ok((
stream.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// UNIX
pub struct UnixTunnelListener {
listener: UnixListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl UnixTunnelListener {
pub fn new(listener: UnixListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener,
dest,
proxy_protocol,
}
}
}
impl Stream for UnixTunnelListener {
type Item = anyhow::Result<((unix::OwnedReadHalf, unix::OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let stream = stream.into_split();
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// TPROXY UDP
pub struct TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
timeout: Option<Duration>,
}
impl<S> TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
pub fn new(listener: S, timeout: Option<Duration>) -> Self {
Self { listener, timeout }
}
}
impl<S> Stream for TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
pub struct UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
dest: (Host, u16),
timeout: Option<Duration>,
}
impl<S> UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
pub fn new(listener: S, dest: (Host, u16), timeout: Option<Duration>) -> Self {
Self {
listener,
dest,
timeout,
}
}
}
impl<S> Stream for UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = this.dest.clone();
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
pub struct Socks5TunnelListener {
listener: Socks5Listener,
}
impl Socks5TunnelListener {
pub fn new(listener: Socks5Listener) -> Self {
Self { listener }
}
}
impl Stream for Socks5TunnelListener {
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = stream.local_protocol();
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}
pub struct HttpProxyTunnelListener {
listener: HttpProxyListener,
proxy_protocol: bool,
}
impl HttpProxyTunnelListener {
pub fn new(listener: HttpProxyListener, proxy_protocol: bool) -> Self {
Self {
listener,
proxy_protocol,
}
}
}
impl Stream for HttpProxyTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
};
Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}
pub struct StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
listener: Option<(R, W)>,
dest: (Host, u16),
proxy_protocol: bool,
}
impl<R, W> StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pub fn new(listener: (R, W), dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener: Some(listener),
proxy_protocol,
dest,
}
}
}
impl<R, W> Stream for StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Item = anyhow::Result<((R, W), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = match this.listener.take() {
None => None,
Some(stream) => {
let (host, port) = this.dest.clone();
Some(Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
};
Poll::Ready(ret)
}
}