wstunnel/src/tunnel/client.rs

223 lines
7.1 KiB
Rust
Raw Normal View History

2023-12-01 21:25:01 +00:00
use super::{to_host_port, JwtTunnelConfig, JWT_KEY};
2023-11-26 17:22:28 +00:00
use crate::{LocalToRemote, WsClientConfig};
use anyhow::{anyhow, Context};
use fastwebsockets::WebSocket;
2023-10-28 13:55:14 +00:00
use futures_util::pin_mut;
2023-12-01 21:25:01 +00:00
use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::upgrade::Upgraded;
2023-12-01 21:25:01 +00:00
use hyper::{Body, Request, Response};
use std::future::Future;
use std::ops::{Deref, DerefMut};
2023-10-28 13:55:14 +00:00
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
2023-10-28 13:55:14 +00:00
use tokio_stream::{Stream, StreamExt};
use tracing::log::debug;
2023-10-28 13:55:14 +00:00
use tracing::{error, span, Instrument, Level, Span};
2023-12-01 21:25:01 +00:00
use url::{Host, Url};
use uuid::Uuid;
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
2023-10-28 13:55:14 +00:00
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
2023-12-01 21:25:01 +00:00
) -> anyhow::Result<(WebSocket<Upgraded>, Response<Body>)> {
2023-10-28 13:55:14 +00:00
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream,
2023-10-30 07:13:38 +00:00
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
};
let mut req = Request::builder()
.method("GET")
.uri(format!(
"/{}/events?bearer={}",
&client_cfg.http_upgrade_path_prefix,
2023-10-28 13:55:14 +00:00
tunnel_to_jwt_token(request_id, tunnel_cfg)
))
2023-10-27 07:15:15 +00:00
.header(HOST, &client_cfg.http_header_host)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13")
.version(hyper::Version::HTTP_11);
for (k, v) in &client_cfg.http_headers {
2023-10-26 20:04:41 +00:00
req = req.header(k, v);
}
if let Some(auth) = &client_cfg.http_upgrade_credentials {
2023-10-26 20:04:41 +00:00
req = req.header(AUTHORIZATION, auth);
}
let req = req.body(Body::empty()).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
client_cfg.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
2023-10-28 13:55:14 +00:00
let transport = pooled_cnx.deref_mut().take().unwrap();
2023-12-01 21:25:01 +00:00
let (ws, response) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
2023-10-28 13:55:14 +00:00
.await
2023-10-30 07:13:38 +00:00
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
2023-12-01 21:25:01 +00:00
Ok((ws, response))
}
2023-10-28 13:55:14 +00:00
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
duplex_stream: (R, W),
) -> anyhow::Result<()>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
2023-12-01 21:25:01 +00:00
let (mut ws, _) = connect(request_id, client_cfg, remote_cfg).await?;
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (local_rx, local_tx) = duplex_stream;
let (close_tx, close_rx) = oneshot::channel::<()>();
// Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
Ok(())
}
2023-10-28 13:55:14 +00:00
pub async fn run_tunnel<T, R, W>(
client_config: Arc<WsClientConfig>,
tunnel_cfg: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pin_mut!(incoming_cnx);
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let mut tunnel_cfg = tunnel_cfg.clone();
tunnel_cfg.remote = remote_dest;
let client_config = client_config.clone();
let tunnel = async move {
let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream)
.await
.map_err(|err| error!("{:?}", err));
}
.instrument(span);
tokio::spawn(tunnel);
}
Ok(())
}
2023-11-26 14:47:49 +00:00
2023-11-26 17:22:28 +00:00
pub async fn run_reverse_tunnel<F, Fut, T>(
2023-11-26 14:47:49 +00:00
client_config: Arc<WsClientConfig>,
mut tunnel_cfg: LocalToRemote,
2023-11-26 17:22:28 +00:00
connect_to_dest: F,
) -> anyhow::Result<()>
where
2023-12-01 21:25:01 +00:00
F: Fn((Host, u16)) -> Fut,
2023-11-26 17:22:28 +00:00
Fut: Future<Output = anyhow::Result<T>>,
T: AsyncRead + AsyncWrite + Send + 'static,
{
2023-11-26 14:47:49 +00:00
// Invert local with remote
2023-12-01 21:25:01 +00:00
let remote_ori = tunnel_cfg.remote;
tunnel_cfg.remote = to_host_port(tunnel_cfg.local);
2023-11-26 14:47:49 +00:00
loop {
let client_config = client_config.clone();
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", tunnel_cfg.remote.0, tunnel_cfg.remote.1)
);
let _span = span.enter();
// Correctly configure tunnel cfg
2023-12-01 21:25:01 +00:00
let (mut ws, response) = connect(request_id, &client_config, &tunnel_cfg)
2023-11-26 14:47:49 +00:00
.instrument(span.clone())
.await?;
ws.set_auto_apply_mask(client_config.websocket_mask_frame);
// Connect to endpoint
2023-12-01 21:25:01 +00:00
let remote: (Host, u16) = response
.headers()
.get(COOKIE)
.and_then(|h| {
h.to_str()
.ok()
.and_then(|s| Url::parse(s).ok())
.and_then(|url| match (url.host(), url.port()) {
(Some(h), Some(p)) => Some((h.to_owned(), p)),
_ => None,
})
})
.unwrap_or(remote_ori.clone());
let stream = connect_to_dest(remote.clone()).instrument(span.clone()).await;
2023-11-26 14:47:49 +00:00
let stream = match stream {
Ok(s) => s,
Err(err) => {
error!("Cannot connect to {remote:?}: {err:?}");
continue;
}
};
let (local_rx, local_tx) = tokio::io::split(stream);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
let (close_tx, close_rx) = oneshot::channel::<()>();
let tunnel = async move {
let ping_frequency = client_config.websocket_ping_frequency;
tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()),
);
// Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
}
.instrument(span.clone());
tokio::spawn(tunnel);
}
}