ground 1
Former-commit-id: d125471a7e73cbde30e1d8cb42a9e6d7aac10131 [formerly d14a90382da6197691d28f61151f1278dca23a53] [formerly 51e8380286fd2a4dc2ff577a507d0df2356b1e79 [formerly 0cd5c5c0eaa4a0538a566ba9e6bb5d925da77c1a]] Former-commit-id: b4b5769ad601c8cce35047a3e16ff185e228ea41 [formerly 4c6f9e6bd777c187a240b0a6119c2a4eaa396da4] Former-commit-id: 2aa31860ffe6a5a51f0148527598a7399f968801 Former-commit-id: a826342ca74b913ce45171f92bc6b9d19ac8db08 Former-commit-id: fc039d0217ff4d0c47048755da4f833d86568586 Former-commit-id: d742a7134f042fd67fb1b9399490473babef28d9 [formerly 7e2ea8487c5ce2a5fb39b015eb6d00d8a48654c6] Former-commit-id: f4273cd0403ef19d4cf19c861fa4d730ab10b29d
This commit is contained in:
parent
8c611e9149
commit
8387557459
13 changed files with 1874 additions and 637 deletions
7
.cargo/config.toml
Normal file
7
.cargo/config.toml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
[build]
|
||||||
|
rustflags = ["--cfg", "uuid_unstable"]
|
||||||
|
|
||||||
|
#[target.'cfg(target_os = "linux")']
|
||||||
|
#rustflags = ["-C", "linker=ld.lld", "-C", "relocation-model=static", "-C", "strip=symbols", "--cfg", "uuid_unstable"]
|
||||||
|
|
||||||
|
#[build]
|
27
.gitignore
vendored
27
.gitignore
vendored
|
@ -1,21 +1,10 @@
|
||||||
dist
|
# Generated by Cargo
|
||||||
cabal-dev
|
# will have compiled files and executables
|
||||||
*.o
|
debug/
|
||||||
*.hi
|
target/
|
||||||
*.chi
|
|
||||||
*.chs.h
|
|
||||||
.virtualenv
|
|
||||||
.hsenv
|
|
||||||
.cabal-sandbox/
|
|
||||||
cabal.sandbox.config
|
|
||||||
cabal.config
|
|
||||||
*.log
|
|
||||||
tags
|
|
||||||
bin/
|
|
||||||
*~
|
|
||||||
.stack-work
|
|
||||||
|
|
||||||
|
# These are backup files generated by rustfmt
|
||||||
|
**/*.rs.bk
|
||||||
|
|
||||||
# Added by cargo
|
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||||
|
*.pdb
|
||||||
/target
|
|
||||||
|
|
814
Cargo.lock
generated
814
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
30
Cargo.toml
30
Cargo.toml
|
@ -8,12 +8,36 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.4.5", features = ["derive"]}
|
clap = { version = "4.4.5", features = ["derive"]}
|
||||||
url = "2.4.1"
|
url = "2.4.1"
|
||||||
|
anyhow = "1.0.75"
|
||||||
|
|
||||||
reqwest = { version = "0.11.20", features = ["stream", "trust-dns"] }
|
hyper = { version = "0.14.27", features = ["client", "runtime"] }
|
||||||
hyper = { version = "1.0.0-rc.4", features = ["client", "http2"] }
|
#fastwebsockets = { version = "0.4.4", features = ["upgrade"]}
|
||||||
hyper-openssl = {version = "0.9.2", features = []}
|
fastwebsockets = { git = "https://github.com/mmastrac/fastwebsockets", branch = "split", features = ["upgrade", "simd"]}
|
||||||
|
libc = { version = "0.2.148", features = []}
|
||||||
|
once_cell = { version = "1.18.0", features = [] }
|
||||||
|
ahash = { version = "0.8.3", features = []}
|
||||||
|
pin-project = "1"
|
||||||
|
scopeguard = "1.2.0"
|
||||||
|
uuid = { version = "1.4.1", features = ["v7", "serde"] }
|
||||||
|
jsonwebtoken = { version = "8.3.0", default-features = false }
|
||||||
|
rustls-pemfile = { version = "1.0.3", features = [] }
|
||||||
|
|
||||||
|
rustls-native-certs = { version = "0.6.3", features = [] }
|
||||||
tokio = { version = "1.32.0", features = ["full"] }
|
tokio = { version = "1.32.0", features = ["full"] }
|
||||||
|
tokio-rustls = { version = "0.24.1", features = ["tls12", "dangerous_configuration", "early-data"] }
|
||||||
|
tokio-stream = { version = "0.1.14", features = ["net"] }
|
||||||
|
tokio-fd = "0.3.0"
|
||||||
|
futures-util = { version = "0.3.28" }
|
||||||
|
|
||||||
tracing = { version = "0.1.37", features = ["log"] }
|
tracing = { version = "0.1.37", features = ["log"] }
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "local-time"] }
|
tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "local-time"] }
|
||||||
|
base64 = "0.21.4"
|
||||||
|
serde = { version = "1.0.188", features = ["derive"] }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
|
panic = "abort"
|
||||||
|
codegen-units = 1
|
||||||
|
opt-level = 3
|
||||||
|
|
21
certs/cert.pem
Normal file
21
certs/cert.pem
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIDgzCCAmugAwIBAgIUSFStqIolH/v5Mp2u8dNw2kHDEUowDQYJKoZIhvcNAQEF
|
||||||
|
BQAwUDELMAkGA1UEBhMCRlIxDjAMBgNVBAgMBVBhcmlzMQ4wDAYDVQQHDAVQYXJp
|
||||||
|
czEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMCAXDTIzMTAxNTEz
|
||||||
|
Mzk1NVoYDzIwNTEwMzAxMTMzOTU1WjBQMQswCQYDVQQGEwJGUjEOMAwGA1UECAwF
|
||||||
|
UGFyaXMxDjAMBgNVBAcMBVBhcmlzMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRz
|
||||||
|
IFB0eSBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD2fM5T4YzD
|
||||||
|
a4By7hHMrTL1BYWgr7OUcYDGAuzZiXEKkvc/zuiYHtek2n/hOvZCPp0pba4tOlfM
|
||||||
|
6BA4qK1cgabH95Q/RfSqttjHoH5hgXUrUZ5YI0n9/7XZRnv5idO5dYqAHElRX70H
|
||||||
|
YVHzU+xwvzfmWPI6QFMF9lhHQzSQyN7P3iZc97nLtTCxmVg0Wgo103CQ3J8Sop07
|
||||||
|
+uZuzkPCkgW8eVEoMPTDZ/pChdW2lvJxGs4BQu92UC73XPhRECbAxXA7JwXxegYl
|
||||||
|
K2pJcNcJWIyRGLEbaVxRMPiBpbIfJbU1nNoSlgGKJb8GuhVK7y4eRxLnOvnLGCFp
|
||||||
|
dl9c3o6iPYH/AgMBAAGjUzBRMB0GA1UdDgQWBBS4Y1uJ52HbmP1YLWETMcVn7fI/
|
||||||
|
SzAfBgNVHSMEGDAWgBS4Y1uJ52HbmP1YLWETMcVn7fI/SzAPBgNVHRMBAf8EBTAD
|
||||||
|
AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQC8R9bx8P1TQsfNIqHhRuSss623VCdPPMgt
|
||||||
|
uJzXsZVYTfKizIo8nIWpy2y+RpJFpgB26XtrBORwZmc+pDjiABInZxUYoQEMmz7K
|
||||||
|
gc6OBAeweVD3QNcxqfO+NLft6tP6r3aqDjfF0w358LbuIRGRE34e5wdYBKqNmcu5
|
||||||
|
Bh9XcWCL7mP3aq+Sl0340Zl+/rPi0sLMNohEYTX6+/XB7qM27Cq/JDJhxGVdKRxO
|
||||||
|
nv/K02yKpY/C+8tJRJ86v5gTFfDtjGpu9EmDhtGCnpeqX55uE4pgeKUdkNgGviKD
|
||||||
|
BTizUWSqnkkuqQdZ+DGT4HVXvKHYyWswbHN19huq7SZK17SOz19o
|
||||||
|
-----END CERTIFICATE-----
|
28
certs/key.pem
Normal file
28
certs/key.pem
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD2fM5T4YzDa4By
|
||||||
|
7hHMrTL1BYWgr7OUcYDGAuzZiXEKkvc/zuiYHtek2n/hOvZCPp0pba4tOlfM6BA4
|
||||||
|
qK1cgabH95Q/RfSqttjHoH5hgXUrUZ5YI0n9/7XZRnv5idO5dYqAHElRX70HYVHz
|
||||||
|
U+xwvzfmWPI6QFMF9lhHQzSQyN7P3iZc97nLtTCxmVg0Wgo103CQ3J8Sop07+uZu
|
||||||
|
zkPCkgW8eVEoMPTDZ/pChdW2lvJxGs4BQu92UC73XPhRECbAxXA7JwXxegYlK2pJ
|
||||||
|
cNcJWIyRGLEbaVxRMPiBpbIfJbU1nNoSlgGKJb8GuhVK7y4eRxLnOvnLGCFpdl9c
|
||||||
|
3o6iPYH/AgMBAAECggEALdKa6uYh7Ix2JyeSAIpsUDe0FWDEkkKdjXIqxPAzpyMW
|
||||||
|
OvMEs478SOXj4yO6dys7vWFqAXd4rhuwNFBLVki2EDO7CB5Bs2DloQr5o7fU5/Y2
|
||||||
|
6Sy6SzF4BYoAby4Lwc0Tr+hSSwHw2sfhW8qMyJML2dNMSL7/kDqxQ6I/SfFF1r+Y
|
||||||
|
3LKaC98/jxiIco5Cgabd065x2NVOshWzIkY5xuvCYjfQlEJWKHbuxIIrcJXApEe4
|
||||||
|
pmexuK8VVb/Prm6Ci1+hsWOgkXuv/3EUZNxeQ87kek7Ggw3U1CXLAJ5H+FuNEVuy
|
||||||
|
mbmfX34GwKeC9tq/4zQifFS0BLaP4ND3AAo2rTvgDQKBgQD8ytKguW9yLVWcX7dl
|
||||||
|
ncEXbSKEycMfrvqJw8NvTt/9O/Uto2ri1JAmFBJ5m9tPZpujOSkdPXRmFx+lFyt+
|
||||||
|
XlkJrn1BYfZYoxMkUp5qbVF3tk32mLZM0K1yyxb5XEfcykbA1S8eLDt7F4h558Sp
|
||||||
|
e2+K60klFDB6b4Yil/a9aN0QIwKBgQD5nYCvKStaw+3YnX7TnuycbfedYiMAX7kC
|
||||||
|
1O7HGzEo+gDm4wvCF2pXPsSNJCG9c9KIpnQQrtL68MtAlVgRGa/GNM4emEVgpLtE
|
||||||
|
W6SFVkEBb3JYY1JAB5umxwO0TcFwTn3ivy2QtrFphWSG8gsZnMIrIfui3X3llY6J
|
||||||
|
Cu8iqvd2dQKBgD3R7/6EOr/mXEhYlAYStTTgaI+ms8Qcy4JDUJj45ggM0KGvlCUS
|
||||||
|
rInTYM1CkzhwtGEPSoGvFLcesotyBh3qPsYCWPlTVqZIgxbf6YPHZiPrfldu8y4H
|
||||||
|
3lLzXZPvwFc7VGA2AkbTtFwe3i5Jwqtb12RWs9WQgWZ/vYLaPOoHKgCXAoGBAKY/
|
||||||
|
uZJwCBkWv5XzJ6JIieyR7UZcM1Wva2iwaywfNznEcM9WTuGBeOkMvBoJA5PLzWAI
|
||||||
|
BOuLlKdfsu+byCDzi7emOdX0sthwPu2DX+sSjI8pK+4kkIZmystkZ1oyI3DqRju7
|
||||||
|
+twUYcsW9eJO2QfA+S2DH7bUcGJ1no41wxnC5rh1AoGBALjSLlnJtumrsiZLUeRB
|
||||||
|
rf4n2QIiXd+7tn2ZKSNYn/621hCHwl6peAI3G57Avxixf7Yv4wAOaIgSWUrenmJ9
|
||||||
|
DbwM06AZqnLwraM1B6c3tujOfjIFU7IjENdELxrTnVVYq9a7HDJzT6ty7m1A7eDA
|
||||||
|
9cEnUJ+dG5G05eV44S4Z/6sF
|
||||||
|
-----END PRIVATE KEY-----
|
16
src/embedded_certificate.rs
Normal file
16
src/embedded_certificate.rs
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tokio_rustls::rustls::{Certificate, PrivateKey};
|
||||||
|
|
||||||
|
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
|
||||||
|
let key = include_bytes!("../certs/key.pem");
|
||||||
|
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice())
|
||||||
|
.expect("failed to load embedded tls private key");
|
||||||
|
PrivateKey(keys.remove(0))
|
||||||
|
});
|
||||||
|
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
|
||||||
|
let cert = include_bytes!("../certs/cert.pem");
|
||||||
|
let certs = rustls_pemfile::certs(&mut cert.as_slice())
|
||||||
|
.expect("failed to load embedded tls certificate");
|
||||||
|
|
||||||
|
certs.into_iter().map(Certificate).collect()
|
||||||
|
});
|
570
src/main.rs
570
src/main.rs
|
@ -1,20 +1,42 @@
|
||||||
|
mod embedded_certificate;
|
||||||
|
mod stdio;
|
||||||
|
mod tcp;
|
||||||
|
mod tls;
|
||||||
|
mod transport;
|
||||||
|
mod udp;
|
||||||
|
|
||||||
|
use base64::Engine;
|
||||||
|
use clap::Parser;
|
||||||
|
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
|
||||||
|
use hyper::http::HeaderValue;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::{BTreeMap, HashMap};
|
||||||
|
use std::io;
|
||||||
use std::io::ErrorKind;
|
use std::io::ErrorKind;
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use clap::Parser;
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use hyper::body::Body;
|
use tokio::net::TcpStream;
|
||||||
use hyper::Request;
|
|
||||||
use hyper_openssl::HttpsConnector;
|
|
||||||
use url::{Host, Url, UrlQuery};
|
|
||||||
|
|
||||||
/// Simple program to greet a person
|
use tokio_rustls::rustls::server::DnsName;
|
||||||
|
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
|
||||||
|
|
||||||
|
use tracing::{debug, error, field, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
use url::{Host, Url};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
/// Use the websockets protocol to tunnel {TCP,UDP} traffic
|
||||||
|
/// wsTunnelClient <---> wsTunnelServer <---> RemoteHost
|
||||||
|
/// Use secure connection (wss://) to bypass proxies
|
||||||
#[derive(clap::Parser, Debug)]
|
#[derive(clap::Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, verbatim_doc_comment, long_about = None)]
|
||||||
struct Wstunnel {
|
struct Wstunnel {
|
||||||
|
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
commands: Commands,
|
commands: Commands,
|
||||||
}
|
}
|
||||||
|
@ -22,139 +44,567 @@ struct Wstunnel {
|
||||||
#[derive(clap::Subcommand, Debug)]
|
#[derive(clap::Subcommand, Debug)]
|
||||||
enum Commands {
|
enum Commands {
|
||||||
Client(Client),
|
Client(Client),
|
||||||
Server(Server)
|
Server(Server),
|
||||||
}
|
}
|
||||||
#[derive(clap::Args, Debug)]
|
#[derive(clap::Args, Debug)]
|
||||||
struct Client {
|
struct Client {
|
||||||
/// Name of the person to greet
|
/// Listen on local and forwards traffic from remote
|
||||||
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
|
/// Can be specified multiple times
|
||||||
|
#[arg(short='L', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
|
||||||
local_to_remote: Vec<LocalToRemote>,
|
local_to_remote: Vec<LocalToRemote>,
|
||||||
|
|
||||||
|
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
|
||||||
|
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
|
||||||
|
#[arg(long, value_name = "INT", verbatim_doc_comment)]
|
||||||
|
socket_so_mark: Option<u32>,
|
||||||
|
|
||||||
|
/// Domain name that will be use as SNI during TLS handshake
|
||||||
|
/// Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header.
|
||||||
|
/// or it will be flag as fishy as your request rejected
|
||||||
|
#[arg(long, value_name = "DOMAIN_NAME", value_parser = parse_sni_override, verbatim_doc_comment)]
|
||||||
|
tls_sni_override: Option<DnsName>,
|
||||||
|
|
||||||
|
/// Enable TLS certificate verification.
|
||||||
|
/// Disabled by default. The client will happily connect to any server with self signed certificate.
|
||||||
|
#[arg(long, verbatim_doc_comment)]
|
||||||
|
tls_verify_certificate: bool,
|
||||||
|
|
||||||
|
/// Use a specific prefix that will show up in the http path during the upgrade request.
|
||||||
|
/// Useful if you need to route requests server side but don't have vhosts
|
||||||
|
#[arg(long, default_value = "morille", verbatim_doc_comment)]
|
||||||
|
http_upgrade_path_prefix: String,
|
||||||
|
|
||||||
|
/// Pass authorization header with basic auth credentials during the upgrade request.
|
||||||
|
/// If you need more customization, you can use the http_headers option.
|
||||||
|
#[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)]
|
||||||
|
http_upgrade_credentials: Option<HeaderValue>,
|
||||||
|
|
||||||
|
/// Frequency at which the client will send websocket ping to the server.
|
||||||
|
#[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)]
|
||||||
|
websocket_ping_frequency_sec: Option<Duration>,
|
||||||
|
|
||||||
|
/// Enable the masking of websocket frames. Default is false
|
||||||
|
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
|
||||||
|
#[arg(long, default_value = "false", verbatim_doc_comment)]
|
||||||
|
websocket_mask_frame: bool,
|
||||||
|
|
||||||
|
/// Send custom headers in the upgrade request
|
||||||
|
/// Can be specified multiple time
|
||||||
|
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
|
||||||
|
http_headers: Vec<(String, HeaderValue)>,
|
||||||
|
|
||||||
|
/// Address of the wstunnel server
|
||||||
|
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
|
||||||
|
#[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
|
||||||
|
remote_addr: Url,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(clap::Args, Debug)]
|
#[derive(clap::Args, Debug)]
|
||||||
struct Server {
|
struct Server {
|
||||||
/// Name of the person to greet
|
/// Address of the wstunnel server to bind to
|
||||||
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
|
/// Example: With TLS wss://0.0.0.0:8080 or without ws://[::]:8080
|
||||||
local_to_remote: String,
|
#[arg(value_name = "ws[s]://0.0.0.0[:port]", value_parser = parse_server_url, verbatim_doc_comment)]
|
||||||
|
remote_addr: Url,
|
||||||
|
|
||||||
|
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
|
||||||
|
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
|
||||||
|
#[arg(long, value_name = "INT", verbatim_doc_comment)]
|
||||||
|
socket_so_mark: Option<i32>,
|
||||||
|
|
||||||
|
/// Frequency at which the server will send websocket ping to client.
|
||||||
|
#[arg(long, value_name = "seconds", value_parser = parse_duration_sec, verbatim_doc_comment)]
|
||||||
|
websocket_ping_frequency_sec: Option<Duration>,
|
||||||
|
|
||||||
|
/// Enable the masking of websocket frames. Default is false
|
||||||
|
/// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead.
|
||||||
|
#[arg(long, default_value = "false", verbatim_doc_comment)]
|
||||||
|
websocket_mask_frame: bool,
|
||||||
|
|
||||||
|
/// Server will only accept connection from the specified tunnel information.
|
||||||
|
/// Can be specified multiple time
|
||||||
|
/// Example: --restrict-to "google.com:443" --restrict-to "localhost:22"
|
||||||
|
#[arg(long, value_name = "DEST:PORT", verbatim_doc_comment)]
|
||||||
|
restrict_to: Option<Vec<String>>,
|
||||||
|
|
||||||
|
/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
|
||||||
|
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||||
|
tls_certificate: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one
|
||||||
|
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
|
||||||
|
tls_private_key: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
enum L4Protocol {
|
enum L4Protocol {
|
||||||
TCP, UDP { timeout: Duration }
|
Tcp,
|
||||||
|
Udp { timeout: Option<Duration> },
|
||||||
|
Stdio,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl L4Protocol {
|
impl L4Protocol {
|
||||||
fn new_udp() -> L4Protocol {
|
fn new_udp() -> L4Protocol {
|
||||||
L4Protocol::UDP { timeout: Duration::from_secs(30) }
|
L4Protocol::Udp {
|
||||||
|
timeout: Some(Duration::from_secs(30)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct LocalToRemote {
|
pub struct LocalToRemote {
|
||||||
|
socket_so_mark: Option<i32>,
|
||||||
protocol: L4Protocol,
|
protocol: L4Protocol,
|
||||||
local: SocketAddr,
|
local: SocketAddr,
|
||||||
remote: (Host<String>, u16),
|
remote: (Host<String>, u16),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_env_var(arg: &str) -> Result<LocalToRemote, std::io::Error> {
|
fn parse_duration_sec(arg: &str) -> Result<Duration, io::Error> {
|
||||||
|
use std::io::Error;
|
||||||
|
|
||||||
|
let Ok(secs) = arg.parse::<u64>() else {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot duration of seconds from {}", arg),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Duration::from_secs(secs))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
|
|
||||||
let (mut protocol, arg) = match &arg[..6] {
|
let (mut protocol, arg) = match &arg[..6] {
|
||||||
"tcp://" => (L4Protocol::TCP, &arg[6..]),
|
"tcp://" => (L4Protocol::Tcp, &arg[6..]),
|
||||||
"udp://" => (L4Protocol::new_udp(), &arg[6..]),
|
"udp://" => (L4Protocol::new_udp(), &arg[6..]),
|
||||||
_ => (L4Protocol::TCP, arg)
|
_ => match &arg[..8] {
|
||||||
|
"stdio://" => (L4Protocol::Stdio, &arg[8..]),
|
||||||
|
_ => (L4Protocol::Tcp, arg),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let (bind, remaining) = if arg.starts_with('[') {
|
let (bind, remaining) = if arg.starts_with('[') {
|
||||||
// ipv6 bind
|
// ipv6 bind
|
||||||
let Some((ipv6_str, remaining)) = arg.split_once(']') else {
|
let Some((ipv6_str, remaining)) = arg.split_once(']') else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", arg)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse IPv6 bind from {}", arg),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else {
|
let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", ipv6_str)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse IPv6 bind from {}", ipv6_str),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
(IpAddr::V6(ipv6_addr), remaining)
|
(IpAddr::V6(ipv6_addr), remaining)
|
||||||
} else {
|
} else {
|
||||||
// Maybe ipv4 addr
|
// Maybe ipv4 addr
|
||||||
let Some((ipv4_str, remaining)) = arg.split_once(':') else {
|
let Some((ipv4_str, remaining)) = arg.split_once(':') else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv4 bind from {}", arg)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse IPv4 bind from {}", arg),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
match Ipv4Addr::from_str(ipv4_str) {
|
match Ipv4Addr::from_str(ipv4_str) {
|
||||||
Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining),
|
Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining),
|
||||||
// Must be the port, so we default to ipv6 bind
|
// Must be the port, so we default to ipv6 bind
|
||||||
Err(_) => (IpAddr::V6(Ipv6Addr::from_str("::1").unwrap()), arg)
|
Err(_) => (IpAddr::V4(Ipv4Addr::from_str("127.0.0.1").unwrap()), arg),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
|
let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", remaining)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse bind port from {}", remaining),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Ok(bind_port): Result<u16, _> = port_str.parse() else {
|
let Ok(bind_port): Result<u16, _> = port_str.parse() else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", port_str)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse bind port from {}", port_str),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
|
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote from {}", remaining)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse remote from {}", remaining),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(remote_host) = remote.host() else {
|
let Some(remote_host) = remote.host() else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote host from {}", remaining)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse remote host from {}", remaining),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(remote_port) = remote.port() else {
|
let Some(remote_port) = remote.port() else {
|
||||||
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote port from {}", remaining)));
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse remote port from {}", remaining),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
match &mut protocol {
|
|
||||||
L4Protocol::TCP => {}
|
|
||||||
L4Protocol::UDP { ref mut timeout, .. } => {
|
|
||||||
let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
|
let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
|
||||||
if let Some(duration) = options.get("timeout_sec")
|
match &mut protocol {
|
||||||
|
L4Protocol::Stdio => {}
|
||||||
|
L4Protocol::Tcp => {}
|
||||||
|
L4Protocol::Udp {
|
||||||
|
ref mut timeout, ..
|
||||||
|
} => {
|
||||||
|
if let Some(duration) = options
|
||||||
|
.get("timeout_sec")
|
||||||
.and_then(|x| x.parse::<u64>().ok())
|
.and_then(|x| x.parse::<u64>().ok())
|
||||||
.map(|x| Duration::from_secs(x)) {
|
.map(|d| {
|
||||||
|
if d == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(Duration::from_secs(d))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
{
|
||||||
*timeout = duration;
|
*timeout = duration;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(LocalToRemote {
|
Ok(LocalToRemote {
|
||||||
|
socket_so_mark: options
|
||||||
|
.get("socket_so_mark")
|
||||||
|
.and_then(|x| x.parse::<i32>().ok()),
|
||||||
protocol,
|
protocol,
|
||||||
local: SocketAddr::new(bind, bind_port),
|
local: SocketAddr::new(bind, bind_port),
|
||||||
remote: (remote_host.to_owned(), remote_port)
|
remote: (remote_host.to_owned(), remote_port),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn parse_sni_override(arg: &str) -> Result<DnsName, io::Error> {
|
||||||
println!("Hello, world!");
|
match DnsName::try_from(arg.to_string()) {
|
||||||
|
Ok(val) => Ok(val),
|
||||||
|
Err(err) => Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("Invalid sni override: {}", err),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_http_headers(arg: &str) -> Result<(String, HeaderValue), io::Error> {
|
||||||
|
let Some((key, value)) = arg.split_once(':') else {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse http header from {}", arg),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = match HeaderValue::from_str(value.trim()) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!(
|
||||||
|
"cannot parse http header value from {} due to {:?}",
|
||||||
|
value, err
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((key.to_owned(), value))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_http_credentials(arg: &str) -> Result<HeaderValue, io::Error> {
|
||||||
|
let encoded = base64::engine::general_purpose::STANDARD.encode(arg.trim().as_bytes());
|
||||||
|
let Ok(header) = HeaderValue::from_str(&format!("Basic {}", encoded)) else {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse http credentials {}", arg),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
||||||
|
let Ok(url) = Url::parse(arg) else {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("cannot parse server url {}", arg),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
if url.scheme() != "ws" && url.scheme() != "wss" {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("invalid scheme {}", url.scheme()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.host().is_none() {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("invalid server host {}", arg),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TlsClientConfig {
|
||||||
|
pub tls_sni_override: Option<DnsName>,
|
||||||
|
pub tls_verify_certificate: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TlsServerConfig {
|
||||||
|
pub tls_certificate: Vec<Certificate>,
|
||||||
|
pub tls_key: PrivateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct WsServerConfig {
|
||||||
|
pub socket_so_mark: Option<i32>,
|
||||||
|
pub bind: SocketAddr,
|
||||||
|
pub restrict_to: Option<Vec<String>>,
|
||||||
|
pub websocket_ping_frequency: Option<Duration>,
|
||||||
|
pub timeout_connect: Duration,
|
||||||
|
pub websocket_mask_frame: bool,
|
||||||
|
pub tls: Option<TlsServerConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct WsClientConfig {
|
||||||
|
pub remote_addr: (Host<String>, u16),
|
||||||
|
pub tls: Option<TlsClientConfig>,
|
||||||
|
pub http_upgrade_path_prefix: String,
|
||||||
|
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||||
|
pub http_headers: HashMap<String, HeaderValue>,
|
||||||
|
pub timeout_connect: Duration,
|
||||||
|
pub websocket_ping_frequency: Duration,
|
||||||
|
pub websocket_mask_frame: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClientConfig {
|
||||||
|
pub fn websocket_scheme(&self) -> &'static str {
|
||||||
|
match self.tls {
|
||||||
|
None => "ws",
|
||||||
|
Some(_) => "wss",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn websocket_host_url(&self) -> String {
|
||||||
|
format!("{}:{}", self.remote_addr.0, self.remote_addr.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_server_name(&self) -> ServerName {
|
||||||
|
match self
|
||||||
|
.tls
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|tls| tls.tls_sni_override.as_ref())
|
||||||
|
{
|
||||||
|
None => match &self.remote_addr.0 {
|
||||||
|
Host::Domain(domain) => {
|
||||||
|
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
|
||||||
|
}
|
||||||
|
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
|
||||||
|
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
|
||||||
|
},
|
||||||
|
Some(sni_override) => ServerName::DnsName(sni_override.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
let args = Wstunnel::parse();
|
let args = Wstunnel::parse();
|
||||||
|
|
||||||
println!("Hello {:?}!", args)
|
// Setup logging
|
||||||
|
match &args.commands {
|
||||||
|
// Disable logging if there is a stdio tunnel
|
||||||
|
Commands::Client(args)
|
||||||
|
if args
|
||||||
|
.local_to_remote
|
||||||
|
.iter()
|
||||||
|
.filter(|x| x.protocol == L4Protocol::Stdio)
|
||||||
|
.count()
|
||||||
|
> 0 => {}
|
||||||
|
_ => {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_ansi(true)
|
||||||
|
.with_env_filter(EnvFilter::from_default_env())
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
match args.commands {
|
||||||
.timeout(Duration::from_secs(10))
|
Commands::Client(args) => {
|
||||||
.connect_timeout(Duration::from_secs(10))
|
let tls = match args.remote_addr.scheme() {
|
||||||
.danger_accept_invalid_certs(true)
|
"ws" => None,
|
||||||
.build().unwrap();
|
"wss" => Some(TlsClientConfig {
|
||||||
|
tls_sni_override: args.tls_sni_override,
|
||||||
|
tls_verify_certificate: args.tls_verify_certificate,
|
||||||
|
}),
|
||||||
|
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let server_config = Arc::new(WsClientConfig {
|
||||||
let mut conn = HttpsConnector::new()?;
|
remote_addr: (
|
||||||
conn.set_callback(move |c, _| {
|
args.remote_addr.host().unwrap().to_owned(),
|
||||||
// Prevent native TLS lib from inferring and verifying a default SNI.
|
args.remote_addr.port_or_known_default().unwrap(),
|
||||||
c.set_use_server_name_indication(false);
|
),
|
||||||
c.set_verify_hostname(false);
|
tls,
|
||||||
|
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
||||||
// And set a custom SNI instead.
|
http_upgrade_credentials: args.http_upgrade_credentials,
|
||||||
c.set_hostname("somewhere.com")
|
http_headers: args.http_headers.into_iter().collect(),
|
||||||
|
timeout_connect: Duration::from_secs(10),
|
||||||
|
websocket_ping_frequency: args
|
||||||
|
.websocket_ping_frequency_sec
|
||||||
|
.unwrap_or(Duration::from_secs(30)),
|
||||||
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
});
|
});
|
||||||
Client::builder()
|
|
||||||
.build::<_, Body>(conn)
|
|
||||||
.request(Request::get("somewhere-else.com").body(())?)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
reqwest::Proxy::all("https://google.com").unwrap().basic_auth("", "")
|
// Start tunnels
|
||||||
|
for tunnel in args.local_to_remote.into_iter() {
|
||||||
|
let server_config = server_config.clone();
|
||||||
|
|
||||||
|
match &tunnel.protocol {
|
||||||
|
L4Protocol::Tcp => {
|
||||||
|
let server = tcp::run_server(tunnel.local)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
|
||||||
|
})
|
||||||
|
.map_ok(TcpStream::into_split);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
|
||||||
|
error!("{:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
L4Protocol::Udp { timeout } => {
|
||||||
|
let server = udp::run_server(tunnel.local, *timeout)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
|
||||||
|
})
|
||||||
|
.map_ok(tokio::io::split);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
|
||||||
|
error!("{:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
L4Protocol::Stdio => {
|
||||||
|
let server = stdio::run_server().await.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start STDIO server: {}", err);
|
||||||
|
});
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(err) = run_tunnel(
|
||||||
|
server_config,
|
||||||
|
tunnel,
|
||||||
|
stream::once(async move { Ok(server) }),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("{:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Commands::Server(args) => {
|
||||||
|
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||||
|
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
|
||||||
|
tls::load_certificates_from_pem(&cert_path)
|
||||||
|
.expect("Cannot load tls certificate")
|
||||||
|
} else {
|
||||||
|
embedded_certificate::TLS_CERTIFICATE.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let tls_key = if let Some(key_path) = args.tls_private_key {
|
||||||
|
tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key")
|
||||||
|
} else {
|
||||||
|
embedded_certificate::TLS_PRIVATE_KEY.clone()
|
||||||
|
};
|
||||||
|
Some(TlsServerConfig {
|
||||||
|
tls_certificate,
|
||||||
|
tls_key,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let server_config = WsServerConfig {
|
||||||
|
socket_so_mark: args.socket_so_mark,
|
||||||
|
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
|
||||||
|
restrict_to: args.restrict_to,
|
||||||
|
websocket_ping_frequency: args.websocket_ping_frequency_sec,
|
||||||
|
timeout_connect: Duration::from_secs(10),
|
||||||
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
|
tls: tls_config,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("{:?}", server_config);
|
||||||
|
transport::run_server(Arc::new(server_config))
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start wstunnel server: {:?}", err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::signal::ctrl_c().await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(name="tunnel", level="info", skip_all, fields(id=field::Empty, remote=field::Empty))]
|
||||||
|
async fn run_tunnel<T, R, W>(
|
||||||
|
server_config: Arc<WsClientConfig>,
|
||||||
|
tunnel: LocalToRemote,
|
||||||
|
incoming_cnx: T,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
|
where
|
||||||
|
T: Stream<Item = io::Result<(R, W)>>,
|
||||||
|
R: AsyncRead + Send + 'static,
|
||||||
|
W: AsyncWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
let span = Span::current();
|
||||||
|
let request_id = Uuid::now_v7();
|
||||||
|
span.record("id", request_id.to_string());
|
||||||
|
span.record(
|
||||||
|
"remote",
|
||||||
|
&format!("{}:{}", tunnel.remote.0, tunnel.remote.1),
|
||||||
|
);
|
||||||
|
|
||||||
|
let tunnel = Arc::new(tunnel);
|
||||||
|
pin_mut!(incoming_cnx);
|
||||||
|
|
||||||
|
while let Some(Ok(cnx_stream)) = incoming_cnx.next().await {
|
||||||
|
let server_config = server_config.clone();
|
||||||
|
let tunnel = tunnel.clone();
|
||||||
|
|
||||||
|
tokio::spawn(
|
||||||
|
async move {
|
||||||
|
let ret =
|
||||||
|
transport::connect_to_server(request_id, &server_config, &tunnel, cnx_stream)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(ret) = ret {
|
||||||
|
error!("{:?}", ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.instrument(span.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
19
src/stdio.rs
Normal file
19
src/stdio.rs
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
#![allow(unused_imports)]
|
||||||
|
|
||||||
|
use libc::STDIN_FILENO;
|
||||||
|
use std::os::fd::{AsRawFd, FromRawFd};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio::io::{stdout, AsyncRead, ReadBuf, Stdout};
|
||||||
|
use tokio_fd::AsyncFd;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
|
||||||
|
info!("Starting STDIO server");
|
||||||
|
|
||||||
|
let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?;
|
||||||
|
let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?;
|
||||||
|
|
||||||
|
Ok((stdin, stdout))
|
||||||
|
}
|
110
src/tcp.rs
Normal file
110
src/tcp.rs
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use std::{io, vec};
|
||||||
|
|
||||||
|
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
|
use std::os::fd::AsRawFd;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||||
|
use tokio::time::timeout;
|
||||||
|
use tokio_stream::wrappers::TcpListenerStream;
|
||||||
|
use tracing::debug;
|
||||||
|
use tracing::log::info;
|
||||||
|
use url::Host;
|
||||||
|
|
||||||
|
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
|
||||||
|
socket.set_nodelay(true).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"cannot set no_delay on socket: {}",
|
||||||
|
io::Error::last_os_error()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if let Some(so_mark) = so_mark {
|
||||||
|
unsafe {
|
||||||
|
let optval: libc::c_int = *so_mark;
|
||||||
|
let ret = libc::setsockopt(
|
||||||
|
socket.as_raw_fd(),
|
||||||
|
libc::SOL_SOCKET,
|
||||||
|
libc::SO_MARK,
|
||||||
|
&optval as *const _ as *const libc::c_void,
|
||||||
|
std::mem::size_of_val(&optval) as libc::socklen_t,
|
||||||
|
);
|
||||||
|
|
||||||
|
if ret != 0 {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Cannot set SO_MARK on the connection {:?}",
|
||||||
|
io::Error::last_os_error()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub async fn connect(
|
||||||
|
host: &Host<String>,
|
||||||
|
port: u16,
|
||||||
|
so_mark: &Option<i32>,
|
||||||
|
connect_timeout: Duration,
|
||||||
|
) -> Result<TcpStream, anyhow::Error> {
|
||||||
|
info!("Opening TCP connection to {}:{}", host, port);
|
||||||
|
|
||||||
|
// TODO: Avoid allocation of vec by extracting the code that does the connection in a separate function
|
||||||
|
let socket_addrs: Vec<SocketAddr> = match host {
|
||||||
|
Host::Domain(domain) => tokio::net::lookup_host(format!("{}:{}", domain, port))
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("cannot resolve domain: {}", domain))?
|
||||||
|
.collect(),
|
||||||
|
Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))],
|
||||||
|
Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cnx = None;
|
||||||
|
let mut last_err = None;
|
||||||
|
for addr in socket_addrs {
|
||||||
|
debug!("connecting to {}", addr);
|
||||||
|
|
||||||
|
let mut socket = match &addr {
|
||||||
|
SocketAddr::V4(_) => TcpSocket::new_v4()?,
|
||||||
|
SocketAddr::V6(_) => TcpSocket::new_v6()?,
|
||||||
|
};
|
||||||
|
|
||||||
|
configure_socket(&mut socket, so_mark)?;
|
||||||
|
match timeout(connect_timeout, socket.connect(addr)).await {
|
||||||
|
Ok(Ok(stream)) => {
|
||||||
|
cnx = Some(stream);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ok(Err(err)) => {
|
||||||
|
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
|
||||||
|
last_err = Some(err);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
debug!(
|
||||||
|
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
|
||||||
|
connect_timeout.as_secs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(cnx) = cnx {
|
||||||
|
Ok(cnx)
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(
|
||||||
|
"Cannot connect to tcp endpoint {}:{} reason {:?}",
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
last_err
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_server(bind: SocketAddr) -> Result<TcpListenerStream, anyhow::Error> {
|
||||||
|
info!("Starting TCP server listening cnx on {}", bind);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(bind)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Cannot create TCP server {:?}", bind))?;
|
||||||
|
Ok(TcpListenerStream::new(listener))
|
||||||
|
}
|
125
src/tls.rs
Normal file
125
src/tls.rs
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
use crate::{TlsClientConfig, TlsServerConfig, WsClientConfig};
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use std::fs::File;
|
||||||
|
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio_rustls::client::TlsStream;
|
||||||
|
use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||||
|
|
||||||
|
use tokio_rustls::rustls::{Certificate, ClientConfig, PrivateKey, ServerName};
|
||||||
|
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
pub struct NullVerifier;
|
||||||
|
impl ServerCertVerifier for NullVerifier {
|
||||||
|
fn verify_server_cert(
|
||||||
|
&self,
|
||||||
|
_end_entity: &Certificate,
|
||||||
|
_intermediates: &[Certificate],
|
||||||
|
_server_name: &ServerName,
|
||||||
|
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||||
|
_ocsp_response: &[u8],
|
||||||
|
_now: SystemTime,
|
||||||
|
) -> Result<ServerCertVerified, rustls::Error> {
|
||||||
|
Ok(ServerCertVerified::assertion())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_certificates_from_pem(path: &Path) -> anyhow::Result<Vec<Certificate>> {
|
||||||
|
let file = File::open(path)?;
|
||||||
|
let mut reader = BufReader::new(file);
|
||||||
|
let certs = rustls_pemfile::certs(&mut reader)?;
|
||||||
|
|
||||||
|
Ok(certs.into_iter().map(Certificate).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
|
||||||
|
let file = File::open(path)?;
|
||||||
|
let mut reader = BufReader::new(file);
|
||||||
|
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
|
||||||
|
|
||||||
|
match keys.len() {
|
||||||
|
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
|
||||||
|
1 => Ok(PrivateKey(keys.remove(0))),
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"More than one PKCS8-encoded private key found in {path:?}"
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_connector(
|
||||||
|
tls_cfg: &TlsClientConfig,
|
||||||
|
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||||
|
) -> anyhow::Result<TlsConnector> {
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
|
||||||
|
// Load system certificates and add them to the root store
|
||||||
|
let certs = rustls_native_certs::load_native_certs()
|
||||||
|
.with_context(|| "Cannot load system certificates")?;
|
||||||
|
for cert in certs {
|
||||||
|
root_store.add(&Certificate(cert.0)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut config = ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
// To bypass certificate verification
|
||||||
|
if !tls_cfg.tls_verify_certificate {
|
||||||
|
config
|
||||||
|
.dangerous()
|
||||||
|
.set_certificate_verifier(Arc::new(NullVerifier));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(alpn_protocols) = alpn_protocols {
|
||||||
|
config.alpn_protocols = alpn_protocols;
|
||||||
|
}
|
||||||
|
let tls_connector = TlsConnector::from(Arc::new(config));
|
||||||
|
Ok(tls_connector)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_acceptor(
|
||||||
|
tls_cfg: &TlsServerConfig,
|
||||||
|
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||||
|
) -> anyhow::Result<TlsAcceptor> {
|
||||||
|
let mut config = rustls::ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
|
||||||
|
.with_context(|| "invalid tls certificate or private key")?;
|
||||||
|
|
||||||
|
if let Some(alpn_protocols) = alpn_protocols {
|
||||||
|
config.alpn_protocols = alpn_protocols;
|
||||||
|
}
|
||||||
|
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect(
|
||||||
|
server_cfg: &WsClientConfig,
|
||||||
|
tls_cfg: &TlsClientConfig,
|
||||||
|
tcp_stream: TcpStream,
|
||||||
|
) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||||
|
let sni = server_cfg.tls_server_name();
|
||||||
|
info!(
|
||||||
|
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
|
||||||
|
server_cfg.remote_addr.0, server_cfg.remote_addr.1
|
||||||
|
);
|
||||||
|
|
||||||
|
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
|
||||||
|
let tls_stream = tls_connector
|
||||||
|
.connect(sni, tcp_stream)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"failed to do TLS handshake with the server {:?}",
|
||||||
|
server_cfg.remote_addr
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(tls_stream)
|
||||||
|
}
|
503
src/transport.rs
Normal file
503
src/transport.rs
Normal file
|
@ -0,0 +1,503 @@
|
||||||
|
#![allow(unused_imports)]
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
use std::ops::{Deref, Not};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::{tcp, tls, L4Protocol, LocalToRemote, WsClientConfig, WsServerConfig};
|
||||||
|
use anyhow::Context;
|
||||||
|
use fastwebsockets::upgrade::UpgradeFut;
|
||||||
|
use fastwebsockets::{
|
||||||
|
Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite,
|
||||||
|
};
|
||||||
|
use futures_util::{pin_mut, StreamExt};
|
||||||
|
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE, X_FRAME_OPTIONS};
|
||||||
|
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||||
|
use hyper::server::conn::Http;
|
||||||
|
use hyper::service::service_fn;
|
||||||
|
use hyper::upgrade::Upgraded;
|
||||||
|
use hyper::{http, Body, Request, Response, StatusCode};
|
||||||
|
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tokio::io::{
|
||||||
|
AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest, ReadHalf, WriteHalf,
|
||||||
|
};
|
||||||
|
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::time::error::Elapsed;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::udp::{MyUdpSocket, UdpStream};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
use tracing::log::debug;
|
||||||
|
use tracing::{error, field, info, instrument, trace, warn, Instrument, Span};
|
||||||
|
use url::quirks::host;
|
||||||
|
use url::Host;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
struct SpawnExecutor;
|
||||||
|
|
||||||
|
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
|
||||||
|
where
|
||||||
|
Fut: Future + Send + 'static,
|
||||||
|
Fut::Output: Send + 'static,
|
||||||
|
{
|
||||||
|
fn execute(&self, fut: Fut) {
|
||||||
|
tokio::task::spawn(fut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct JwtTunnelConfig {
|
||||||
|
pub id: String,
|
||||||
|
pub p: L4Protocol,
|
||||||
|
pub r: String,
|
||||||
|
pub rp: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
|
||||||
|
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
|
||||||
|
(
|
||||||
|
Header::new(Algorithm::HS256),
|
||||||
|
EncodingKey::from_secret(JWT_SECRET),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
|
||||||
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
|
validation.required_spec_claims = HashSet::with_capacity(0);
|
||||||
|
(validation, DecodingKey::from_secret(JWT_SECRET))
|
||||||
|
});
|
||||||
|
|
||||||
|
pub async fn connect(
|
||||||
|
request_id: Uuid,
|
||||||
|
server_cfg: &WsClientConfig,
|
||||||
|
tunnel_cfg: &LocalToRemote,
|
||||||
|
) -> anyhow::Result<WebSocket<Upgraded>> {
|
||||||
|
let (host, port) = &server_cfg.remote_addr;
|
||||||
|
let tcp_stream = tcp::connect(
|
||||||
|
host,
|
||||||
|
*port,
|
||||||
|
&tunnel_cfg.socket_so_mark,
|
||||||
|
server_cfg.timeout_connect,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let data = JwtTunnelConfig {
|
||||||
|
id: request_id.to_string(),
|
||||||
|
p: tunnel_cfg.protocol,
|
||||||
|
r: tunnel_cfg.remote.0.to_string(),
|
||||||
|
rp: tunnel_cfg.remote.1,
|
||||||
|
};
|
||||||
|
let (alg, secret) = JWT_KEY.deref();
|
||||||
|
let mut req = Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(format!(
|
||||||
|
"/{}/events?bearer={}",
|
||||||
|
&server_cfg.http_upgrade_path_prefix,
|
||||||
|
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
|
||||||
|
))
|
||||||
|
.header(HOST, server_cfg.remote_addr.0.to_string())
|
||||||
|
.header(UPGRADE, "websocket")
|
||||||
|
.header(CONNECTION, "upgrade")
|
||||||
|
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
|
||||||
|
.header(SEC_WEBSOCKET_VERSION, "13")
|
||||||
|
.version(hyper::Version::HTTP_11);
|
||||||
|
|
||||||
|
for (k, v) in &server_cfg.http_headers {
|
||||||
|
req = req.header(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
if let Some(auth) = &server_cfg.http_upgrade_credentials {
|
||||||
|
req = req.header(AUTHORIZATION, auth.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let req = req.body(Body::empty()).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"failed to build HTTP request to contact the server {:?}",
|
||||||
|
server_cfg.remote_addr
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
debug!("with HTTP upgrade request {:?}", req);
|
||||||
|
|
||||||
|
let ws_handshake = match &server_cfg.tls {
|
||||||
|
None => fastwebsockets::handshake::client(&SpawnExecutor, req, tcp_stream).await,
|
||||||
|
Some(tls_cfg) => {
|
||||||
|
let tls_stream = tls::connect(server_cfg, tls_cfg, tcp_stream).await?;
|
||||||
|
fastwebsockets::handshake::client(&SpawnExecutor, req, tls_stream).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (ws, _) = ws_handshake.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"failed to do websocket handshake with the server {:?}",
|
||||||
|
server_cfg.remote_addr
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(ws)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_to_server<R, W>(
|
||||||
|
request_id: Uuid,
|
||||||
|
server_config: &WsClientConfig,
|
||||||
|
remote_cfg: &LocalToRemote,
|
||||||
|
duplex_stream: (R, W),
|
||||||
|
) -> anyhow::Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Send + 'static,
|
||||||
|
W: AsyncWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
let mut ws = connect(request_id, server_config, remote_cfg).await?;
|
||||||
|
ws.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||||
|
|
||||||
|
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
|
||||||
|
let (local_rx, local_tx) = duplex_stream;
|
||||||
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
|
|
||||||
|
// Forward local tx to websocket tx
|
||||||
|
let ping_frequency = server_config.websocket_ping_frequency;
|
||||||
|
tokio::spawn(
|
||||||
|
propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Forward websocket rx to local rx
|
||||||
|
let _ = propagate_write(local_tx, ws_rx, close_rx, server_config.timeout_connect).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn from_query(
|
||||||
|
server_config: &WsServerConfig,
|
||||||
|
query: &str,
|
||||||
|
) -> anyhow::Result<(
|
||||||
|
L4Protocol,
|
||||||
|
Host,
|
||||||
|
u16,
|
||||||
|
Pin<Box<dyn AsyncRead + Send>>,
|
||||||
|
Pin<Box<dyn AsyncWrite + Send>>,
|
||||||
|
)> {
|
||||||
|
let jwt: TokenData<JwtTunnelConfig> = match query.split_once('=') {
|
||||||
|
Some(("bearer", jwt)) => {
|
||||||
|
let (validation, decode_key) = JWT_DECODE.deref();
|
||||||
|
match jsonwebtoken::decode(jwt, decode_key, validation) {
|
||||||
|
Ok(jwt) => jwt,
|
||||||
|
err => {
|
||||||
|
error!("error while decoding jwt for tunnel info {:?}", err);
|
||||||
|
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
|
||||||
|
};
|
||||||
|
|
||||||
|
Span::current().record("id", jwt.claims.id);
|
||||||
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
|
if let Some(allowed_dests) = &server_config.restrict_to {
|
||||||
|
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||||
|
if allowed_dests
|
||||||
|
.iter()
|
||||||
|
.any(|dest| dest == &requested_dest)
|
||||||
|
.not()
|
||||||
|
{
|
||||||
|
warn!(
|
||||||
|
"Rejecting connection with not allowed destination: {}",
|
||||||
|
requested_dest
|
||||||
|
);
|
||||||
|
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match jwt.claims.p {
|
||||||
|
L4Protocol::Udp { .. } => {
|
||||||
|
let host = Host::parse(&jwt.claims.r)?;
|
||||||
|
let cnx = Arc::new(UdpSocket::bind("[::]:0").await?);
|
||||||
|
cnx.connect((host.to_string(), jwt.claims.rp)).await?;
|
||||||
|
Ok((
|
||||||
|
L4Protocol::Udp { timeout: None },
|
||||||
|
host,
|
||||||
|
jwt.claims.rp,
|
||||||
|
Box::pin(MyUdpSocket::new(cnx.clone())),
|
||||||
|
Box::pin(MyUdpSocket::new(cnx)),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
L4Protocol::Tcp { .. } => {
|
||||||
|
let host = Host::parse(&jwt.claims.r)?;
|
||||||
|
let port = jwt.claims.rp;
|
||||||
|
let (rx, tx) = tcp::connect(
|
||||||
|
&host,
|
||||||
|
port,
|
||||||
|
&server_config.socket_so_mark,
|
||||||
|
Duration::from_secs(10),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
.into_split();
|
||||||
|
|
||||||
|
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
|
||||||
|
}
|
||||||
|
_ => Err(anyhow::anyhow!("Invalid upgrade request")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_upgrade(
|
||||||
|
server_config: Arc<WsServerConfig>,
|
||||||
|
mut req: Request<Body>,
|
||||||
|
) -> Result<Response<Body>, anyhow::Error> {
|
||||||
|
if let Some(x) = req.headers().get("X-Forwarded-For") {
|
||||||
|
info!("Request X-Forwarded-For: {:?}", x);
|
||||||
|
Span::current().record("forwarded_for", x.to_str().unwrap_or_default());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.uri().path().ends_with("/events") {
|
||||||
|
warn!(
|
||||||
|
"Rejecting connection with bad upgrade request: {}",
|
||||||
|
req.uri()
|
||||||
|
);
|
||||||
|
return Ok(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body(Body::from("Invalid upgrade request".to_string()))
|
||||||
|
.unwrap_or_default());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (protocol, dest, port, local_rx, local_tx) =
|
||||||
|
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
"Rejecting connection with bad upgrade request: {} {}",
|
||||||
|
err,
|
||||||
|
req.uri()
|
||||||
|
);
|
||||||
|
return Ok(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||||
|
.unwrap_or_default());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
|
||||||
|
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
"Rejecting connection with bad upgrade request: {} {}",
|
||||||
|
err,
|
||||||
|
req.uri()
|
||||||
|
);
|
||||||
|
return Ok(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||||
|
.unwrap_or_default());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::spawn(
|
||||||
|
async move {
|
||||||
|
let (ws_rx, mut ws_tx) = fut.await.unwrap().split(tokio::io::split);
|
||||||
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
|
let connect_timeout = server_config.timeout_connect;
|
||||||
|
let ping_frequency = server_config
|
||||||
|
.websocket_ping_frequency
|
||||||
|
.unwrap_or(Duration::MAX);
|
||||||
|
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||||
|
|
||||||
|
tokio::task::spawn(
|
||||||
|
propagate_write(local_tx, ws_rx, close_rx, connect_timeout)
|
||||||
|
.instrument(Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
|
||||||
|
}
|
||||||
|
.instrument(Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(name="tunnel", level="info", skip_all, fields(id=field::Empty, remote=field::Empty, peer=field::Empty, forwarded_for=field::Empty))]
|
||||||
|
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||||
|
info!(
|
||||||
|
"Starting wstunnel server listening on {}",
|
||||||
|
server_config.bind
|
||||||
|
);
|
||||||
|
|
||||||
|
let config = server_config.clone();
|
||||||
|
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(&server_config.bind).await?;
|
||||||
|
let tls_acceptor = if let Some(tls) = &server_config.tls {
|
||||||
|
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (stream, peer_addr) = listener.accept().await?;
|
||||||
|
let _ = stream.set_nodelay(true);
|
||||||
|
|
||||||
|
Span::current().record("peer", peer_addr.to_string());
|
||||||
|
info!("Accepting connection");
|
||||||
|
let upgrade_fn = upgrade_fn.clone();
|
||||||
|
// TLS
|
||||||
|
if let Some(tls_acceptor) = &tls_acceptor {
|
||||||
|
let tls_acceptor = tls_acceptor.clone();
|
||||||
|
let fut = async move {
|
||||||
|
info!("Doing TLS handshake");
|
||||||
|
let tls_stream = match tls_acceptor.accept(stream).await {
|
||||||
|
Ok(tls_stream) => tls_stream,
|
||||||
|
Err(err) => {
|
||||||
|
error!("error while accepting TLS connection {}", err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let conn_fut = Http::new()
|
||||||
|
.http1_only(true)
|
||||||
|
.serve_connection(tls_stream, service_fn(upgrade_fn))
|
||||||
|
.with_upgrades();
|
||||||
|
|
||||||
|
if let Err(e) = conn_fut.await {
|
||||||
|
error!("Error while upgrading cnx to websocket: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(Span::current());
|
||||||
|
|
||||||
|
tokio::spawn(fut);
|
||||||
|
// Normal
|
||||||
|
} else {
|
||||||
|
let conn_fut = Http::new()
|
||||||
|
.http1_only(true)
|
||||||
|
.serve_connection(stream, service_fn(upgrade_fn))
|
||||||
|
.with_upgrades();
|
||||||
|
|
||||||
|
let fut = async move {
|
||||||
|
if let Err(e) = conn_fut.await {
|
||||||
|
error!("Error while upgrading cnx to weboscket: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(Span::current());
|
||||||
|
|
||||||
|
tokio::spawn(fut);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn propagate_read(
|
||||||
|
local_rx: impl AsyncRead,
|
||||||
|
mut ws_tx: WebSocketWrite<WriteHalf<Upgraded>>,
|
||||||
|
mut close_tx: oneshot::Sender<()>,
|
||||||
|
ping_frequency: Duration,
|
||||||
|
) -> Result<(), WebSocketError> {
|
||||||
|
let _guard = scopeguard::guard((), |_| {
|
||||||
|
info!("Closing local tx ==> websocket tx tunnel");
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; 8 * 1024];
|
||||||
|
pin_mut!(local_rx);
|
||||||
|
loop {
|
||||||
|
let read = select! {
|
||||||
|
biased;
|
||||||
|
|
||||||
|
read_len = local_rx.read(buffer.as_mut_slice()) => read_len,
|
||||||
|
|
||||||
|
_ = close_tx.closed() => break,
|
||||||
|
|
||||||
|
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
|
||||||
|
debug!("sending ping to keep websocket connection alive");
|
||||||
|
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let read_len = match read {
|
||||||
|
Ok(read_len) if read_len > 0 => read_len,
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
"error while reading incoming bytes from local tx tunnel {}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!("read {} bytes", read_len);
|
||||||
|
match ws_tx
|
||||||
|
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len])))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("error while writing to websocket tx tunnel {}", err);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if read_len == buffer.len() {
|
||||||
|
buffer.resize(read_len * 2, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn propagate_write(
|
||||||
|
local_tx: impl AsyncWrite,
|
||||||
|
mut ws_rx: WebSocketRead<ReadHalf<Upgraded>>,
|
||||||
|
mut close_rx: oneshot::Receiver<()>,
|
||||||
|
timeout_connect: Duration,
|
||||||
|
) -> Result<(), WebSocketError> {
|
||||||
|
let _guard = scopeguard::guard((), |_| {
|
||||||
|
info!("Closing local rx <== websocket rx tunnel");
|
||||||
|
});
|
||||||
|
let mut x = |x: Frame<'_>| {
|
||||||
|
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
||||||
|
futures_util::future::ready(anyhow::Ok(()))
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(local_tx);
|
||||||
|
loop {
|
||||||
|
let ret = select! {
|
||||||
|
biased;
|
||||||
|
ret = timeout(timeout_connect, ws_rx.read_frame(&mut x)) => ret,
|
||||||
|
|
||||||
|
_ = &mut close_rx => break,
|
||||||
|
};
|
||||||
|
|
||||||
|
let msg = match ret {
|
||||||
|
Ok(Ok(msg)) => msg,
|
||||||
|
Ok(Err(err)) => {
|
||||||
|
error!("error while reading from websocket rx {}", err);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
trace!("frame {:?}", err);
|
||||||
|
// TODO: Check that the connection is not closed (no easy method to know if a tx is closed ...)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!("frame {:?} {:?}", msg.opcode, msg.payload);
|
||||||
|
let ret = match msg.opcode {
|
||||||
|
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
||||||
|
local_tx.write_all(msg.payload.as_ref()).await
|
||||||
|
}
|
||||||
|
OpCode::Close => break,
|
||||||
|
OpCode::Ping => Ok(()),
|
||||||
|
OpCode::Pong => Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match ret {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(err) => {
|
||||||
|
error!("error while writing bytes to local for rx tunnel {}", err);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
241
src/udp.rs
Normal file
241
src/udp.rs
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
#![allow(unused_imports)]
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use futures_util::future::join;
|
||||||
|
use futures_util::{stream, FutureExt, Stream};
|
||||||
|
use hyper::server;
|
||||||
|
use libc::poll;
|
||||||
|
use pin_project::{pin_project, pinned_drop};
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::io;
|
||||||
|
use std::io::{Error, ErrorKind, IoSlice};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::{pin, Pin};
|
||||||
|
use std::sync::{Arc, RwLock, Weak};
|
||||||
|
use std::task::Poll;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf};
|
||||||
|
use tokio::net::UdpSocket;
|
||||||
|
use tokio::time::Sleep;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
const DEFAULT_UDP_BUFFER_SIZE: usize = 8 * 1024;
|
||||||
|
|
||||||
|
struct UdpServer {
|
||||||
|
listener: UdpSocket,
|
||||||
|
std_socket: std::net::UdpSocket,
|
||||||
|
buffer: Vec<u8>,
|
||||||
|
peers: HashMap<SocketAddr, DuplexStream, ahash::RandomState>,
|
||||||
|
keys_to_delete: Arc<RwLock<Vec<SocketAddr>>>,
|
||||||
|
pub cnx_timeout: Option<Duration>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpServer {
|
||||||
|
pub fn new(listener: UdpSocket, timeout: Option<Duration>) -> Self {
|
||||||
|
let socket = listener.into_std().unwrap();
|
||||||
|
let listener = UdpSocket::from_std(socket.try_clone().unwrap()).unwrap();
|
||||||
|
Self {
|
||||||
|
listener,
|
||||||
|
std_socket: socket,
|
||||||
|
peers: HashMap::with_hasher(ahash::RandomState::new()),
|
||||||
|
buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE],
|
||||||
|
keys_to_delete: Default::default(),
|
||||||
|
cnx_timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clean_dead_keys(&mut self) {
|
||||||
|
let nb_key_to_delete = self.keys_to_delete.read().unwrap().len();
|
||||||
|
if nb_key_to_delete == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
||||||
|
let mut keys_to_delete = self.keys_to_delete.write().unwrap();
|
||||||
|
for key in keys_to_delete.iter() {
|
||||||
|
self.peers.remove(key);
|
||||||
|
}
|
||||||
|
keys_to_delete.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_socket(&self) -> UdpSocket {
|
||||||
|
UdpSocket::from_std(self.std_socket.try_clone().unwrap()).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pin_project(PinnedDrop)]
|
||||||
|
pub struct UdpStream {
|
||||||
|
socket: UdpSocket,
|
||||||
|
peer: SocketAddr,
|
||||||
|
#[pin]
|
||||||
|
deadline: Option<Sleep>,
|
||||||
|
#[pin]
|
||||||
|
io: DuplexStream,
|
||||||
|
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsMut<DuplexStream> for UdpStream {
|
||||||
|
fn as_mut(&mut self) -> &mut DuplexStream {
|
||||||
|
&mut self.io
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pinned_drop]
|
||||||
|
impl PinnedDrop for UdpStream {
|
||||||
|
fn drop(self: Pin<&mut Self>) {
|
||||||
|
if let Some(keys_to_delete) = self.keys_to_delete.upgrade() {
|
||||||
|
keys_to_delete.write().unwrap().push(self.peer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for UdpStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
let project = self.project();
|
||||||
|
if let Some(deadline) = project.deadline.as_pin_mut() {
|
||||||
|
if deadline.poll(cx).is_ready() {
|
||||||
|
return Poll::Ready(Err(Error::new(
|
||||||
|
ErrorKind::TimedOut,
|
||||||
|
format!("UDP stream timeout with {}", project.peer),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
project.io.poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for UdpStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, Error>> {
|
||||||
|
self.socket.poll_send_to(cx, buf, self.peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Error>> {
|
||||||
|
self.socket.poll_send_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_server(
|
||||||
|
bind: SocketAddr,
|
||||||
|
timeout: Option<Duration>,
|
||||||
|
) -> Result<impl Stream<Item = io::Result<UdpStream>>, anyhow::Error> {
|
||||||
|
info!(
|
||||||
|
"Starting UDP server listening cnx on {} with cnx timeout of {}s",
|
||||||
|
bind,
|
||||||
|
timeout.unwrap_or(Duration::from_secs(0)).as_secs()
|
||||||
|
);
|
||||||
|
|
||||||
|
let listener = UdpSocket::bind(bind)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Cannot create UDP server {:?}", bind))?;
|
||||||
|
|
||||||
|
let udp_server = UdpServer::new(listener, timeout);
|
||||||
|
let stream = stream::unfold(udp_server, |mut server| async {
|
||||||
|
loop {
|
||||||
|
server.clean_dead_keys();
|
||||||
|
let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(err) => {
|
||||||
|
error!("Cannot read from UDP server. Closing server: {}", err);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match server.peers.entry(peer_addr) {
|
||||||
|
Entry::Occupied(mut peer) => {
|
||||||
|
let ret = peer.get_mut().write_all(&server.buffer[0..nb_bytes]).await;
|
||||||
|
if let Err(err) = ret {
|
||||||
|
info!("Peer {:?} disconnected {:?}", peer_addr, err);
|
||||||
|
peer.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Entry::Vacant(peer) => {
|
||||||
|
let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE);
|
||||||
|
rx.write_all(&server.buffer[0..nb_bytes])
|
||||||
|
.await
|
||||||
|
.unwrap_or_default(); // should never fail
|
||||||
|
peer.insert(rx);
|
||||||
|
let udp_client = UdpStream {
|
||||||
|
socket: server.clone_socket(),
|
||||||
|
peer: peer_addr,
|
||||||
|
deadline: server
|
||||||
|
.cnx_timeout
|
||||||
|
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
||||||
|
.map(tokio::time::sleep_until),
|
||||||
|
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
||||||
|
io: tx,
|
||||||
|
};
|
||||||
|
return Some((Ok(udp_client), (server)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MyUdpSocket {
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MyUdpSocket {
|
||||||
|
pub fn new(socket: Arc<UdpSocket>) -> Self {
|
||||||
|
Self { socket }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for MyUdpSocket {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
|
||||||
|
.poll_recv_from(cx, buf)
|
||||||
|
.map(|x| x.map(|_| ()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for MyUdpSocket {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, Error>> {
|
||||||
|
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue