chore(client): Cleanup implementation

This commit is contained in:
Σrebe - Romain GERARD 2023-10-28 15:55:14 +02:00
parent 7fe768a078
commit 2499d993e2
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 156 additions and 100 deletions

View file

@ -9,7 +9,7 @@ mod udp;
use base64::Engine;
use clap::Parser;
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
use futures_util::{stream, TryStreamExt};
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
@ -22,16 +22,14 @@ use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::server::DnsName;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
use tracing::{error, info, span, Instrument, Level};
use tracing::{error, info, Level};
use tracing_subscriber::EnvFilter;
use url::{Host, Url};
use uuid::Uuid;
/// Use the websockets protocol to tunnel {TCP,UDP} traffic
/// wsTunnelClient <---> wsTunnelServer <---> RemoteHost
@ -588,7 +586,9 @@ async fn main() {
.map_ok(move |stream| (stream.into_split(), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err);
}
});
@ -604,7 +604,9 @@ async fn main() {
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
tokio::spawn(async move {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err);
}
});
@ -618,7 +620,9 @@ async fn main() {
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
tokio::spawn(async move {
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
if let Err(err) =
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err);
}
});
@ -630,7 +634,7 @@ async fn main() {
panic!("Cannot start STDIO server: {}", err);
});
tokio::spawn(async move {
if let Err(err) = run_tunnel(
if let Err(err) = tunnel::client::run_tunnel(
client_config,
tunnel.clone(),
stream::once(async move { Ok((server, tunnel.remote)) }),
@ -693,49 +697,3 @@ async fn main() {
tokio::signal::ctrl_c().await.unwrap();
}
async fn run_tunnel<T, R, W>(
client_config: Arc<WsClientConfig>,
tunnel: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pin_mut!(incoming_cnx);
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let server_config = client_config.clone();
let mut tunnel = tunnel.clone();
tunnel.remote = remote_dest;
tokio::spawn(
async move {
let ret = tunnel::client::connect_to_server(
request_id,
&server_config,
&tunnel,
cnx_stream,
)
.await;
if let Err(ret) = ret {
error!("{:?}", ret);
}
anyhow::Ok(())
}
.instrument(span.clone()),
);
}
Ok(())
}

View file

@ -1,18 +1,22 @@
use super::{JwtTunnelConfig, MaybeTlsStream, JWT_KEY};
use crate::{LocalProtocol, LocalToRemote, WsClientConfig};
use super::{JwtTunnelConfig, JWT_KEY};
use crate::{LocalToRemote, WsClientConfig};
use anyhow::{anyhow, Context};
use fastwebsockets::WebSocket;
use futures_util::pin_mut;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::upgrade::Upgraded;
use hyper::{Body, Request};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot;
use tokio_stream::{Stream, StreamExt};
use tracing::log::debug;
use tracing::{Instrument, Span};
use tracing::{error, span, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
struct SpawnExecutor;
@ -27,36 +31,30 @@ where
}
}
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String {
let cfg = JwtTunnelConfig::new(request_id, tunnel);
let (alg, secret) = JWT_KEY.deref();
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
}
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
tunnel_cfg: &LocalToRemote,
) -> anyhow::Result<WebSocket<Upgraded>> {
let mut tcp_stream = match client_cfg.cnx_pool().get().await {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream,
Err(err) => Err(anyhow!(
"failed to get a connection to the server from the pool: {err:?}"
))?,
};
let data = JwtTunnelConfig {
id: request_id.to_string(),
p: match tunnel_cfg.local_protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 => LocalProtocol::Tcp,
},
r: tunnel_cfg.remote.0.to_string(),
rp: tunnel_cfg.remote.1,
};
let (alg, secret) = JWT_KEY.deref();
let mut req = Request::builder()
.method("GET")
.uri(format!(
"/{}/events?bearer={}",
&client_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
tunnel_to_jwt_token(request_id, tunnel_cfg)
))
.header(HOST, &client_cfg.http_header_host)
.header(UPGRADE, "websocket")
@ -79,26 +77,20 @@ pub async fn connect(
)
})?;
debug!("with HTTP upgrade request {:?}", req);
let ws_handshake = match tcp_stream.deref_mut() {
MaybeTlsStream::Plain(cnx) => {
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
}
MaybeTlsStream::Tls(cnx) => {
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
}
};
let (ws, _) = ws_handshake.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
let transport = pooled_cnx.deref_mut().take().unwrap();
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
.await
.with_context(|| {
format!(
"failed to do websocket handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
Ok(ws)
}
pub async fn connect_to_server<R, W>(
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,
remote_cfg: &LocalToRemote,
@ -127,3 +119,39 @@ where
Ok(())
}
pub async fn run_tunnel<T, R, W>(
client_config: Arc<WsClientConfig>,
tunnel_cfg: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pin_mut!(incoming_cnx);
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let mut tunnel_cfg = tunnel_cfg.clone();
tunnel_cfg.remote = remote_dest;
let client_config = client_config.clone();
let tunnel = async move {
let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream)
.await
.map_err(|err| error!("{:?}", err));
}
.instrument(span);
tokio::spawn(tunnel);
}
Ok(())
}

View file

@ -2,15 +2,20 @@ pub mod client;
mod io;
pub mod server;
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::io::{Error, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JwtTunnelConfig {
@ -20,6 +25,22 @@ struct JwtTunnelConfig {
pub rp: u16,
}
impl JwtTunnelConfig {
fn new(request_id: Uuid, tunnel: &LocalToRemote) -> Self {
Self {
id: request_id.to_string(),
p: match tunnel.local_protocol {
LocalProtocol::Tcp => LocalProtocol::Tcp,
LocalProtocol::Udp { .. } => tunnel.local_protocol,
LocalProtocol::Stdio => LocalProtocol::Tcp,
LocalProtocol::Socks5 => LocalProtocol::Tcp,
},
r: tunnel.remote.0.to_string(),
rp: tunnel.remote.1,
}
}
}
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
(
@ -34,23 +55,72 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
(validation, DecodingKey::from_secret(JWT_SECRET))
});
pub enum MaybeTlsStream {
Plain(Option<TcpStream>),
Tls(Option<TlsStream<TcpStream>>),
pub enum TransportStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl MaybeTlsStream {
pub fn is_used(&self) -> bool {
match self {
MaybeTlsStream::Plain(Some(_)) | MaybeTlsStream::Tls(Some(_)) => false,
MaybeTlsStream::Plain(None) | MaybeTlsStream::Tls(None) => true,
impl AsyncRead for TransportStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TransportStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match &self {
TransportStream::Plain(cnx) => cnx.is_write_vectored(),
TransportStream::Tls(cnx) => cnx.is_write_vectored(),
}
}
}
#[async_trait]
impl ManageConnection for WsClientConfig {
type Connection = MaybeTlsStream;
type Connection = Option<TransportStream>;
type Error = anyhow::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
@ -65,10 +135,10 @@ impl ManageConnection for WsClientConfig {
};
match &self.tls {
None => Ok(MaybeTlsStream::Plain(Some(tcp_stream))),
None => Ok(Some(TransportStream::Plain(tcp_stream))),
Some(tls_cfg) => {
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
Ok(MaybeTlsStream::Tls(Some(tls_stream)))
Ok(Some(TransportStream::Tls(tls_stream)))
}
}
}
@ -78,6 +148,6 @@ impl ManageConnection for WsClientConfig {
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_used()
conn.is_none()
}
}