diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 00ff6a8..b6ff519 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,6 +1,8 @@ use ahash::{HashMap, HashMapExt}; -use futures_util::{Stream, StreamExt}; +use anyhow::anyhow; +use futures_util::{pin_mut, Stream, StreamExt}; use std::cmp::min; +use std::future::Future; use std::io; use std::ops::{Deref, Not}; use std::pin::Pin; @@ -18,9 +20,9 @@ use parking_lot::Mutex; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpListener; -use tokio::sync::oneshot; -use tokio_stream::wrappers::TcpListenerStream; +use tokio::net::{TcpListener, TcpStream}; +use tokio::select; +use tokio::sync::{mpsc, oneshot}; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; @@ -48,7 +50,7 @@ async fn from_query( _err => return Err(anyhow::anyhow!("Invalid upgrade request")), }; - Span::current().record("id", jwt.claims.id); + Span::current().record("id", &jwt.claims.id); Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); if let Some(allowed_dests) = &server_config.restrict_to { let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); @@ -81,42 +83,27 @@ async fn from_query( } LocalProtocol::ReverseTcp => { #[allow(clippy::type_complexity)] - static REVERSE: Lazy, u16), TcpListenerStream>>> = + static SERVERS: Lazy, u16), mpsc::Receiver>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); - let listening_server = REVERSE.lock().remove(&local_srv); - let mut listening_server = if let Some(listening_server) = listening_server { - listening_server - } else { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - tcp::run_server(bind.parse()?).await? - }; - - let tcp = listening_server.next().await.unwrap()?; - let (local_rx, local_tx) = tokio::io::split(tcp); - REVERSE.lock().insert(local_srv.clone(), listening_server); + let bind = format!("{}:{}", local_srv.0, local_srv.1); + let listening_server = tcp::run_server(bind.parse()?); + let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; + let (local_rx, local_tx) = tcp.into_split(); Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) } LocalProtocol::ReverseUdp { timeout } => { #[allow(clippy::type_complexity)] - static REVERSE: Lazy< - Mutex, u16), Pin> + Send>>>>, - > = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); + static SERVERS: Lazy, u16), mpsc::Receiver>>> = + Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp); - let listening_server = REVERSE.lock().remove(&local_srv); - let mut listening_server = if let Some(listening_server) = listening_server { - listening_server - } else { - let bind = format!("{}:{}", local_srv.0, local_srv.1); - Box::pin(udp::run_server(bind.parse()?, timeout).await?) - }; - - let udp = listening_server.next().await.unwrap()?; + let bind = format!("{}:{}", local_srv.0, local_srv.1); + let listening_server = udp::run_server(bind.parse()?, timeout); + let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let (local_rx, local_tx) = tokio::io::split(udp); - REVERSE.lock().insert(local_srv.clone(), listening_server); Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx))) } @@ -124,6 +111,63 @@ async fn from_query( } } +#[allow(clippy::type_complexity)] +async fn run_listening_server( + local_srv: &(Host, u16), + servers: &Mutex, u16), mpsc::Receiver>>, + gen_listening_server: Fut, +) -> anyhow::Result +where + Fut: Future>, + FutOut: Stream> + Send + 'static, + T: Send + 'static, +{ + let listening_server = servers.lock().remove(local_srv); + let mut listening_server = if let Some(listening_server) = listening_server { + listening_server + } else { + let listening_server = gen_listening_server.await?; + let (tx, rx) = mpsc::channel::(1); + let fut = async move { + pin_mut!(listening_server); + loop { + select! { + biased; + cnx = listening_server.next() => { + match cnx { + None => break, + Some(Err(err)) => { + warn!("Error while listening for incoming connections {err:?}"); + break; + } + Some(Ok(cnx)) => { + if tx.send_timeout(cnx, Duration::from_secs(30)).await.is_err() { + break; + } + } + } + }, + + _ = tx.closed() => { + break; + } + } + } + info!("Stopping listening server"); + }; + + tokio::spawn(fut.instrument(Span::current())); + rx + }; + + let cnx = listening_server + .recv() + .await + .ok_or_else(|| anyhow!("listening server stopped"))?; + servers.lock().insert(local_srv.clone(), listening_server); + Ok(cnx) +} + async fn server_upgrade( server_config: Arc, mut req: Request,