Former-commit-id: d125471a7e73cbde30e1d8cb42a9e6d7aac10131 [formerly d14a90382da6197691d28f61151f1278dca23a53] [formerly 51e8380286fd2a4dc2ff577a507d0df2356b1e79 [formerly 0cd5c5c0eaa4a0538a566ba9e6bb5d925da77c1a]]
Former-commit-id: b4b5769ad601c8cce35047a3e16ff185e228ea41 [formerly 4c6f9e6bd777c187a240b0a6119c2a4eaa396da4]
Former-commit-id: 2aa31860ffe6a5a51f0148527598a7399f968801
Former-commit-id: a826342ca74b913ce45171f92bc6b9d19ac8db08
Former-commit-id: fc039d0217ff4d0c47048755da4f833d86568586
Former-commit-id: d742a7134f042fd67fb1b9399490473babef28d9 [formerly 7e2ea8487c5ce2a5fb39b015eb6d00d8a48654c6]
Former-commit-id: f4273cd0403ef19d4cf19c861fa4d730ab10b29d
This commit is contained in:
Σrebe - Romain GERARD 2023-10-01 17:16:23 +02:00
parent 8c611e9149
commit 8387557459
13 changed files with 1874 additions and 637 deletions

7
.cargo/config.toml Normal file
View file

@ -0,0 +1,7 @@
[build]
rustflags = ["--cfg", "uuid_unstable"]
#[target.'cfg(target_os = "linux")']
#rustflags = ["-C", "linker=ld.lld", "-C", "relocation-model=static", "-C", "strip=symbols", "--cfg", "uuid_unstable"]
#[build]

27
.gitignore vendored
View file

