diff --git a/src/main.rs b/src/main.rs index 3c198f1..5b99938 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod udp; use base64::Engine; use clap::Parser; -use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; +use futures_util::{stream, TryStreamExt}; use hyper::header::HOST; use hyper::http::{HeaderName, HeaderValue}; use serde::{Deserialize, Serialize}; @@ -22,16 +22,14 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::rustls::server::DnsName; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName}; -use tracing::{error, info, span, Instrument, Level}; +use tracing::{error, info, Level}; use tracing_subscriber::EnvFilter; use url::{Host, Url}; -use uuid::Uuid; /// Use the websockets protocol to tunnel {TCP,UDP} traffic /// wsTunnelClient <---> wsTunnelServer <---> RemoteHost @@ -588,7 +586,9 @@ async fn main() { .map_ok(move |stream| (stream.into_split(), remote.clone())); tokio::spawn(async move { - if let Err(err) = run_tunnel(client_config, tunnel, server).await { + if let Err(err) = + tunnel::client::run_tunnel(client_config, tunnel, server).await + { error!("{:?}", err); } }); @@ -604,7 +604,9 @@ async fn main() { .map_ok(move |stream| (tokio::io::split(stream), remote.clone())); tokio::spawn(async move { - if let Err(err) = run_tunnel(client_config, tunnel, server).await { + if let Err(err) = + tunnel::client::run_tunnel(client_config, tunnel, server).await + { error!("{:?}", err); } }); @@ -618,7 +620,9 @@ async fn main() { .map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest)); tokio::spawn(async move { - if let Err(err) = run_tunnel(client_config, tunnel, server).await { + if let Err(err) = + tunnel::client::run_tunnel(client_config, tunnel, server).await + { error!("{:?}", err); } }); @@ -630,7 +634,7 @@ async fn main() { panic!("Cannot start STDIO server: {}", err); }); tokio::spawn(async move { - if let Err(err) = run_tunnel( + if let Err(err) = tunnel::client::run_tunnel( client_config, tunnel.clone(), stream::once(async move { Ok((server, tunnel.remote)) }), @@ -693,49 +697,3 @@ async fn main() { tokio::signal::ctrl_c().await.unwrap(); } - -async fn run_tunnel( - client_config: Arc, - tunnel: LocalToRemote, - incoming_cnx: T, -) -> anyhow::Result<()> -where - T: Stream>, - R: AsyncRead + Send + 'static, - W: AsyncWrite + Send + 'static, -{ - pin_mut!(incoming_cnx); - while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await { - let request_id = Uuid::now_v7(); - let span = span!( - Level::INFO, - "tunnel", - id = request_id.to_string(), - remote = format!("{}:{}", remote_dest.0, remote_dest.1) - ); - let server_config = client_config.clone(); - let mut tunnel = tunnel.clone(); - tunnel.remote = remote_dest; - - tokio::spawn( - async move { - let ret = tunnel::client::connect_to_server( - request_id, - &server_config, - &tunnel, - cnx_stream, - ) - .await; - - if let Err(ret) = ret { - error!("{:?}", ret); - } - - anyhow::Ok(()) - } - .instrument(span.clone()), - ); - } - - Ok(()) -} diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 95ac7cf..415152d 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,18 +1,22 @@ -use super::{JwtTunnelConfig, MaybeTlsStream, JWT_KEY}; -use crate::{LocalProtocol, LocalToRemote, WsClientConfig}; +use super::{JwtTunnelConfig, JWT_KEY}; +use crate::{LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; use fastwebsockets::WebSocket; +use futures_util::pin_mut; use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; use hyper::{Body, Request}; use std::future::Future; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::oneshot; +use tokio_stream::{Stream, StreamExt}; use tracing::log::debug; -use tracing::{Instrument, Span}; +use tracing::{error, span, Instrument, Level, Span}; +use url::Host; use uuid::Uuid; struct SpawnExecutor; @@ -27,36 +31,30 @@ where } } +fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String { + let cfg = JwtTunnelConfig::new(request_id, tunnel); + let (alg, secret) = JWT_KEY.deref(); + jsonwebtoken::encode(alg, &cfg, secret).unwrap_or_default() +} + pub async fn connect( request_id: Uuid, client_cfg: &WsClientConfig, tunnel_cfg: &LocalToRemote, ) -> anyhow::Result> { - let mut tcp_stream = match client_cfg.cnx_pool().get().await { + let mut pooled_cnx = match client_cfg.cnx_pool().get().await { Ok(tcp_stream) => tcp_stream, Err(err) => Err(anyhow!( "failed to get a connection to the server from the pool: {err:?}" ))?, }; - let data = JwtTunnelConfig { - id: request_id.to_string(), - p: match tunnel_cfg.local_protocol { - LocalProtocol::Tcp => LocalProtocol::Tcp, - LocalProtocol::Udp { .. } => tunnel_cfg.local_protocol, - LocalProtocol::Stdio => LocalProtocol::Tcp, - LocalProtocol::Socks5 => LocalProtocol::Tcp, - }, - r: tunnel_cfg.remote.0.to_string(), - rp: tunnel_cfg.remote.1, - }; - let (alg, secret) = JWT_KEY.deref(); let mut req = Request::builder() .method("GET") .uri(format!( "/{}/events?bearer={}", &client_cfg.http_upgrade_path_prefix, - jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(), + tunnel_to_jwt_token(request_id, tunnel_cfg) )) .header(HOST, &client_cfg.http_header_host) .header(UPGRADE, "websocket") @@ -79,26 +77,20 @@ pub async fn connect( ) })?; debug!("with HTTP upgrade request {:?}", req); - let ws_handshake = match tcp_stream.deref_mut() { - MaybeTlsStream::Plain(cnx) => { - fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await - } - MaybeTlsStream::Tls(cnx) => { - fastwebsockets::handshake::client(&SpawnExecutor, req, cnx.take().unwrap()).await - } - }; - - let (ws, _) = ws_handshake.with_context(|| { - format!( - "failed to do websocket handshake with the server {:?}", - client_cfg.remote_addr - ) - })?; + let transport = pooled_cnx.deref_mut().take().unwrap(); + let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport) + .await + .with_context(|| { + format!( + "failed to do websocket handshake with the server {:?}", + client_cfg.remote_addr + ) + })?; Ok(ws) } -pub async fn connect_to_server( +async fn connect_to_server( request_id: Uuid, client_cfg: &WsClientConfig, remote_cfg: &LocalToRemote, @@ -127,3 +119,39 @@ where Ok(()) } + +pub async fn run_tunnel( + client_config: Arc, + tunnel_cfg: LocalToRemote, + incoming_cnx: T, +) -> anyhow::Result<()> +where + T: Stream>, + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + pin_mut!(incoming_cnx); + while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await { + let request_id = Uuid::now_v7(); + let span = span!( + Level::INFO, + "tunnel", + id = request_id.to_string(), + remote = format!("{}:{}", remote_dest.0, remote_dest.1) + ); + let mut tunnel_cfg = tunnel_cfg.clone(); + tunnel_cfg.remote = remote_dest; + let client_config = client_config.clone(); + + let tunnel = async move { + let _ = connect_to_server(request_id, &client_config, &tunnel_cfg, cnx_stream) + .await + .map_err(|err| error!("{:?}", err)); + } + .instrument(span); + + tokio::spawn(tunnel); + } + + Ok(()) +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 8b1fe60..0709d1d 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,15 +2,20 @@ pub mod client; mod io; pub mod server; -use crate::{tcp, tls, LocalProtocol, WsClientConfig}; +use crate::{tcp, tls, LocalProtocol, LocalToRemote, WsClientConfig}; use async_trait::async_trait; use bb8::ManageConnection; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +use std::io::{Error, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] struct JwtTunnelConfig { @@ -20,6 +25,22 @@ struct JwtTunnelConfig { pub rp: u16, } +impl JwtTunnelConfig { + fn new(request_id: Uuid, tunnel: &LocalToRemote) -> Self { + Self { + id: request_id.to_string(), + p: match tunnel.local_protocol { + LocalProtocol::Tcp => LocalProtocol::Tcp, + LocalProtocol::Udp { .. } => tunnel.local_protocol, + LocalProtocol::Stdio => LocalProtocol::Tcp, + LocalProtocol::Socks5 => LocalProtocol::Tcp, + }, + r: tunnel.remote.0.to_string(), + rp: tunnel.remote.1, + } + } +} + static JWT_SECRET: &[u8; 15] = b"champignonfrais"; static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| { ( @@ -34,23 +55,72 @@ static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { (validation, DecodingKey::from_secret(JWT_SECRET)) }); -pub enum MaybeTlsStream { - Plain(Option), - Tls(Option>), +pub enum TransportStream { + Plain(TcpStream), + Tls(TlsStream), } -impl MaybeTlsStream { - pub fn is_used(&self) -> bool { - match self { - MaybeTlsStream::Plain(Some(_)) | MaybeTlsStream::Tls(Some(_)) => false, - MaybeTlsStream::Plain(None) | MaybeTlsStream::Tls(None) => true, +impl AsyncRead for TransportStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), + TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TransportStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), + TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TransportStream::Plain(cnx) => Pin::new(cnx).poll_flush(cx), + TransportStream::Tls(cnx) => Pin::new(cnx).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TransportStream::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx), + TransportStream::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + TransportStream::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + TransportStream::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match &self { + TransportStream::Plain(cnx) => cnx.is_write_vectored(), + TransportStream::Tls(cnx) => cnx.is_write_vectored(), } } } #[async_trait] impl ManageConnection for WsClientConfig { - type Connection = MaybeTlsStream; + type Connection = Option; type Error = anyhow::Error; async fn connect(&self) -> Result { @@ -65,10 +135,10 @@ impl ManageConnection for WsClientConfig { }; match &self.tls { - None => Ok(MaybeTlsStream::Plain(Some(tcp_stream))), + None => Ok(Some(TransportStream::Plain(tcp_stream))), Some(tls_cfg) => { let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?; - Ok(MaybeTlsStream::Tls(Some(tls_stream))) + Ok(Some(TransportStream::Tls(tls_stream))) } } } @@ -78,6 +148,6 @@ impl ManageConnection for WsClientConfig { } fn has_broken(&self, conn: &mut Self::Connection) -> bool { - conn.is_used() + conn.is_none() } }