chore(client): Cleanup implementation
This commit is contained in:
parent
7fe768a078
commit
2499d993e2
3 changed files with 156 additions and 100 deletions
66
src/main.rs
66
src/main.rs
|
@ -9,7 +9,7 @@ mod udp;
|
|||
|
||||
use base64::Engine;
|
||||
use clap::Parser;
|
||||
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
|
||||
use futures_util::{stream, TryStreamExt};
|
||||
use hyper::header::HOST;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -22,16 +22,14 @@ use std::str::FromStr;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use tokio_rustls::rustls::server::DnsName;
|
||||
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};
|
||||
|
||||
use tracing::{error, info, span, Instrument, Level};
|
||||
use tracing::{error, info, Level};
|
||||
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use url::{Host, Url};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Use the websockets protocol to tunnel {TCP,UDP} traffic
|
||||
/// wsTunnelClient <---> wsTunnelServer <---> RemoteHost
|
||||
|
@ -588,7 +586,9 @@ async fn main() {
|
|||
.map_ok(move |stream| (stream.into_split(), remote.clone()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -604,7 +604,9 @@ async fn main() {
|
|||
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -618,7 +620,9 @@ async fn main() {
|
|||
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_tunnel(client_config, tunnel, server).await {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -630,7 +634,7 @@ async fn main() {
|
|||
panic!("Cannot start STDIO server: {}", err);
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_tunnel(
|
||||
if let Err(err) = tunnel::client::run_tunnel(
|
||||
client_config,
|
||||
tunnel.clone(),
|
||||
stream::once(async move { Ok((server, tunnel.remote)) }),
|
||||
|
@ -693,49 +697,3 @@ async fn main() {
|
|||
|
||||
tokio::signal::ctrl_c().await.unwrap();
|
||||
}
|
||||
|
||||
async fn run_tunnel<T, R, W>(
|
||||
client_config: Arc<WsClientConfig>,
|
||||
tunnel: LocalToRemote,
|
||||
incoming_cnx: T,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
|
||||
R: AsyncRead + Send + 'static,
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
pin_mut!(incoming_cnx);
|
||||
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
|
||||
);
|
||||
let server_config = client_config.clone();
|
||||
let mut tunnel = tunnel.clone();
|
||||
tunnel.remote = remote_dest;
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let ret = tunnel::client::connect_to_server(
|
||||
request_id,
|
||||
&server_config,
|
||||
&tunnel,
|
||||
cnx_stream,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(ret) = ret {
|
||||
error!("{:?}", ret);
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
use super::{JwtTunnelConfig, MaybeTlsStream, JWT_KEY};
|
||||
use crate::{LocalProtocol, LocalToRemote, WsClientConfig};
|
||||
use super::{JwtTunnelConfig, JWT_KEY};
|
||||
use crate::{LocalToRemote, WsClientConfig};
|
||||
use anyhow::{anyhow, Context};
|
||||
|
||||
use fastwebsockets::WebSocket;
|
||||
use futures_util::pin_mut;
|
||||
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE};
|
||||
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Body, Request};
|
||||
use std::future::Future;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::log::debug;
|
||||
use tracing::{Instrument, Span};
|
||||
use tracing::{error, span, Instrument, Level, Span};
|
||||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
struct SpawnExecutor;
|
||||
|
@ -27,36 +31,30 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String {
|
||||
let cfg = JwtTunnelConfig::new(request_id, tunnel);
|
||||
let (alg, secret) = JWT_KEY.deref();
|
||||
jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
tunnel_cfg: &LocalToRemote,
|
||||
) -> anyhow::Result<WebSocket<Upgraded>> {
|
||||
let mut tcp_stream = match client_cfg.cnx_pool().get().await {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
Ok(tcp_stream) => tcp_stream,
|
||||
Err(err) => Err(anyhow!(
|
||||
"failed to get a connection to the server from the pool: {err:?}"
|
||||
))?,
|
||||
};
|
||||
|
||||
let data = JwtTunnelConfig {
|
||||
id: request_id.to_string(),
|
||||
p: match tunnel_cfg.local_protocol {
|
||||
LocalProtocol::Tcp => LocalProtocol::Tcp,
|
||||
LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol,
|
||||
LocalProtocol::Stdio => LocalProtocol::Tcp,
|
||||
LocalProtocol::Socks5 => LocalProtocol::Tcp,
|
||||
},
|
||||
r: tunnel_cfg.remote.0.to_string(),
|
||||
rp: tunnel_cfg.remote.1,
|
||||
};
|
||||
let (alg, secret) = JWT_KEY.deref();
|
||||
let mut req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(format!(
|
||||
"/{}/events?bearer={}",
|
||||
&client_cfg.http_upgrade_path_prefix,
|
||||
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
|
||||
tunnel_to_jwt_token(request_id, tunnel_cfg)
|
||||
))
|
||||
.header(HOST, &client_cfg.http_header_host)
|
||||
.header(UPGRADE, "websocket")
|
||||
|
@ -79,26 +77,20 @@ pub async fn connect(
|
|||
)
|
||||
})?;
|
||||
debug!("with HTTP upgrade request {:?}", req);
|
||||
let ws_handshake = match tcp_stream.deref_mut() {
|
||||
MaybeTlsStream::Plain(cnx) => {
|
||||
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
|
||||
}
|
||||
MaybeTlsStream::Tls(cnx) => {
|
||||
fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await
|
||||
}
|
||||
};
|
||||
|
||||
let (ws, _) = ws_handshake.with_context(|| {
|
||||
format!(
|
||||
"failed to do websocket handshake with the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to do websocket handshake with the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(ws)
|
||||
}
|
||||
|
||||
pub async fn connect_to_server<R, W>(
|
||||
async fn connect_to_server<R, W>(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
remote_cfg: &LocalToRemote,
|
||||
|
@ -127,3 +119,39 @@ where
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_tunnel<T, R, W>(
|
||||
client_config: Arc<WsClientConfig>,
|
||||
tunnel_cfg: LocalToRemote,
|
||||
incoming_cnx: T,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
|
||||
R: AsyncRead + Send + 'static,
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
pin_mut!(incoming_cnx);
|
||||
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
"tunnel",
|
||||
id = request_id.to_string(),
|
||||
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
|
||||
);
|
||||
let mut tunnel_cfg = tunnel_cfg.clone();
|
||||
tunnel_cfg.remote = remote_dest;
|
||||
let client_config = client_config.clone();
|
||||
|
||||
let tunnel = async move {
|
||||
let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream)
|
||||
.await
|
||||
.map_err(|err| error!("{:?}", err));
|
||||
}
|
||||
.instrument(span);
|
||||
|
||||
tokio::spawn(tunnel);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -2,15 +2,20 @@ pub mod client;
|
|||
mod io;
|
||||
pub mod server;
|
||||
|
||||
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
|
||||
use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use bb8::ManageConnection;
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::io::{Error, IoSlice};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct JwtTunnelConfig {
|
||||
|
@ -20,6 +25,22 @@ struct JwtTunnelConfig {
|
|||
pub rp: u16,
|
||||
}
|
||||
|
||||
impl JwtTunnelConfig {
|
||||
fn new(request_id: Uuid, tunnel: &LocalToRemote) -> Self {
|
||||
Self {
|
||||
id: request_id.to_string(),
|
||||
p: match tunnel.local_protocol {
|
||||
LocalProtocol::Tcp => LocalProtocol::Tcp,
|
||||
LocalProtocol::Udp { .. } => tunnel.local_protocol,
|
||||
LocalProtocol::Stdio => LocalProtocol::Tcp,
|
||||
LocalProtocol::Socks5 => LocalProtocol::Tcp,
|
||||
},
|
||||
r: tunnel.remote.0.to_string(),
|
||||
rp: tunnel.remote.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
|
||||
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
|
||||
(
|
||||
|
@ -34,23 +55,72 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
|
|||
(validation, DecodingKey::from_secret(JWT_SECRET))
|
||||
});
|
||||
|
||||
pub enum MaybeTlsStream {
|
||||
Plain(Option<TcpStream>),
|
||||
Tls(Option<TlsStream<TcpStream>>),
|
||||
pub enum TransportStream {
|
||||
Plain(TcpStream),
|
||||
Tls(TlsStream<TcpStream>),
|
||||
}
|
||||
|
||||
impl MaybeTlsStream {
|
||||
pub fn is_used(&self) -> bool {
|
||||
match self {
|
||||
MaybeTlsStream::Plain(Some(_)) | MaybeTlsStream::Tls(Some(_)) => false,
|
||||
MaybeTlsStream::Plain(None) | MaybeTlsStream::Tls(None) => true,
|
||||
impl AsyncRead for TransportStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TransportStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match &self {
|
||||
TransportStream::Plain(cnx) => cnx.is_write_vectored(),
|
||||
TransportStream::Tls(cnx) => cnx.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ManageConnection for WsClientConfig {
|
||||
type Connection = MaybeTlsStream;
|
||||
type Connection = Option<TransportStream>;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||
|
@ -65,10 +135,10 @@ impl ManageConnection for WsClientConfig {
|
|||
};
|
||||
|
||||
match &self.tls {
|
||||
None => Ok(MaybeTlsStream::Plain(Some(tcp_stream))),
|
||||
None => Ok(Some(TransportStream::Plain(tcp_stream))),
|
||||
Some(tls_cfg) => {
|
||||
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
|
||||
Ok(MaybeTlsStream::Tls(Some(tls_stream)))
|
||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -78,6 +148,6 @@ impl ManageConnection for WsClientConfig {
|
|||
}
|
||||
|
||||
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
||||
conn.is_used()
|
||||
conn.is_none()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue