Correctly close reverse tunnel server
This commit is contained in:
parent
7374e5ea75
commit
6e0386c416
1 changed files with 74 additions and 30 deletions
|
@ -1,6 +1,8 @@
|
||||||
use ahash::{HashMap, HashMapExt};
|
use ahash::{HashMap, HashMapExt};
|
||||||
use futures_util::{Stream, StreamExt};
|
use anyhow::anyhow;
|
||||||
|
use futures_util::{pin_mut, Stream, StreamExt};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::ops::{Deref, Not};
|
use std::ops::{Deref, Not};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
@ -18,9 +20,9 @@ use parking_lot::Mutex;
|
||||||
|
|
||||||
use crate::udp::UdpStream;
|
use crate::udp::UdpStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::oneshot;
|
use tokio::select;
|
||||||
use tokio_stream::wrappers::TcpListenerStream;
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
||||||
use url::Host;
|
use url::Host;
|
||||||
|
|
||||||
|
@ -48,7 +50,7 @@ async fn from_query(
|
||||||
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
|
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
|
||||||
};
|
};
|
||||||
|
|
||||||
Span::current().record("id", jwt.claims.id);
|
Span::current().record("id", &jwt.claims.id);
|
||||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
if let Some(allowed_dests) = &server_config.restrict_to {
|
if let Some(allowed_dests) = &server_config.restrict_to {
|
||||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||||
|
@ -81,42 +83,27 @@ async fn from_query(
|
||||||
}
|
}
|
||||||
LocalProtocol::ReverseTcp => {
|
LocalProtocol::ReverseTcp => {
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
static REVERSE: Lazy<Mutex<HashMap<(Host<String>, u16), TcpListenerStream>>> =
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<TcpStream>>>> =
|
||||||
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||||
let listening_server = REVERSE.lock().remove(&local_srv);
|
|
||||||
let mut listening_server = if let Some(listening_server) = listening_server {
|
|
||||||
listening_server
|
|
||||||
} else {
|
|
||||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||||
tcp::run_server(bind.parse()?).await?
|
let listening_server = tcp::run_server(bind.parse()?);
|
||||||
};
|
let tcp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||||
|
let (local_rx, local_tx) = tcp.into_split();
|
||||||
let tcp = listening_server.next().await.unwrap()?;
|
|
||||||
let (local_rx, local_tx) = tokio::io::split(tcp);
|
|
||||||
REVERSE.lock().insert(local_srv.clone(), listening_server);
|
|
||||||
|
|
||||||
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
|
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
|
||||||
}
|
}
|
||||||
LocalProtocol::ReverseUdp { timeout } => {
|
LocalProtocol::ReverseUdp { timeout } => {
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
static REVERSE: Lazy<
|
static SERVERS: Lazy<Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<UdpStream>>>> =
|
||||||
Mutex<HashMap<(Host<String>, u16), Pin<Box<dyn Stream<Item = io::Result<UdpStream>> + Send>>>>,
|
Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
||||||
> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0)));
|
|
||||||
|
|
||||||
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);
|
||||||
let listening_server = REVERSE.lock().remove(&local_srv);
|
|
||||||
let mut listening_server = if let Some(listening_server) = listening_server {
|
|
||||||
listening_server
|
|
||||||
} else {
|
|
||||||
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
let bind = format!("{}:{}", local_srv.0, local_srv.1);
|
||||||
Box::pin(udp::run_server(bind.parse()?, timeout).await?)
|
let listening_server = udp::run_server(bind.parse()?, timeout);
|
||||||
};
|
let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
|
||||||
|
|
||||||
let udp = listening_server.next().await.unwrap()?;
|
|
||||||
let (local_rx, local_tx) = tokio::io::split(udp);
|
let (local_rx, local_tx) = tokio::io::split(udp);
|
||||||
REVERSE.lock().insert(local_srv.clone(), listening_server);
|
|
||||||
|
|
||||||
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
|
Ok((jwt.claims.p, local_srv.0, local_srv.1, Box::pin(local_rx), Box::pin(local_tx)))
|
||||||
}
|
}
|
||||||
|
@ -124,6 +111,63 @@ async fn from_query(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
async fn run_listening_server<T, Fut, FutOut>(
|
||||||
|
local_srv: &(Host, u16),
|
||||||
|
servers: &Mutex<HashMap<(Host<String>, u16), mpsc::Receiver<T>>>,
|
||||||
|
gen_listening_server: Fut,
|
||||||
|
) -> anyhow::Result<T>
|
||||||
|
where
|
||||||
|
Fut: Future<Output = anyhow::Result<FutOut>>,
|
||||||
|
FutOut: Stream<Item = io::Result<T>> + Send + 'static,
|
||||||
|
T: Send + 'static,
|
||||||
|
{
|
||||||
|
let listening_server = servers.lock().remove(local_srv);
|
||||||
|
let mut listening_server = if let Some(listening_server) = listening_server {
|
||||||
|
listening_server
|
||||||
|
} else {
|
||||||
|
let listening_server = gen_listening_server.await?;
|
||||||
|
let (tx, rx) = mpsc::channel::<T>(1);
|
||||||
|
let fut = async move {
|
||||||
|
pin_mut!(listening_server);
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
biased;
|
||||||
|
cnx = listening_server.next() => {
|
||||||
|
match cnx {
|
||||||
|
None => break,
|
||||||
|
Some(Err(err)) => {
|
||||||
|
warn!("Error while listening for incoming connections {err:?}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(Ok(cnx)) => {
|
||||||
|
if tx.send_timeout(cnx, Duration::from_secs(30)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
_ = tx.closed() => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("Stopping listening server");
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::spawn(fut.instrument(Span::current()));
|
||||||
|
rx
|
||||||
|
};
|
||||||
|
|
||||||
|
let cnx = listening_server
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow!("listening server stopped"))?;
|
||||||
|
servers.lock().insert(local_srv.clone(), listening_server);
|
||||||
|
Ok(cnx)
|
||||||
|
}
|
||||||
|
|
||||||
async fn server_upgrade(
|
async fn server_upgrade(
|
||||||
server_config: Arc<WsServerConfig>,
|
server_config: Arc<WsServerConfig>,
|
||||||
mut req: Request<Body>,
|
mut req: Request<Body>,
|
||||||
|
|
Loading…
Reference in a new issue