use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::{tunnel, WsClientConfig}; use futures_util::pin_mut; use hyper::header::COOKIE; use jsonwebtoken::TokenData; use log::debug; use std::future::Future; use std::ops::Deref; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::oneshot; use tokio_stream::{Stream, StreamExt}; use tracing::{error, span, Instrument, Level, Span, event}; use url::Host; use uuid::Uuid; async fn connect_to_server( request_id: Uuid, client_cfg: &WsClientConfig, remote_cfg: &RemoteAddr, duplex_stream: (R, W), ) -> anyhow::Result<()> where R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { // Connect to server with the correct protocol let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { TransportScheme::Ws | TransportScheme::Wss => { tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg) .await .map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))? } TransportScheme::Http | TransportScheme::Https => { tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg) .await .map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))? } }; debug!("Server response: {:?}", response); let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); // Forward local tx to websocket tx let ping_frequency = client_cfg.websocket_ping_frequency; tokio::spawn( super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) .instrument(Span::current()), ); // Forward websocket rx to local rx let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; Ok(()) } pub async fn run_tunnel(client_config: Arc, incoming_cnx: T) -> anyhow::Result<()> where T: Stream>, R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { pin_mut!(incoming_cnx); while let Some(Ok((cnx_stream, remote_addr))) = incoming_cnx.next().await { let request_id = Uuid::now_v7(); let span = span!( Level::INFO, "tunnel", id = request_id.to_string(), remote = format!("{}:{}", remote_addr.host, remote_addr.port) ); let client_config = client_config.clone(); let tunnel = async move { let _ = connect_to_server(request_id, &client_config, &remote_addr, cnx_stream) .await .map_err(|err| error!("{:?}", err)); } .instrument(span); tokio::spawn(tunnel); } Ok(()) } pub async fn run_reverse_tunnel( client_cfg: Arc, remote_addr: RemoteAddr, connect_to_dest: F, ) -> anyhow::Result<()> where F: Fn(Option) -> Fut, Fut: Future>, T: AsyncRead + AsyncWrite + Send + 'static, { loop { let client_config = client_cfg.clone(); let request_id = Uuid::now_v7(); let span = span!( Level::INFO, "tunnel", id = request_id.to_string(), remote = format!("{}:{}", remote_addr.host, remote_addr.port) ); // Correctly configure tunnel cfg let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { TransportScheme::Ws | TransportScheme::Wss => { match tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr) .instrument(span.clone()) .await { Ok((r, w, response)) => (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response), Err(err) => { event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; continue; } } } TransportScheme::Http | TransportScheme::Https => { match tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr) .instrument(span.clone()) .await { Ok((r, w, response)) => (TunnelReader::Http2(r), TunnelWriter::Http2(w), response), Err(err) => { event!(parent: &span, Level::ERROR, "Retrying in 1sec, cannot connect to remote server: {:?}", err); tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; continue; } } } }; // Connect to endpoint event!(parent: &span, Level::DEBUG, "Server response: {:?}", response); let remote = response .headers .get(COOKIE) .and_then(|h| h.to_str().ok()) .and_then(|h| { let (validation, decode_key) = JWT_DECODE.deref(); let jwt: Option> = jsonwebtoken::decode(h, decode_key, validation).ok(); jwt }) .map(|jwt| RemoteAddr { protocol: jwt.claims.p, host: Host::parse(&jwt.claims.r).unwrap_or_else(|_| Host::Domain(String::new())), port: jwt.claims.rp, }); let stream = match connect_to_dest(remote).instrument(span.clone()).await { Ok(s) => s, Err(err) => { event!(parent: &span, Level::ERROR, "Cannot connect to xxxx: {err:?}"); continue; } }; let (local_rx, local_tx) = tokio::io::split(stream); let (close_tx, close_rx) = oneshot::channel::<()>(); let tunnel = async move { let ping_frequency = client_config.websocket_ping_frequency; tokio::spawn( super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) .in_current_span(), ); // Forward websocket rx to local rx let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; } .instrument(span.clone()); tokio::spawn(tunnel); } }