feat: Add connection pool to speed up creation of tunnel

This commit is contained in:
Σrebe - Romain GERARD 2023-10-23 19:11:12 +02:00
parent a9420e97fd
commit 6570c857ad
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
11 changed files with 715 additions and 583 deletions

2
.gitignore vendored
View file

@ -2,6 +2,8 @@
# will have compiled files and executables
debug/
target/
artifacts/
dist/
# These are backup files generated by rustfmt
**/*.rs.bk

16
Cargo.lock generated
View file

@ -136,6 +136,19 @@ version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2"
[[package]]
name = "bb8"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98b4b0f25f18bcdc3ac72bdb486ed0acf7e185221fd4dc985bc15db5800b0ba2"
dependencies = [
"async-trait",
"futures-channel",
"futures-util",
"parking_lot",
"tokio",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -361,6 +374,7 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [
"futures-channel",
"futures-core",
"futures-macro",
"futures-task",
@ -1527,7 +1541,9 @@ version = "7.1.1"
dependencies = [
"ahash",
"anyhow",
"async-trait",
"base64",
"bb8",
"clap",
"fast-socks5",
"fastwebsockets",

View file

@ -36,6 +36,9 @@ base64 = "0.21.4"
serde = { version = "1.0.189", features = ["derive"] }
log = "0.4.20"
bb8 = { version = "0.8", features = [] }
async-trait = "0.1.74"
[target.'cfg(target_family = "unix")'.dependencies]
tokio-fd = "0.3.0"

View file

@ -22,7 +22,7 @@ My inspiration came from [this project](https://www.npmjs.com/package/wstunnel)
* Static tunneling (TCP and UDP)
* Dynamic tunneling (socks5 proxy)
* Support for http proxy (when behind one)
* Support for tls/https server (with embedded self signed certificate, see comment in the example section)
* Support for tls/https server (with embedded self-signed certificate, see comment in the example section)
* Support IPv6
* **Standalone binary for linux x86_64** (so just cp it where you want) [here](https://github.com/erebe/wstunnel/releases)
* Standalone archive for windows
@ -64,9 +64,16 @@ Options:
'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443
'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53
'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30]
'socks5://1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel
'socks5://1212?socket_so_mark=2' => each tunnel can have the socket_so_mark option, cf explanation on server command
'socks5://[::1]:1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel
'stdio://google.com:443' => listen for data from stdio, mainly for `ssh -o ProxyCommand="wstunnel client -L stdio://%h:%p ws://localhost:8080" my-server`
--socket-so-mark <INT>
(linux only) Mark network packet with SO_MARK sockoption with the specified value.
You need to use {root, sudo, capabilities} to run wstunnel when using this option
-c, --connection-min-idle <INT>
Client will maintain a pool of open connection to the server, in order to speed up the connection process.
This option set the maximum number of connection that will be kept open.
This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
It will avoid the latency of doing tcp + tls handshake with the server [default: 0]
--tls-sni-override <DOMAIN_NAME>
Domain name that will be use as SNI during TLS handshake
Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header.
@ -145,11 +152,12 @@ wstunnel server ws://[::]:8080
This will create a websocket server listening on any interface on port 8080.
On the client side use this command to forward traffic through the websocket tunnel
```bash
wstunnel client -L socks5://8888 ws://myRemoteHost:8080
wstunnel client -L socks5://127.0.0.1:8888 --connection-min-idle 10 ws://myRemoteHost:8080
```
This command will create a socks5 server listening on port 8888 of a loopback interface and will forward traffic.
With firefox you can setup a proxy using this tunnel, by setting in networking preferences 127.0.0.1:8888 and selecting socks5 proxy
Be sure to check the option `Proxy DNS when using SOCKS v5` for the server to resolve DNS name and not your local machine.
or with curl
@ -160,7 +168,7 @@ curl -x socks5h://127.0.0.1:8888 http://google.com/
### As proxy command for SSH
You can specify `stdio` as source port on the client side if you wish to use wstunnel as part of a proxy command for ssh
```
```bash
ssh -o ProxyCommand="wstunnel client -L stdio://%h:%p ws://localhost:8080" my-server
```
@ -169,7 +177,7 @@ An other useful example is when you want to bypass an http proxy (a corporate pr
The most reliable way to do it is to use wstunnel as described below
Start your wstunnel server with tls activated
```
```bash
wstunnel server wss://[::]:443 --restrict-to 127.0.0.1:22
```
The server will listen on any interface using port 443 (https) and restrict traffic to be forwarded only to the ssh daemon.
@ -180,16 +188,32 @@ It was made in order to add the least possible overhead while still being compli
**Do not rely on wstunnel to protect your privacy, as it only forwards traffic that is already secure by design (ex: https)**
Now on the client side start the client with
```
wstunnel client -L tcp://9999:127.0.0.1:22 -p mycorporateproxy:8080 wss://myRemoteHost:443
```bash
wstunnel client -L tcp://9999:127.0.0.1:22 -p http://mycorporateproxy:8080 wss://myRemoteHost:443
```
It will start a tcp server on port 9999 that will contact the corporate proxy, negotiate a tls connection with the remote host and forward traffic to the ssh daemon on the remote host.
You may now access your server from your local machine on ssh by using
```
```bash
ssh -p 9999 login@127.0.0.1
```
### How to secure the access of your wstunnel server
Generate a secret, let's say `h3GywpDrP6gJEdZ6xbJbZZVFmvFZDCa4KcRd`
Now start you server with the following command
```bash
wstunnel server --restrict-http-upgrade-path-prefix h3GywpDrP6gJEdZ6xbJbZZVFmvFZDCa4KcRd wss://[::]:443
```
And start your client with
```bash
wstunnel client --http-upgrade-path-prefix h3GywpDrP6gJEdZ6xbJbZZVFmvFZDCa4KcRd ... wss://myRemoteHost
```
Now your wstunnel server, will only accept connection if the client specify the correct path prefix during the upgrade request.
### Wireguard and wstunnel
https://kirill888.github.io/notes/wireguard-via-websocket/

View file

@ -4,7 +4,7 @@ mod socks5;
mod stdio;
mod tcp;
mod tls;
mod transport;
mod tunnel;
mod udp;
use base64::Engine;
@ -54,12 +54,29 @@ struct Client {
/// 'tcp://1212:google.com:443' => listen locally on tcp on port 1212 and forward to google.com on port 443
/// 'udp://1212:1.1.1.1:53' => listen locally on udp on port 1212 and forward to cloudflare dns 1.1.1.1 on port 53
/// 'udp://1212:1.1.1.1:53?timeout_sec=10' timeout_sec on udp force close the tunnel after 10sec. Set it to 0 to disable the timeout [default: 30]
/// 'socks5://1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel
/// 'socks5://1212?socket_so_mark=2' => each tunnel can have the socket_so_mark option, cf explanation on server command
/// 'socks5://[::1]:1212' => listen locally with socks5 on port 1212 and forward dynamically requested tunnel
/// 'stdio://google.com:443' => listen for data from stdio, mainly for `ssh -o ProxyCommand="wstunnel client -L stdio://%h:%p ws://localhost:8080" my-server`
#[arg(short='L', long, value_name = "{tcp,udp,socks5,stdio}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg, verbatim_doc_comment)]
local_to_remote: Vec<LocalToRemote>,
/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
/// You need to use {root, sudo, capabilities} to run wstunnel when using this option
#[arg(long, value_name = "INT", verbatim_doc_comment)]
socket_so_mark: Option<i32>,
/// Client will maintain a pool of open connection to the server, in order to speed up the connection process.
/// This option set the maximum number of connection that will be kept open.
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
/// It will avoid the latency of doing tcp + tls handshake with the server
#[arg(
short = 'c',
long,
value_name = "INT",
default_value = "0",
verbatim_doc_comment
)]
connection_min_idle: u32,
/// Domain name that will be use as SNI during TLS handshake
/// Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header.
/// or it will be flagged as fishy and your request rejected
@ -163,7 +180,6 @@ enum LocalProtocol {
#[derive(Clone, Debug)]
pub struct LocalToRemote {
socket_so_mark: Option<i32>,
local_protocol: LocalProtocol,
local: SocketAddr,
remote: (Host<String>, u16),
@ -262,11 +278,8 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
match &arg[..6] {
"tcp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let (dest_host, dest_port, _options) = parse_tunnel_dest(remaining)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Tcp,
local: local_bind,
remote: (dest_host, dest_port),
@ -288,9 +301,6 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
.unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Udp { timeout },
local: local_bind,
remote: (dest_host, dest_port),
@ -300,22 +310,16 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
"socks5:/" => {
let (local_bind, remaining) = parse_local_bind(&arg[9..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
let (dest_host, dest_port, _options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Socks5,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"stdio://" => {
let (dest_host, dest_port, options) = parse_tunnel_dest(&arg[8..])?;
let (dest_host, dest_port, _options) = parse_tunnel_dest(&arg[8..])?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Stdio,
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
remote: (dest_host, dest_port),
@ -441,6 +445,7 @@ impl Debug for WsServerConfig {
#[derive(Clone, Debug)]
pub struct WsClientConfig {
pub remote_addr: (Host<String>, u16),
pub socket_so_mark: Option<i32>,
pub tls: Option<TlsClientConfig>,
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
@ -449,6 +454,7 @@ pub struct WsClientConfig {
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
pub http_proxy: Option<Url>,
cnx_pool: Option<bb8::Pool<WsClientConfig>>,
}
impl WsClientConfig {
@ -459,6 +465,10 @@ impl WsClientConfig {
}
}
pub fn cnx_pool(&self) -> &bb8::Pool<WsClientConfig> {
self.cnx_pool.as_ref().unwrap()
}
pub fn websocket_host_url(&self) -> String {
format!("{}:{}", self.remote_addr.0, self.remote_addr.1)
}
@ -518,11 +528,12 @@ async fn main() {
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
};
let client_config = Arc::new(WsClientConfig {
let mut client_config = WsClientConfig {
remote_addr: (
args.remote_addr.host().unwrap().to_owned(),
args.remote_addr.port_or_known_default().unwrap(),
),
socket_so_mark: args.socket_so_mark,
tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials,
@ -533,11 +544,23 @@ async fn main() {
.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame,
http_proxy: args.http_proxy,
});
cnx_pool: None,
};
let pool = bb8::Pool::builder()
.max_size(1000)
.min_idle(Some(args.connection_min_idle))
.max_lifetime(Some(Duration::from_secs(30)))
.retry_connection(true)
.build(client_config.clone())
.await
.unwrap();
client_config.cnx_pool = Some(pool);
let client_config = Arc::new(client_config);
// Start tunnels
for tunnel in args.local_to_remote.into_iter() {
let server_config = client_config.clone();
let client_config = client_config.clone();
match &tunnel.local_protocol {
LocalProtocol::Tcp => {
@ -551,7 +574,7 @@ async fn main() {
.map_ok(move |stream| (stream.into_split(), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -567,7 +590,7 @@ async fn main() {
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -581,7 +604,7 @@ async fn main() {
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
error!("{:?}", err);
}
});
@ -594,7 +617,7 @@ async fn main() {
});
tokio::spawn(async move {
if let Err(err) = run_tunnel(
server_config,
client_config,
tunnel.clone(),
stream::once(async move { Ok((server, tunnel.remote)) }),
)
@ -646,7 +669,7 @@ async fn main() {
};
info!("{:?}", server_config);
transport::run_server(Arc::new(server_config))
tunnel::server::run_server(Arc::new(server_config))
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err);
@ -658,7 +681,7 @@ async fn main() {
}
async fn run_tunnel<T, R, W>(
server_config: Arc<WsClientConfig>,
client_config: Arc<WsClientConfig>,
tunnel: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
@ -676,14 +699,18 @@ where
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let server_config = server_config.clone();
let server_config = client_config.clone();
let mut tunnel = tunnel.clone();
tunnel.remote = remote_dest;
tokio::spawn(
async move {
let ret =
transport::connect_to_server(request_id, &server_config, &tunnel, cnx_stream)
let ret = tunnel::client::connect_to_server(
request_id,
&server_config,
&tunnel,
cnx_stream,
)
.await;
if let Err(ret) = ret {

View file

@ -100,14 +100,14 @@ pub fn tls_acceptor(
}
pub async fn connect(
server_cfg: &WsClientConfig,
client_cfg: &WsClientConfig,
tls_cfg: &TlsClientConfig,
tcp_stream: TcpStream,
) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = server_cfg.tls_server_name();
let sni = client_cfg.tls_server_name();
info!(
"Doing TLS handshake using sni {sni:?} with the server {}:{}",
server_cfg.remote_addr.0, server_cfg.remote_addr.1
client_cfg.remote_addr.0, client_cfg.remote_addr.1
);
let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?;
@ -117,7 +117,7 @@ pub async fn connect(
.with_context(|| {
format!(
"failed to do TLS handshake with the server {:?}",
server_cfg.remote_addr
client_cfg.remote_addr
)
})?;

View file

@ -1,538 +0,0 @@
use std::cmp::min;
use std::collections::HashSet;
use std::future::Future;
use std::ops::{Deref, Not};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig, WsServerConfig};
use anyhow::Context;
use fastwebsockets::{
Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite,
};
use futures_util::pin_mut;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::upgrade::Upgraded;
use hyper::{http, Body, Request, Response, StatusCode};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
use once_cell::sync::Lazy;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::{TcpListener, UdpSocket};
use tokio::select;
use tokio::sync::oneshot;
use tokio::time::timeout;
use crate::udp::MyUdpSocket;
use serde::{Deserialize, Serialize};
use tracing::log::debug;
use tracing::{error, info, instrument, span, trace, warn, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String,
pub p: LocalProtocol,
pub r: String,
pub rp: u16,
}
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
(
Header::new(Algorithm::HS256),
EncodingKey::from_secret(JWT_SECRET),
)
});
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});
pub async fn connect(
request_id: Uuid,
server_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
) -> anyhow::Result<WebSocket<Upgraded>> {
let (host, port) = &server_cfg.remote_addr;
let tcp_stream = if let Some(http_proxy) = &server_cfg.http_proxy {
tcp::connect_with_http_proxy(
http_proxy,
host,
*port,
&tunnel_cfg.socket_so_mark,
server_cfg.timeout_connect,
)
.await?
} else {
tcp::connect(
host,
*port,
&tunnel_cfg.socket_so_mark,
server_cfg.timeout_connect,
)
.await?
};
let data = JwtTunnelConfig {
id: request_id.to_string(),
p: match tunnel_cfg.local_protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 => LocalProtocol::Tcp,
},
r: tunnel_cfg.remote.0.to_string(),
rp: tunnel_cfg.remote.1,
};
let (alg, secret) = JWT_KEY.deref();
let mut req = Request::builder()
.method("GET")
.uri(format!(
"/{}/events?bearer={}",
&server_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
))
.header(HOST, server_cfg.remote_addr.0.to_string())
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13")
.version(hyper::Version::HTTP_11);
for (k, v) in &server_cfg.http_headers {
req = req.header(k.clone(), v.clone());
}
if let Some(auth) = &server_cfg.http_upgrade_credentials {
req = req.header(AUTHORIZATION, auth.clone());
}
let req = req.body(Body::empty()).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
server_cfg.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
let ws_handshake = match &server_cfg.tls {
None => fastwebsockets::handshake::client(&SpawnExecutor, req, tcp_stream).await,
Some(tls_cfg) => {
let tls_stream = tls::connect(server_cfg, tls_cfg, tcp_stream).await?;
fastwebsockets::handshake::client(&SpawnExecutor, req, tls_stream).await
}
};
let (ws, _) = ws_handshake.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
server_cfg.remote_addr
)
})?;
Ok(ws)
}
pub async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let mut ws = connect(request_id, client_cfg, remote_cfg).await?;
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = propagate_write(local_tx, ws_rx, close_rx).await;
Ok(())
}
async fn from_query(
server_config: &WsServerConfig,
query: &str,
) -> anyhow::Result<(
LocalProtocol,
Host,
u16,
Pin<Box<dyn AsyncRead + Send>>,
Pin<Box<dyn AsyncWrite + Send>>,
)> {
let jwt: TokenData<JwtTunnelConfig> = match query.split_once('=') {
Some(("bearer", jwt)) => {
let (validation, decode_key) = JWT_DECODE.deref();
match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt,
err => {
error!("error while decoding jwt for tunnel info {:?}", err);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
}
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
};
Span::current().record("id", jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Some(allowed_dests) = &server_config.restrict_to {
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests
.iter()
.any(|dest| dest == &requested_dest)
.not()
{
warn!(
"Rejecting connection with not allowed destination: {}",
requested_dest
);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
match jwt.claims.p {
LocalProtocol::Udp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let cnx = Arc::new(UdpSocket::bind("[::]:0").await?);
cnx.connect((host.to_string(), jwt.claims.rp)).await?;
Ok((
LocalProtocol::Udp { timeout: None },
host,
jwt.claims.rp,
Box::pin(MyUdpSocket::new(cnx.clone())),
Box::pin(MyUdpSocket::new(cnx)),
))
}
LocalProtocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(
&host,
port,
&server_config.socket_so_mark,
Duration::from_secs(10),
)
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
}
_ => Err(anyhow::anyhow!("Invalid upgrade request")),
}
}
async fn server_upgrade(
server_config: Arc<WsServerConfig>,
mut req: Request<Body>,
) -> Result<Response<Body>, anyhow::Error> {
if let Some(x) = req.headers().get("X-Forwarded-For") {
info!("Request X-Forwarded-For: {:?}", x);
Span::current().record("forwarded_for", x.to_str().unwrap_or_default());
}
if !req.uri().path().ends_with("/events") {
warn!(
"Rejecting connection with bad upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix {
let path = req.uri().path();
let min_len = min(path.len(), 1);
let max_len = min(path.len(), path_prefix.len() + 1);
if &path[0..min_len] != "/"
|| &path[min_len..max_len] != path_prefix.as_str()
|| !path[max_len..].starts_with('/')
{
warn!(
"Rejecting connection with bad path prefix in upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
}
let (protocol, dest, port, local_rx, local_tx) =
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
tokio::spawn(
async move {
let (ws_rx, mut ws_tx) = match fut.await {
Ok(ws) => ws.split(tokio::io::split),
Err(err) => {
error!("Error during http upgrade request: {:?}", err);
return;
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();
let ping_frequency = server_config
.websocket_ping_frequency
.unwrap_or(Duration::MAX);
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn(
propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
let _ = propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
}
.instrument(Span::current()),
);
Ok(response)
}
#[instrument(name="tunnel", level="info", skip_all, fields(id=tracing::field::Empty, remote=tracing::field::Empty, peer=tracing::field::Empty, forwarded_for=tracing::field::Empty))]
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!(
"Starting wstunnel server listening on {}",
server_config.bind
);
let config = server_config.clone();
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
let listener = TcpListener::bind(&server_config.bind).await?;
let tls_acceptor = if let Some(tls) = &server_config.tls {
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
} else {
None
};
loop {
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
let span = span!(
Level::INFO,
"tunnel",
id = tracing::field::Empty,
remote = tracing::field::Empty,
peer = peer_addr.to_string(),
forwarded_for = tracing::field::Empty
);
info!("Accepting connection");
let upgrade_fn = upgrade_fn.clone();
// TLS
if let Some(tls_acceptor) = &tls_acceptor {
let tls_acceptor = tls_acceptor.clone();
let fut = async move {
info!("Doing TLS handshake");
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(err) => {
error!("error while accepting TLS connection {}", err);
return;
}
};
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(tls_stream, service_fn(upgrade_fn))
.with_upgrades();
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to websocket: {:?}", e);
}
}
.instrument(span);
tokio::spawn(fut);
// Normal
} else {
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(stream, service_fn(upgrade_fn))
.with_upgrades();
let fut = async move {
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to weboscket: {:?}", e);
}
}
.instrument(span);
tokio::spawn(fut);
};
}
}
async fn propagate_read(
local_rx: impl AsyncRead,
mut ws_tx: WebSocketWrite<WriteHalf<Upgraded>>,
mut close_tx: oneshot::Sender<()>,
ping_frequency: Duration,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local tx ==> websocket tx tunnel");
});
let mut buffer = vec![0u8; 8 * 1024];
pin_mut!(local_rx);
loop {
let read = select! {
biased;
read_len = local_rx.read(buffer.as_mut_slice()) => read_len,
_ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?;
continue;
}
};
let read_len = match read {
Ok(read_len) if read_len > 0 => read_len,
Ok(_) => break,
Err(err) => {
warn!(
"error while reading incoming bytes from local tx tunnel {}",
err
);
break;
}
};
trace!("read {} bytes", read_len);
match ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len])))
.await
{
Ok(_) => {}
Err(err) => {
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
}
if read_len == buffer.len() {
buffer.resize(read_len * 2, 0);
}
}
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
Ok(())
}
async fn propagate_write(
local_tx: impl AsyncWrite,
mut ws_rx: WebSocketRead<ReadHalf<Upgraded>>,
mut close_rx: oneshot::Receiver<()>,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local rx <== websocket rx tunnel");
});
let mut x = |x: Frame<'_>| {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
};
pin_mut!(local_tx);
loop {
let ret = select! {
biased;
ret = ws_rx.read_frame(&mut x) => ret,
_ = &mut close_rx => break,
};
let msg = match ret {
Ok(msg) => msg,
Err(err) => {
error!("error while reading from websocket rx {}", err);
break;
}
};
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
local_tx.write_all(msg.payload.as_ref()).await
}
OpCode::Close => break,
OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()),
};
match ret {
Ok(_) => {}
Err(err) => {
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
}
}
Ok(())
}

129
src/tunnel/client.rs Normal file
View file

@ -0,0 +1,129 @@
use super::{JwtTunnelConfig, MaybeTlsStream, JWT_KEY};
use crate::{LocalProtocol, LocalToRemote, WsClientConfig};
use anyhow::{anyhow, Context};
use fastwebsockets::WebSocket;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::upgrade::Upgraded;
use hyper::{Body, Request};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
use tracing::log::debug;
use tracing::{Instrument, Span};
use uuid::Uuid;
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
) -> anyhow::Result<WebSocket<Upgraded>> {
let mut tcp_stream = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream,
Err(err) => Err(anyhow!(
"failed to get a connection to the server from the pool: {err:?}"
))?,
};
let data = JwtTunnelConfig {
id: request_id.to_string(),
p: match tunnel_cfg.local_protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 => LocalProtocol::Tcp,
},
r: tunnel_cfg.remote.0.to_string(),
rp: tunnel_cfg.remote.1,
};
let (alg, secret) = JWT_KEY.deref();
let mut req = Request::builder()
.method("GET")
.uri(format!(
"/{}/events?bearer={}",
&client_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
))
.header(HOST, client_cfg.remote_addr.0.to_string())
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13")
.version(hyper::Version::HTTP_11);
for (k, v) in &client_cfg.http_headers {
req = req.header(k.clone(), v.clone());
}
if let Some(auth) = &client_cfg.http_upgrade_credentials {
req = req.header(AUTHORIZATION, auth.clone());
}
let req = req.body(Body::empty()).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
client_cfg.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
let ws_handshake = match tcp_stream.deref_mut() {
MaybeTlsStream::Plain(cnx) => {
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
}
MaybeTlsStream::Tls(cnx) => {
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
}
};
let (ws, _) = ws_handshake.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
Ok(ws)
}
pub async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let mut ws = connect(request_id, client_cfg, remote_cfg).await?;
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency)
.instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
Ok(())
}

123
src/tunnel/io.rs Normal file
View file

@ -0,0 +1,123 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut;
use hyper::upgrade::Upgraded;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::select;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::log::debug;
use tracing::{error, info, trace, warn};
pub(super) async fn propagate_read(
local_rx: impl AsyncRead,
mut ws_tx: WebSocketWrite<WriteHalf<Upgraded>>,
mut close_tx: oneshot::Sender<()>,
ping_frequency: Duration,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local tx ==> websocket tx tunnel");
});
let mut buffer = vec![0u8; 8 * 1024];
pin_mut!(local_rx);
loop {
let read = select! {
biased;
read_len = local_rx.read(buffer.as_mut_slice()) => read_len,
_ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?;
continue;
}
};
let read_len = match read {
Ok(read_len) if read_len > 0 => read_len,
Ok(_) => break,
Err(err) => {
warn!(
"error while reading incoming bytes from local tx tunnel {}",
err
);
break;
}
};
trace!("read {} bytes", read_len);
match ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len])))
.await
{
Ok(_) => {}
Err(err) => {
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
}
if read_len == buffer.len() {
buffer.resize(read_len * 2, 0);
}
}
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
Ok(())
}
pub(super) async fn propagate_write(
local_tx: impl AsyncWrite,
mut ws_rx: WebSocketRead<ReadHalf<Upgraded>>,
mut close_rx: oneshot::Receiver<()>,
) -> Result<(), WebSocketError> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local rx <== websocket rx tunnel");
});
let mut x = |x: Frame<'_>| {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
};
pin_mut!(local_tx);
loop {
let ret = select! {
biased;
ret = ws_rx.read_frame(&mut x) => ret,
_ = &mut close_rx => break,
};
let msg = match ret {
Ok(msg) => msg,
Err(err) => {
error!("error while reading from websocket rx {}", err);
break;
}
};
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
local_tx.write_all(msg.payload.as_ref()).await
}
OpCode::Close => break,
OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()),
};
match ret {
Ok(_) => {}
Err(err) => {
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
}
}
Ok(())
}

83
src/tunnel/mod.rs Normal file
View file

@ -0,0 +1,83 @@
pub mod client;
mod io;
pub mod server;
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
pub id: String,
pub p: LocalProtocol,
pub r: String,
pub rp: u16,
}
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
(
Header::new(Algorithm::HS256),
EncodingKey::from_secret(JWT_SECRET),
)
});
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::with_capacity(0);
(validation, DecodingKey::from_secret(JWT_SECRET))
});
pub enum MaybeTlsStream {
Plain(Option<TcpStream>),
Tls(Option<TlsStream<TcpStream>>),
}
impl MaybeTlsStream {
pub fn is_used(&self) -> bool {
match self {
MaybeTlsStream::Plain(Some(_)) | MaybeTlsStream::Tls(Some(_)) => false,
MaybeTlsStream::Plain(None) | MaybeTlsStream::Tls(None) => true,
}
}
}
#[async_trait]
impl ManageConnection for WsClientConfig {
type Connection = MaybeTlsStream;
type Error = anyhow::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let (host, port) = &self.remote_addr;
let so_mark = &self.socket_so_mark;
let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout).await?
} else {
tcp::connect(host, *port, so_mark, timeout).await?
};
match &self.tls {
None => Ok(MaybeTlsStream::Plain(Some(tcp_stream))),
Some(tls_cfg) => {
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
Ok(MaybeTlsStream::Tls(Some(tls_stream)))
}
}
}
async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
Ok(())
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_used()
}
}

263
src/tunnel/server.rs Normal file
View file

@ -0,0 +1,263 @@
use std::cmp::min;
use std::ops::{Deref, Not};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use super::{JwtTunnelConfig, JWT_DECODE};
use crate::udp::MyUdpSocket;
use crate::{tcp, tls, LocalProtocol, WsServerConfig};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{http, Body, Request, Response, StatusCode};
use jsonwebtoken::TokenData;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::oneshot;
use tracing::{error, info, instrument, span, warn, Instrument, Level, Span};
use url::Host;
async fn from_query(
server_config: &WsServerConfig,
query: &str,
) -> anyhow::Result<(
LocalProtocol,
Host,
u16,
Pin<Box<dyn AsyncRead + Send>>,
Pin<Box<dyn AsyncWrite + Send>>,
)> {
let jwt: TokenData<JwtTunnelConfig> = match query.split_once('=') {
Some(("bearer", jwt)) => {
let (validation, decode_key) = JWT_DECODE.deref();
match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt,
err => {
error!("error while decoding jwt for tunnel info {:?}", err);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
}
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
};
Span::current().record("id", jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Some(allowed_dests) = &server_config.restrict_to {
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests
.iter()
.any(|dest| dest == &requested_dest)
.not()
{
warn!(
"Rejecting connection with not allowed destination: {}",
requested_dest
);
return Err(anyhow::anyhow!("Invalid upgrade request"));
}
}
match jwt.claims.p {
LocalProtocol::Udp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let cnx = Arc::new(UdpSocket::bind("[::]:0").await?);
cnx.connect((host.to_string(), jwt.claims.rp)).await?;
Ok((
LocalProtocol::Udp { timeout: None },
host,
jwt.claims.rp,
Box::pin(MyUdpSocket::new(cnx.clone())),
Box::pin(MyUdpSocket::new(cnx)),
))
}
LocalProtocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp;
let (rx, tx) = tcp::connect(
&host,
port,
&server_config.socket_so_mark,
Duration::from_secs(10),
)
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
}
_ => Err(anyhow::anyhow!("Invalid upgrade request")),
}
}
async fn server_upgrade(
server_config: Arc<WsServerConfig>,
mut req: Request<Body>,
) -> Result<Response<Body>, anyhow::Error> {
if let Some(x) = req.headers().get("X-Forwarded-For") {
info!("Request X-Forwarded-For: {:?}", x);
Span::current().record("forwarded_for", x.to_str().unwrap_or_default());
}
if !req.uri().path().ends_with("/events") {
warn!(
"Rejecting connection with bad upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
if let Some(path_prefix) = &server_config.restrict_http_upgrade_path_prefix {
let path = req.uri().path();
let min_len = min(path.len(), 1);
let max_len = min(path.len(), path_prefix.len() + 1);
if &path[0..min_len] != "/"
|| &path[min_len..max_len] != path_prefix.as_str()
|| !path[max_len..].starts_with('/')
{
warn!(
"Rejecting connection with bad path prefix in upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request"))
.unwrap_or_default());
}
}
let (protocol, dest, port, local_rx, local_tx) =
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret,
Err(err) => {
warn!(
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
.unwrap_or_default());
}
};
tokio::spawn(
async move {
let (ws_rx, mut ws_tx) = match fut.await {
Ok(ws) => ws.split(tokio::io::split),
Err(err) => {
error!("Error during http upgrade request: {:?}", err);
return;
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();
let ping_frequency = server_config
.websocket_ping_frequency
.unwrap_or(Duration::MAX);
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn(
super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
}
.instrument(Span::current()),
);
Ok(response)
}
#[instrument(name="tunnel", level="info", skip_all, fields(id=tracing::field::Empty, remote=tracing::field::Empty, peer=tracing::field::Empty, forwarded_for=tracing::field::Empty))]
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!(
"Starting wstunnel server listening on {}",
server_config.bind
);
let config = server_config.clone();
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
let listener = TcpListener::bind(&server_config.bind).await?;
let tls_acceptor = if let Some(tls) = &server_config.tls {
Some(tls::tls_acceptor(tls, Some(vec![b"http/1.1".to_vec()]))?)
} else {
None
};
loop {
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
let span = span!(
Level::INFO,
"tunnel",
id = tracing::field::Empty,
remote = tracing::field::Empty,
peer = peer_addr.to_string(),
forwarded_for = tracing::field::Empty
);
info!("Accepting connection");
let upgrade_fn = upgrade_fn.clone();
// TLS
if let Some(tls_acceptor) = &tls_acceptor {
let tls_acceptor = tls_acceptor.clone();
let fut = async move {
info!("Doing TLS handshake");
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(err) => {
error!("error while accepting TLS connection {}", err);
return;
}
};
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(tls_stream, service_fn(upgrade_fn))
.with_upgrades();
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to websocket: {:?}", e);
}
}
.instrument(span);
tokio::spawn(fut);
// Normal
} else {
let conn_fut = Http::new()
.http1_only(true)
.serve_connection(stream, service_fn(upgrade_fn))
.with_upgrades();
let fut = async move {
if let Err(e) = conn_fut.await {
error!("Error while upgrading cnx to weboscket: {:?}", e);
}
}
.instrument(span);
tokio::spawn(fut);
};
}
}