@ -1,21 +1,10 @@
dist # Generated by Cargo
cabal-dev # will have compiled files and executables
*.o debug/
*.hi target/
*.chi
*.chs.h
.virtualenv
.hsenv
.cabal-sandbox/
cabal.sandbox.config
cabal.config
*.log
tags
bin/
*~
.stack-work
# These are backup files generated by rustfmt
**/*.rs.bk
# Added by cargo # MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
/target

814
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,12 +8,36 @@ edition = "2021"
[dependencies] [dependencies]
clap = { version = "4.4.5", features = ["derive"]} clap = { version = "4.4.5", features = ["derive"]}
url = "2.4.1" url = "2.4.1"
anyhow = "1.0.75"
reqwest = { version = "0.11.20", features = ["stream", "trust-dns"] } hyper = { version = "0.14.27", features = ["client", "runtime"] }
hyper = { version = "1.0.0-rc.4", features = ["client", "http2"] } #fastwebsockets = { version = "0.4.4", features = ["upgrade"]}
hyper-openssl = {version = "0.9.2", features = []} fastwebsockets = { git = "https://github.com/mmastrac/fastwebsockets", branch = "split", features = ["upgrade", "simd"]}
libc = { version = "0.2.148", features = []}
once_cell = { version = "1.18.0", features = [] }
ahash = { version = "0.8.3", features = []}
pin-project = "1"
scopeguard = "1.2.0"
uuid = { version = "1.4.1", features = ["v7", "serde"] }
jsonwebtoken = { version = "8.3.0", default-features = false }
rustls-pemfile = { version = "1.0.3", features = [] }
rustls-native-certs = { version = "0.6.3", features = [] }
tokio = { version = "1.32.0", features = ["full"] } tokio = { version = "1.32.0", features = ["full"] }
tokio-rustls = { version = "0.24.1", features = ["tls12", "dangerous_configuration", "early-data"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
tokio-fd = "0.3.0"
futures-util = { version = "0.3.28" }
tracing = { version = "0.1.37", features = ["log"] } tracing = { version = "0.1.37", features = ["log"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "local-time"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "local-time"] }
base64 = "0.21.4"
serde = { version = "1.0.188", features = ["derive"] }
[profile.release]
lto = "fat"
panic = "abort"
codegen-units = 1
opt-level = 3

21
certs/cert.pem Normal file
View file

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDgzCCAmugAwIBAgIUSFStqIolH/v5Mp2u8dNw2kHDEUowDQYJKoZIhvcNAQEF
BQAwUDELMAkGA1UEBhMCRlIxDjAMBgNVBAgMBVBhcmlzMQ4wDAYDVQQHDAVQYXJp
czEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMCAXDTIzMTAxNTEz
Mzk1NVoYDzIwNTEwMzAxMTMzOTU1WjBQMQswCQYDVQQGEwJGUjEOMAwGA1UECAwF
UGFyaXMxDjAMBgNVBAcMBVBhcmlzMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRz
IFB0eSBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD2fM5T4YzD
a4By7hHMrTL1BYWgr7OUcYDGAuzZiXEKkvc/zuiYHtek2n/hOvZCPp0pba4tOlfM
6BA4qK1cgabH95Q/RfSqttjHoH5hgXUrUZ5YI0n9/7XZRnv5idO5dYqAHElRX70H
YVHzU+xwvzfmWPI6QFMF9lhHQzSQyN7P3iZc97nLtTCxmVg0Wgo103CQ3J8Sop07
+uZuzkPCkgW8eVEoMPTDZ/pChdW2lvJxGs4BQu92UC73XPhRECbAxXA7JwXxegYl
K2pJcNcJWIyRGLEbaVxRMPiBpbIfJbU1nNoSlgGKJb8GuhVK7y4eRxLnOvnLGCFp
dl9c3o6iPYH/AgMBAAGjUzBRMB0GA1UdDgQWBBS4Y1uJ52HbmP1YLWETMcVn7fI/
SzAfBgNVHSMEGDAWgBS4Y1uJ52HbmP1YLWETMcVn7fI/SzAPBgNVHRMBAf8EBTAD
AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQC8R9bx8P1TQsfNIqHhRuSss623VCdPPMgt
uJzXsZVYTfKizIo8nIWpy2y+RpJFpgB26XtrBORwZmc+pDjiABInZxUYoQEMmz7K
gc6OBAeweVD3QNcxqfO+NLft6tP6r3aqDjfF0w358LbuIRGRE34e5wdYBKqNmcu5
Bh9XcWCL7mP3aq+Sl0340Zl+/rPi0sLMNohEYTX6+/XB7qM27Cq/JDJhxGVdKRxO
nv/K02yKpY/C+8tJRJ86v5gTFfDtjGpu9EmDhtGCnpeqX55uE4pgeKUdkNgGviKD
BTizUWSqnkkuqQdZ+DGT4HVXvKHYyWswbHN19huq7SZK17SOz19o
-----END CERTIFICATE-----

28
certs/key.pem Normal file
View file

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD2fM5T4YzDa4By
7hHMrTL1BYWgr7OUcYDGAuzZiXEKkvc/zuiYHtek2n/hOvZCPp0pba4tOlfM6BA4
qK1cgabH95Q/RfSqttjHoH5hgXUrUZ5YI0n9/7XZRnv5idO5dYqAHElRX70HYVHz
U+xwvzfmWPI6QFMF9lhHQzSQyN7P3iZc97nLtTCxmVg0Wgo103CQ3J8Sop07+uZu
zkPCkgW8eVEoMPTDZ/pChdW2lvJxGs4BQu92UC73XPhRECbAxXA7JwXxegYlK2pJ
cNcJWIyRGLEbaVxRMPiBpbIfJbU1nNoSlgGKJb8GuhVK7y4eRxLnOvnLGCFpdl9c
3o6iPYH/AgMBAAECggEALdKa6uYh7Ix2JyeSAIpsUDe0FWDEkkKdjXIqxPAzpyMW
OvMEs478SOXj4yO6dys7vWFqAXd4rhuwNFBLVki2EDO7CB5Bs2DloQr5o7fU5/Y2
6Sy6SzF4BYoAby4Lwc0Tr+hSSwHw2sfhW8qMyJML2dNMSL7/kDqxQ6I/SfFF1r+Y
3LKaC98/jxiIco5Cgabd065x2NVOshWzIkY5xuvCYjfQlEJWKHbuxIIrcJXApEe4
pmexuK8VVb/Prm6Ci1+hsWOgkXuv/3EUZNxeQ87kek7Ggw3U1CXLAJ5H+FuNEVuy
mbmfX34GwKeC9tq/4zQifFS0BLaP4ND3AAo2rTvgDQKBgQD8ytKguW9yLVWcX7dl
ncEXbSKEycMfrvqJw8NvTt/9O/Uto2ri1JAmFBJ5m9tPZpujOSkdPXRmFx+lFyt+
XlkJrn1BYfZYoxMkUp5qbVF3tk32mLZM0K1yyxb5XEfcykbA1S8eLDt7F4h558Sp
e2+K60klFDB6b4Yil/a9aN0QIwKBgQD5nYCvKStaw+3YnX7TnuycbfedYiMAX7kC
1O7HGzEo+gDm4wvCF2pXPsSNJCG9c9KIpnQQrtL68MtAlVgRGa/GNM4emEVgpLtE
W6SFVkEBb3JYY1JAB5umxwO0TcFwTn3ivy2QtrFphWSG8gsZnMIrIfui3X3llY6J
Cu8iqvd2dQKBgD3R7/6EOr/mXEhYlAYStTTgaI+ms8Qcy4JDUJj45ggM0KGvlCUS
rInTYM1CkzhwtGEPSoGvFLcesotyBh3qPsYCWPlTVqZIgxbf6YPHZiPrfldu8y4H
3lLzXZPvwFc7VGA2AkbTtFwe3i5Jwqtb12RWs9WQgWZ/vYLaPOoHKgCXAoGBAKY/
uZJwCBkWv5XzJ6JIieyR7UZcM1Wva2iwaywfNznEcM9WTuGBeOkMvBoJA5PLzWAI
BOuLlKdfsu+byCDzi7emOdX0sthwPu2DX+sSjI8pK+4kkIZmystkZ1oyI3DqRju7
+twUYcsW9eJO2QfA+S2DH7bUcGJ1no41wxnC5rh1AoGBALjSLlnJtumrsiZLUeRB
rf4n2QIiXd+7tn2ZKSNYn/621hCHwl6peAI3G57Avxixf7Yv4wAOaIgSWUrenmJ9
DbwM06AZqnLwraM1B6c3tujOfjIFU7IjENdELxrTnVVYq9a7HDJzT6ty7m1A7eDA
9cEnUJ+dG5G05eV44S4Z/6sF
-----END PRIVATE KEY-----

View file

@ -0,0 +1,16 @@
use once_cell::sync::Lazy;
use tokio_rustls::rustls::{Certificate, PrivateKey};
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
let key = include_bytes!("../certs/key.pem");
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice())
.expect("failed to load embedded tls private key");
PrivateKey(keys.remove(0))
});
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
let cert = include_bytes!("../certs/cert.pem");
let certs = rustls_pemfile::certs(&mut cert.as_slice())
.expect("failed to load embedded tls certificate");
certs.into_iter().map(Certificate).collect()
});

View file

@ -1,20 +1,42 @@
mod embedded_certificate;
mod stdio;
mod tcp;
mod tls;
mod transport;
mod udp;
use base64::Engine;
use clap::Parser;
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
use hyper::http::HeaderValue;
use serde::{Deserialize, Serialize};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::BTreeMap; use std::collections::{BTreeMap, HashMap};
use std::io;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use clap::Parser; use tokio::io::{AsyncRead, AsyncWrite};
use hyper::body::Body; use tokio::net::TcpStream;
use hyper::Request;
use hyper_openssl::HttpsConnector;
use url::{Host, Url, UrlQuery};
/// Simple program to greet a person use tokio_rustls::rustls::server::DnsName;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
use tracing::{debug, error, field, instrument, Instrument, Span};
use tracing_subscriber::EnvFilter;
use url::{Host, Url};
use uuid::Uuid;
/// Use the websockets protocol to tunnel {TCP,UDP} traffic
/// wsTunnelClient <---> wsTunnelServer <---> RemoteHost
/// Use secure connection (wss://) to bypass proxies
#[derive(clap::Parser, Debug)] #[derive(clap::Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, verbatim_doc_comment, long_about = None)]
struct Wstunnel { struct Wstunnel {
#[command(subcommand)] #[command(subcommand)]
commands: Commands, commands: Commands,
} }
@ -22,139 +44,567 @@ struct Wstunnel {
#[derive(clap::Subcommand, Debug)] #[derive(clap::Subcommand, Debug)]
enum Commands { enum Commands {
Client(Client), Client(Client),
Server(Server) Server(Server),
} }
#[derive(clap::Args, Debug)] #[derive(clap::Args, Debug)]
struct Client { struct Client {
/// Name of the person to greet /// Listen on local and forwards traffic from remote
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)] /// Can be specified multiple times
#[arg(short='L', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
local_to_remote: Vec<LocalToRemote>, local_to_remote: Vec<LocalToRemote>,
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
#[arg(long, value_name = "INT", verbatim_doc_comment)]
socket_so_mark: Option<u32>,
/// Domain name that will be use as SNI during TLS handshake
/// Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header.
/// or it will be flag as fishy as your request rejected
#[arg(long, value_name = "DOMAIN_NAME", value_parser = parse_sni_override, verbatim_doc_comment)]
tls_sni_override: Option<DnsName>,
/// Enable TLS certificate verification.
/// Disabled by default. The client will happily connect to any server with self signed certificate.
#[arg(long, verbatim_doc_comment)]
tls_verify_certificate: bool,
/// Use a specific prefix that will show up in the http path during the upgrade request.
/// Useful if you need to route requests server side but don't have vhosts
#[arg(long, default_value = "morille", verbatim_doc_comment)]
http_upgrade_path_prefix: String,
/// Pass authorization header with basic auth credentials during the upgrade request.
/// If you need more customization, you can use the http_headers option.
#[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)]
http_upgrade_credentials: Option<HeaderValue>,
/// Frequency at which the client will send websocket ping to the server.
#[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)]
websocket_ping_frequency_sec: Option<Duration>,
/// Enable the masking of websocket frames. Default is false
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
#[arg(long, default_value = "false", verbatim_doc_comment)]
websocket_mask_frame: bool,
/// Send custom headers in the upgrade request
/// Can be specified multiple time
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
http_headers: Vec<(String, HeaderValue)>,
/// Address of the wstunnel server
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
#[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
remote_addr: Url,
} }
#[derive(clap::Args, Debug)] #[derive(clap::Args, Debug)]
struct Server { struct Server {
/// Name of the person to greet /// Address of the wstunnel server to bind to
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)] /// Example: With TLS wss://0.0.0.0:8080 or without ws://[::]:8080
local_to_remote: String, #[arg(value_name = "ws[s]://0.0.0.0[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
remote_addr: Url,
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
#[arg(long, value_name = "INT", verbatim_doc_comment)]
socket_so_mark: Option<i32>,
/// Frequency at which the server will send websocket ping to client.
#[arg(long, value_name = "seconds", value_parser = parse_duration_sec, verbatim_doc_comment)]
websocket_ping_frequency_sec: Option<Duration>,
/// Enable the masking of websocket frames. Default is false
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
#[arg(long, default_value = "false", verbatim_doc_comment)]
websocket_mask_frame: bool,
/// Server will only accept connection from the specified tunnel information.
/// Can be specified multiple time
/// Example: --restrict-to "google.com:443" --restrict-to "localhost:22"
#[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)]
restrict_to: Option<Vec<String>>,
/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_certificate: Option<PathBuf>,
/// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_private_key: Option<PathBuf>,
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
enum L4Protocol { enum L4Protocol {
TCP, UDP { timeout: Duration } Tcp,
Udp { timeout: Option<Duration> },
Stdio,
} }
impl L4Protocol { impl L4Protocol {
fn new_udp() -> L4Protocol { fn new_udp() -> L4Protocol {
L4Protocol::UDP { timeout: Duration::from_secs(30) } L4Protocol::Udp {
timeout: Some(Duration::from_secs(30)),
}
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct LocalToRemote { pub struct LocalToRemote {
socket_so_mark: Option<i32>,
protocol: L4Protocol, protocol: L4Protocol,
local: SocketAddr, local: SocketAddr,
remote: (Host<String>, u16), remote: (Host<String>, u16),
} }
fn parse_env_var(arg: &str) -> Result<LocalToRemote, std::io::Error> { fn parse_duration_sec(arg: &str) -> Result<Duration, io::Error> {
use std::io::Error;
let Ok(secs) = arg.parse::<u64>() else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot duration of seconds from {}", arg),
));
};
Ok(Duration::from_secs(secs))
}
fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
use std::io::Error; use std::io::Error;
let (mut protocol, arg) = match &arg[..6] { let (mut protocol, arg) = match &arg[..6] {
"tcp://" => (L4Protocol::TCP, &arg[6..]), "tcp://" => (L4Protocol::Tcp, &arg[6..]),
"udp://" => (L4Protocol::new_udp(), &arg[6..]), "udp://" => (L4Protocol::new_udp(), &arg[6..]),
_ => (L4Protocol::TCP, arg) _ => match &arg[..8] {
"stdio://" => (L4Protocol::Stdio, &arg[8..]),
_ => (L4Protocol::Tcp, arg),
},
}; };
let (bind, remaining) = if arg.starts_with('[') { let (bind, remaining) = if arg.starts_with('[') {
// ipv6 bind // ipv6 bind
let Some((ipv6_str, remaining)) = arg.split_once(']') else { let Some((ipv6_str, remaining)) = arg.split_once(']') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", arg))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse IPv6 bind from {}", arg),
));
}; };
let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else { let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", ipv6_str))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse IPv6 bind from {}", ipv6_str),
));
}; };
(IpAddr::V6(ipv6_addr), remaining) (IpAddr::V6(ipv6_addr), remaining)
} else { } else {
// Maybe ipv4 addr // Maybe ipv4 addr
let Some((ipv4_str, remaining)) = arg.split_once(':') else { let Some((ipv4_str, remaining)) = arg.split_once(':') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv4 bind from {}", arg))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse IPv4 bind from {}", arg),
));
}; };
match Ipv4Addr::from_str(ipv4_str) { match Ipv4Addr::from_str(ipv4_str) {
Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining), Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining),
// Must be the port, so we default to ipv6 bind // Must be the port, so we default to ipv6 bind
Err(_) => (IpAddr::V6(Ipv6Addr::from_str("::1").unwrap()), arg) Err(_) => (IpAddr::V4(Ipv4Addr::from_str("127.0.0.1").unwrap()), arg),
} }
}; };
let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else { let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", remaining))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse bind port from {}", remaining),
));
}; };
let Ok(bind_port): Result<u16, _> = port_str.parse() else { let Ok(bind_port): Result<u16, _> = port_str.parse() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", port_str))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse bind port from {}", port_str),
));
}; };
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else { let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote from {}", remaining))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse remote from {}", remaining),
));
}; };
let Some(remote_host) = remote.host() else { let Some(remote_host) = remote.host() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote host from {}", remaining))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse remote host from {}", remaining),
));
}; };
let Some(remote_port) = remote.port() else { let Some(remote_port) = remote.port() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote port from {}", remaining))); return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse remote port from {}", remaining),
));
}; };
match &mut protocol {
L4Protocol::TCP => {}
L4Protocol::UDP { ref mut timeout, .. } => {
let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect(); let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
if let Some(duration) = options.get("timeout_sec") match &mut protocol {
L4Protocol::Stdio => {}
L4Protocol::Tcp => {}
L4Protocol::Udp {
ref mut timeout, ..
} => {
if let Some(duration) = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok()) .and_then(|x| x.parse::<u64>().ok())
.map(|x| Duration::from_secs(x)) { .map(|d| {
if d == 0 {
None
} else {
Some(Duration::from_secs(d))
}
})
{
*timeout = duration; *timeout = duration;
} }
} }
}; };
Ok(LocalToRemote { Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
protocol, protocol,
local: SocketAddr::new(bind, bind_port), local: SocketAddr::new(bind, bind_port),
remote: (remote_host.to_owned(), remote_port) remote: (remote_host.to_owned(), remote_port),
}) })
} }
fn main() { fn parse_sni_override(arg: &str) -> Result<DnsName, io::Error> {
println!("Hello, world!"); match DnsName::try_from(arg.to_string()) {
Ok(val) => Ok(val),
Err(err) => Err(io::Error::new(
ErrorKind::InvalidInput,
format!("Invalid sni override: {}", err),
)),
}
}
fn parse_http_headers(arg: &str) -> Result<(String, HeaderValue), io::Error> {
let Some((key, value)) = arg.split_once(':') else {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("cannot parse http header from {}", arg),
));
};
let value = match HeaderValue::from_str(value.trim()) {
Ok(value) => value,
Err(err) => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!(
"cannot parse http header value from {} due to {:?}",
value, err
),
))
}
};
Ok((key.to_owned(), value))
}
fn parse_http_credentials(arg: &str) -> Result<HeaderValue, io::Error> {
let encoded = base64::engine::general_purpose::STANDARD.encode(arg.trim().as_bytes());
let Ok(header) = HeaderValue::from_str(&format!("Basic {}", encoded)) else {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("cannot parse http credentials {}", arg),
));
};
Ok(header)
}
fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
let Ok(url) = Url::parse(arg) else {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("cannot parse server url {}", arg),
));
};
if url.scheme() != "ws" && url.scheme() != "wss" {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("invalid scheme {}", url.scheme()),
));
}
if url.host().is_none() {
return Err(io::Error::new(
ErrorKind::InvalidInput,
format!("invalid server host {}", arg),
));
}
Ok(url)
}
#[derive(Clone, Debug)]
pub struct TlsClientConfig {
pub tls_sni_override: Option<DnsName>,
pub tls_verify_certificate: bool,
}
#[derive(Clone, Debug)]
pub struct TlsServerConfig {
pub tls_certificate: Vec<Certificate>,
pub tls_key: PrivateKey,
}
#[derive(Clone, Debug)]
pub struct WsServerConfig {
pub socket_so_mark: Option<i32>,
pub bind: SocketAddr,
pub restrict_to: Option<Vec<String>>,
pub websocket_ping_frequency: Option<Duration>,
pub timeout_connect: Duration,
pub websocket_mask_frame: bool,
pub tls: Option<TlsServerConfig>,
}
#[derive(Clone, Debug)]
pub struct WsClientConfig {
pub remote_addr: (Host<String>, u16),
pub tls: Option<TlsClientConfig>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<String, HeaderValue>,
pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
}
impl WsClientConfig {
pub fn websocket_scheme(&self) -> &'static str {
match self.tls {
None => "ws",
Some(_) => "wss",
}
}
pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.0, self.remote_addr.1)
}
pub fn tls_server_name(&self) -> ServerName {
match self
.tls
.as_ref()
.and_then(|tls| tls.tls_sni_override.as_ref())
{
None => match &self.remote_addr.0 {
Host::Domain(domain) => {
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
}
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
},
Some(sni_override) => ServerName::DnsName(sni_override.clone()),
}
}
}
#[tokio::main]
async fn main() {
let args = Wstunnel::parse(); let args = Wstunnel::parse();
println!("Hello {:?}!", args) // Setup logging
match &args.commands {
// Disable logging if there is a stdio tunnel
Commands::Client(args)
if args
.local_to_remote
.iter()
.filter(|x| x.protocol == L4Protocol::Stdio)
.count()
> 0 => {}
_ => {
tracing_subscriber::fmt()
.with_ansi(true)
.with_env_filter(EnvFilter::from_default_env())
.init();
}
}
let client = reqwest::Client::builder() match args.commands {
.timeout(Duration::from_secs(10)) Commands::Client(args) => {
.connect_timeout(Duration::from_secs(10)) let tls = match args.remote_addr.scheme() {
.danger_accept_invalid_certs(true) "ws" => None,
.build().unwrap(); "wss" => Some(TlsClientConfig {
tls_sni_override: args.tls_sni_override,
tls_verify_certificate: args.tls_verify_certificate,
}),
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
};
let server_config = Arc::new(WsClientConfig {
let mut conn = HttpsConnector::new()?; remote_addr: (
conn.set_callback(move |c, _| { args.remote_addr.host().unwrap().to_owned(),
// Prevent native TLS lib from inferring and verifying a default SNI. args.remote_addr.port_or_known_default().unwrap(),
c.set_use_server_name_indication(false); ),
c.set_verify_hostname(false); tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
// And set a custom SNI instead. http_upgrade_credentials: args.http_upgrade_credentials,
c.set_hostname("somewhere.com") http_headers: args.http_headers.into_iter().collect(),
timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args
.websocket_ping_frequency_sec
.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame,
}); });
Client::builder()
.build::<_, Body>(conn)
.request(Request::get("somewhere-else.com").body(())?)
.await?;
reqwest::Proxy::all("https://google.com").unwrap().basic_auth("", "") // Start tunnels
for tunnel in args.local_to_remote.into_iter() {
let server_config = server_config.clone();
match &tunnel.protocol {
L4Protocol::Tcp => {
let server = tcp::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
})
.map_ok(TcpStream::into_split);
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
L4Protocol::Udp { timeout } => {
let server = udp::run_server(tunnel.local, *timeout)
.await
.unwrap_or_else(|err| {
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
})
.map_ok(tokio::io::split);
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
L4Protocol::Stdio => {
let server = stdio::run_server().await.unwrap_or_else(|err| {
panic!("Cannot start STDIO server: {}", err);
});
tokio::spawn(async move {
if let Err(err) = run_tunnel(
server_config,
tunnel,
stream::once(async move { Ok(server) }),
)
.await
{
error!("{:?}", err);
}
});
}
}
}
}
Commands::Server(args) => {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
tls::load_certificates_from_pem(&cert_path)
.expect("Cannot load tls certificate")
} else {
embedded_certificate::TLS_CERTIFICATE.clone()
};
let tls_key = if let Some(key_path) = args.tls_private_key {
tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key")
} else {
embedded_certificate::TLS_PRIVATE_KEY.clone()
};
Some(TlsServerConfig {
tls_certificate,
tls_key,
})
} else {
None
};
let server_config = WsServerConfig {
socket_so_mark: args.socket_so_mark,
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
restrict_to: args.restrict_to,
websocket_ping_frequency: args.websocket_ping_frequency_sec,
timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config,
};
debug!("{:?}", server_config);
transport::run_server(Arc::new(server_config))
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err);
});
}
}
tokio::signal::ctrl_c().await.unwrap();
}
#[instrument(name="tunnel", level="info", skip_all, fields(id=field::Empty, remote=field::Empty))]
async fn run_tunnel<T, R, W>(
server_config: Arc<WsClientConfig>,
tunnel: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = io::Result<(R, W)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let span = Span::current();
let request_id = Uuid::now_v7();
span.record("id", request_id.to_string());
span.record(
"remote",
&format!("{}:{}", tunnel.remote.0, tunnel.remote.1),
);
let tunnel = Arc::new(tunnel);
pin_mut!(incoming_cnx);
while let Some(Ok(cnx_stream)) = incoming_cnx.next().await {
let server_config = server_config.clone();
let tunnel = tunnel.clone();
tokio::spawn(
async move {
let ret =
transport::connect_to_server(request_id, &server_config, &tunnel, cnx_stream)
.await;
if let Err(ret) = ret {
error!("{:?}", ret);
}
anyhow::Ok(())
}
.instrument(span.clone()),
);
}
Ok(())
} }

19
src/stdio.rs Normal file
View file

@ -0,0 +1,19 @@
#![allow(unused_imports)]
use libc::STDIN_FILENO;
use std::os::fd::{AsRawFd, FromRawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::fs::File;
use tokio::io::{stdout, AsyncRead, ReadBuf, Stdout};
use tokio_fd::AsyncFd;
use tracing::info;
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
info!("Starting STDIO server");
let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?;
Ok((stdin, stdout))
}

110
src/tcp.rs Normal file
View file

@ -0,0 +1,110 @@
use anyhow::{anyhow, Context};
use std::{io, vec};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::fd::AsRawFd;
use std::time::Duration;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::time::timeout;
use tokio_stream::wrappers::TcpListenerStream;
use tracing::debug;
use tracing::log::info;
use url::Host;
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
socket.set_nodelay(true).with_context(|| {
format!(
"cannot set no_delay on socket: {}",
io::Error::last_os_error()
)
})?;
if let Some(so_mark) = so_mark {
unsafe {
let optval: libc::c_int = *so_mark;
let ret = libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_MARK,
&optval as *const _ as *const libc::c_void,
std::mem::size_of_val(&optval) as libc::socklen_t,
);
if ret != 0 {
return Err(anyhow!(
"Cannot set SO_MARK on the connection {:?}",
io::Error::last_os_error()
));
}
}
}
Ok(())
}
pub async fn connect(
host: &Host<String>,
port: u16,
so_mark: &Option<i32>,
connect_timeout: Duration,
) -> Result<TcpStream, anyhow::Error> {
info!("Opening TCP connection to {}:{}", host, port);
// TODO: Avoid allocation of vec by extracting the code that does the connection in a separate function
let socket_addrs: Vec<SocketAddr> = match host {
Host::Domain(domain) => tokio::net::lookup_host(format!("{}:{}", domain, port))
.await
.with_context(|| format!("cannot resolve domain: {}", domain))?
.collect(),
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
};
let mut cnx = None;
let mut last_err = None;
for addr in socket_addrs {
debug!("connecting to {}", addr);
let mut socket = match &addr {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
configure_socket(&mut socket, so_mark)?;
match timeout(connect_timeout, socket.connect(addr)).await {
Ok(Ok(stream)) => {
cnx = Some(stream);
break;
}
Ok(Err(err)) => {
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
last_err = Some(err);
}
Err(_) => {
debug!(
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
connect_timeout.as_secs()
);
}
}
}
if let Some(cnx) = cnx {
Ok(cnx)
} else {
Err(anyhow!(
"Cannot connect to tcp endpoint {}:{} reason {:?}",
host,
port,
last_err
))
}
}
pub async fn run_server(bind: SocketAddr) -> Result<TcpListenerStream, anyhow::Error> {
info!("Starting TCP server listening cnx on {}", bind);
let listener = TcpListener::bind(bind)
.await
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
Ok(TcpListenerStream::new(listener))
}

125
src/tls.rs Normal file
View file

@ -0,0 +1,125 @@
use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
use anyhow::{anyhow, Context};
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName};
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use tracing::info;
pub struct NullVerifier;
impl ServerCertVerifier for NullVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<Certificate>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)?;
Ok(certs.into_iter().map(Certificate).collect())
}
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
match keys.len() {
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
1 => Ok(PrivateKey(keys.remove(0))),
_ => Err(anyhow!(
"More than one PKCS8-encoded private key found in {path:?}"
)),
}
}
pub fn tls_connector(
tls_cfg: &TlsClientConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs()
.with_context(|| "Cannot load system certificates")?;
for cert in certs {
root_store.add(&Certificate(cert.0)).unwrap();
}
let mut config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// To bypass certificate verification
if !tls_cfg.tls_verify_certificate {
config
.dangerous()
.set_certificate_verifier(Arc::new(NullVerifier));
}
if let Some(alpn_protocols) = alpn_protocols {
config.alpn_protocols = alpn_protocols;
}
let tls_connector = TlsConnector::from(Arc::new(config));
Ok(tls_connector)
}
pub fn tls_acceptor(
tls_cfg: &TlsServerConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsAcceptor> {
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
.with_context(|| "invalid tls certificate or private key")?;
if let Some(alpn_protocols) = alpn_protocols {
config.alpn_protocols = alpn_protocols;
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn connect(
server_cfg: &WsClientConfig,
tls_cfg: &TlsClientConfig,
tcp_stream: TcpStream,
) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = server_cfg.tls_server_name();
info!(
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
server_cfg.remote_addr.0, server_cfg.remote_addr.1
);
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
let tls_stream = tls_connector
.connect(sni, tcp_stream)
.await
.with_context(|| {
format!(
"failed to do TLS handshake with the server {:?}",
server_cfg.remote_addr
)
})?;
Ok(tls_stream)
}

503
src/transport.rs Normal file
View file

@ -0,0 +1,503 @@
#![allow(unused_imports)]
use std::collections::HashSet;
use std::future::Future;
use std::net::Ipv4Addr;
use std::ops::{Deref, Not};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use crate::{tcp, tls, L4Protocol, LocalToRemote, WsClientConfig, WsServerConfig};
use anyhow::Context;
use fastwebsockets::upgrade::UpgradeFut;
use fastwebsockets::{
Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite,
};
use futures_util::{pin_mut, StreamExt};
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE, X_FRAME_OPTIONS};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::upgrade::Upgraded;
use hyper::{http, Body, Request, Response, StatusCode};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
use once_cell::sync::Lazy;
use tokio::io::{
AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest, ReadHalf, WriteHalf,
};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::select;
use tokio::sync::oneshot;
use tokio::time::error::Elapsed;
use tokio::time::timeout;
use crate::udp::{MyUdpSocket, UdpStream};
use serde::{Deserialize, Serialize};
use tokio_rustls::TlsAcceptor;
use tracing::log::debug;
use tracing::{error, field, info, instrument, trace, warn, Instrument, Span};
use url::quirks::host;
use url::Host;
use uuid::Uuid;
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String,
pub p: L4Protocol,
pub r: String,
pub rp: u16,
}
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
(
Header::new(Algorithm::HS256),
EncodingKey::from_secret(JWT_SECRET),
)
});
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});
pub async fn connect(
request_id: Uuid,
server_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
) -> anyhow::Result<WebSocket<Upgraded>> {
let (host, port) = &server_cfg.remote_addr;
let tcp_stream = tcp::connect(
host,
*port,
&tunnel_cfg.socket_so_mark,
server_cfg.timeout_connect,
)
.await?;
let data = JwtTunnelConfig {
id: request_id.to_string(),
p: tunnel_cfg.protocol,
r: tunnel_cfg.remote.0.to_string(),
rp: tunnel_cfg.remote.1,
};
let (alg, secret) = JWT_KEY.deref();
let mut req = Request::builder()
.method("GET")
.uri(format!(
"/{}/events?bearer={}",
&server_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
))
.header(HOST, server_cfg.remote_addr.0.to_string())
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13")
.version(hyper::Version::HTTP_11);
for (k, v) in &server_cfg.http_headers {
req = req.header(k.clone(), v.clone());
}
if let Some(auth) = &server_cfg.http_upgrade_credentials {
req = req.header(AUTHORIZATION, auth.clone());
}
let req = req.body(Body::empty()).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
server_cfg.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
let ws_handshake = match &server_cfg.tls {
None => fastwebsockets::handshake::client(&SpawnExecutor, req, tcp_stream).await,
Some(tls_cfg) => {
let tls_stream = tls::connect(server_cfg, tls_cfg, tcp_stream).await?;
fastwebsockets::handshake::client(&SpawnExecutor, req, tls_stream).await
}
};
let (ws, _) = ws_handshake.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
server_cfg.remote_addr
)
})?;
Ok(ws)
}
pub async fn connect_to_server<R, W>(
request_id: Uuid,
server_config: &WsClientConfig,
remote_cfg: &LocalToRemote,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let mut ws = connect(request_id, server_config, remote_cfg).await?;
ws.set_auto_apply_mask(server_config.websocket_mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = server_config.websocket_ping_frequency;
tokio::spawn(
propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = propagate_write(local_tx, ws_rx, close_rx, server_config.timeout_connect).await;
Ok(())
}
async fn from_query(
server_config: &WsServerConfig,
query: &str,
) -> anyhow::Result<(
L4Protocol,
Host,
u16,
Pin<Box<dyn AsyncRead + Send>>,
Pin<Box<dyn AsyncWrite + Send>>,
)> {
let jwt: TokenData<JwtTunnelConfig> = match query.split_once('=') {
Some(("bearer", jwt)) => {
let (validation, decode_key) = JWT_DECODE.deref();
match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt,
err => {
error!("error while decoding jwt for tunnel info {:?}", err);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
}
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
};
Span::current().record("id", jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Some(allowed_dests) = &server_config.restrict_to {
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests
.iter()
.any(|dest| dest == &requested_dest)
.not()
{
warn!(
"Rejecting connection with not allowed destination: {}",
requested_dest
);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
match jwt.claims.p {
L4Protocol::Udp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let cnx = Arc::new(UdpSocket::bind("[::]:0").await?);
cnx.connect((host.to_string(), jwt.claims.rp)).await?;
Ok((
L4Protocol::Udp { timeout: None },
host,
jwt.claims.rp,
Box::pin(MyUdpSocket::new(cnx.clone())),
Box::pin(MyUdpSocket::new(cnx)),
))
}
L4Protocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(
&host,
port,
&server_config.socket_so_mark,
Duration::from_secs(10),
)
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
}
_ => Err(anyhow::anyhow!("Invalid upgrade request")),
}
}
async fn server_upgrade(
server_config: Arc<WsServerConfig>,
mut req: Request<Body>,
) -> Result<Response<Body>, anyhow::Error> {
if let Some(x) = req.headers().get("X-Forwarded-For") {
info!("Request X-Forwarded-For: {:?}", x);
Span::current().record("forwarded_for", x.to_str().unwrap_or_default());
}
if !req.uri().path().ends_with("/events") {
warn!(
"Rejecting connection with bad upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request".to_string()))
.unwrap_or_default());
}
let (protocol, dest, port, local_rx, local_tx) =
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
tokio::spawn(
async move {
let (ws_rx, mut ws_tx) = fut.await.unwrap().split(tokio::io::split);
let (close_tx, close_rx) = oneshot::channel::<()>();
let connect_timeout = server_config.timeout_connect;
let ping_frequency = server_config
.websocket_ping_frequency
.unwrap_or(Duration::MAX);
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn(
propagate_write(local_tx, ws_rx, close_rx, connect_timeout)
.instrument(Span::current()),
);
let _ = propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
}
.instrument(Span::current()),
);
Ok(response)
}
#[instrument(name="tunnel", level="info", skip_all, fields(id=field::Empty, remote=field::Empty, peer=field::Empty, forwarded_for=field::Empty))]
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!(
"Starting wstunnel server listening on {}",
server_config.bind
);
let config = server_config.clone();
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
let listener = TcpListener::bind(&server_config.bind).await?;
let tls_acceptor = if let Some(tls) = &server_config.tls {
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
} else {
None
};
loop {
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
Span::current().record("peer", peer_addr.to_string());
info!("Accepting connection");
let upgrade_fn = upgrade_fn.clone();
// TLS
if let Some(tls_acceptor) = &tls_acceptor {
let tls_acceptor = tls_acceptor.clone();
let fut = async move {
info!("Doing TLS handshake");
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(err) => {
error!("error while accepting TLS connection {}", err);
return;
}
};
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(tls_stream, service_fn(upgrade_fn))
.with_upgrades();
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to websocket: {:?}", e);
}
}
.instrument(Span::current());
tokio::spawn(fut);
// Normal
} else {
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(stream, service_fn(upgrade_fn))
.with_upgrades();
let fut = async move {
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to weboscket: {:?}", e);
}
}
.instrument(Span::current());
tokio::spawn(fut);
};
}
}
async fn propagate_read(
local_rx: impl AsyncRead,
mut ws_tx: WebSocketWrite<WriteHalf<Upgraded>>,
mut close_tx: oneshot::Sender<()>,
ping_frequency: Duration,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local tx ==> websocket tx tunnel");
});
let mut buffer = vec![0u8; 8 * 1024];
pin_mut!(local_rx);
loop {
let read = select! {
biased;
read_len = local_rx.read(buffer.as_mut_slice()) => read_len,
_ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?;
continue;
}
};
let read_len = match read {
Ok(read_len) if read_len > 0 => read_len,
Ok(_) => break,
Err(err) => {
warn!(
"error while reading incoming bytes from local tx tunnel {}",
err
);
break;
}
};
trace!("read {} bytes", read_len);
match ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len])))
.await
{
Ok(_) => {}
Err(err) => {
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
}
if read_len == buffer.len() {
buffer.resize(read_len * 2, 0);
}
}
Ok(())
}
async fn propagate_write(
local_tx: impl AsyncWrite,
mut ws_rx: WebSocketRead<ReadHalf<Upgraded>>,
mut close_rx: oneshot::Receiver<()>,
timeout_connect: Duration,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local rx <== websocket rx tunnel");
});
let mut x = |x: Frame<'_>| {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
};
pin_mut!(local_tx);
loop {
let ret = select! {
biased;
ret = timeout(timeout_connect, ws_rx.read_frame(&mut x)) => ret,
_ = &mut close_rx => break,
};
let msg = match ret {
Ok(Ok(msg)) => msg,
Ok(Err(err)) => {
error!("error while reading from websocket rx {}", err);
break;
}
Err(err) => {
trace!("frame {:?}", err);
// TODO: Check that the connection is not closed (no easy method to know if a tx is closed ...)
continue;
}
};
trace!("frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
local_tx.write_all(msg.payload.as_ref()).await
}
OpCode::Close => break,
OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()),
};
match ret {
Ok(_) => {}
Err(err) => {
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
}
}
Ok(())
}

241
src/udp.rs Normal file
View file

@ -0,0 +1,241 @@
#![allow(unused_imports)]
use anyhow::Context;
use futures_util::future::join;
use futures_util::{stream, FutureExt, Stream};
use hyper::server;
use libc::poll;
use pin_project::{pin_project, pinned_drop};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::io::{Error, ErrorKind, IoSlice};
use std::net::SocketAddr;
use std::pin::{pin, Pin};
use std::sync::{Arc, RwLock, Weak};
use std::task::Poll;
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf};
use tokio::net::UdpSocket;
use tokio::time::Sleep;
use tracing::{debug, error, info};
const DEFAULT_UDP_BUFFER_SIZE: usize = 8 * 1024;
struct UdpServer {
listener: UdpSocket,
std_socket: std::net::UdpSocket,
buffer: Vec<u8>,
peers: HashMap<SocketAddr, DuplexStream, ahash::RandomState>,
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
pub cnx_timeout: Option<Duration>,
}
impl UdpServer {
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
let socket = listener.into_std().unwrap();
let listener = UdpSocket::from_std(socket.try_clone().unwrap()).unwrap();
Self {
listener,
std_socket: socket,
peers: HashMap::with_hasher(ahash::RandomState::new()),
buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE],
keys_to_delete: Default::default(),
cnx_timeout: timeout,
}
}
fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().unwrap().len();
if nb_key_to_delete == 0 {
return;
}
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write().unwrap();
for key in keys_to_delete.iter() {
self.peers.remove(key);
}
keys_to_delete.clear();
}
fn clone_socket(&self) -> UdpSocket {
UdpSocket::from_std(self.std_socket.try_clone().unwrap()).unwrap()
}
}
#[pin_project(PinnedDrop)]
pub struct UdpStream {
socket: UdpSocket,
peer: SocketAddr,
#[pin]
deadline: Option<Sleep>,
#[pin]
io: DuplexStream,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
}
impl AsMut<DuplexStream> for UdpStream {
fn as_mut(&mut self) -> &mut DuplexStream {
&mut self.io
}
}
#[pinned_drop]
impl PinnedDrop for UdpStream {
fn drop(self: Pin<&mut Self>) {
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
keys_to_delete.write().unwrap().push(self.peer);
}
}
}
impl AsyncRead for UdpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() {
if deadline.poll(cx).is_ready() {
return Poll::Ready(Err(Error::new(
ErrorKind::TimedOut,
format!("UDP stream timeout with {}", project.peer),
)));
}
}
project.io.poll_read(cx, buf)
}
}
impl AsyncWrite for UdpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
self.socket.poll_send_to(cx, buf, self.peer)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
self.socket.poll_send_ready(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
pub async fn run_server(
bind: SocketAddr,
timeout: Option<Duration>,
) -> Result<impl Stream<Item = io::Result<UdpStream>>, anyhow::Error> {
info!(
"Starting UDP server listening cnx on {} with cnx timeout of {}s",
bind,
timeout.unwrap_or(Duration::from_secs(0)).as_secs()
);
let listener = UdpSocket::bind(bind)
.await
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
let udp_server = UdpServer::new(listener, timeout);
let stream = stream::unfold(udp_server, |mut server| async {
loop {
server.clean_dead_keys();
let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await {
Ok(ret) => ret,
Err(err) => {
error!("Cannot read from UDP server. Closing server: {}", err);
return None;
}
};
match server.peers.entry(peer_addr) {
Entry::Occupied(mut peer) => {
let ret = peer.get_mut().write_all(&server.buffer[0..nb_bytes]).await;
if let Err(err) = ret {
info!("Peer {:?} disconnected {:?}", peer_addr, err);
peer.remove();
}
}
Entry::Vacant(peer) => {
let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE);
rx.write_all(&server.buffer[0..nb_bytes])
.await
.unwrap_or_default(); // should never fail
peer.insert(rx);
let udp_client = UdpStream {
socket: server.clone_socket(),
peer: peer_addr,
deadline: server
.cnx_timeout
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until),
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
io: tx,
};
return Some((Ok(udp_client), (server)));
}
}
}
});
Ok(stream)
}
pub struct MyUdpSocket {
socket: Arc<UdpSocket>,
}
impl MyUdpSocket {
pub fn new(socket: Arc<UdpSocket>) -> Self {
Self { socket }
}
}
impl AsyncRead for MyUdpSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
.poll_recv_from(cx, buf)
.map(|x| x.map(|_| ()))
}
}
impl AsyncWrite for MyUdpSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}