From e8a27ea4dfce45753fd98e1e4683e874eddf981b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sat, 25 May 2024 10:31:51 +0200 Subject: [PATCH] Cleanup exit wstunnel when stdio tunnel terminate --- src/main.rs | 12 +++++++++++- src/stdio.rs | 30 +++++++++++++----------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/main.rs b/src/main.rs index e3d51d6..1e4852f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,6 +34,7 @@ use std::time::Duration; use std::{fmt, io}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio::select; use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName}; use tokio_rustls::TlsConnector; @@ -1188,7 +1189,7 @@ async fn main() { } LocalProtocol::Stdio => { - let server = stdio::server::run_server().await.unwrap_or_else(|err| { + let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| { panic!("Cannot start STDIO server: {}", err); }); tokio::spawn(async move { @@ -1208,6 +1209,15 @@ async fn main() { error!("{:?}", err); } }); + + // We need to wait for either a ctrl+c of that the stdio tunnel is closed + // to force exit the program + select! { + _ = handle.closed() => {}, + _ = tokio::signal::ctrl_c() => {} + } + tokio::time::sleep(Duration::from_secs(1)).await; + std::process::exit(0); } LocalProtocol::ReverseTcp => {} LocalProtocol::ReverseUdp { .. } => {} diff --git a/src/stdio.rs b/src/stdio.rs index 127c8d2..29481c0 100644 --- a/src/stdio.rs +++ b/src/stdio.rs @@ -1,38 +1,31 @@ #[cfg(unix)] pub mod server { use std::pin::Pin; - use std::process; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, ReadBuf}; + use tokio::sync::oneshot; use tokio_fd::AsyncFd; use tracing::info; - pub struct AbortOnDropStdin { + pub struct WsStdin { stdin: AsyncFd, + _receiver: oneshot::Receiver<()>, } - // Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop. - // As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped. - impl Drop for AbortOnDropStdin { - fn drop(&mut self) { - // Hackish ! - process::exit(0); - } - } - - impl AsyncRead for AbortOnDropStdin { + impl AsyncRead for WsStdin { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf) } } - pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> { + pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> { info!("Starting STDIO server"); let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; + let (tx, rx) = oneshot::channel::<()>(); - Ok((AbortOnDropStdin { stdin }, stdout)) + Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx)) } } @@ -42,24 +35,27 @@ pub mod server { use log::error; use scopeguard::guard; use std::io::{Read, Write}; + use std::sync::mpsc; use std::{io, process, thread}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + use tokio::sync::oneshot; use tokio::task::LocalSet; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::io::StreamReader; use tracing::info; - pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> { + pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> { info!("Starting STDIO server. Press ctrl+c twice to exit"); crossterm::terminal::enable_raw_mode()?; let stdin = io::stdin(); let (send, recv) = tokio::sync::mpsc::unbounded_channel(); + let (abort_tx, mut abort_rx) = oneshot::channel::<()>(); thread::spawn(move || { let _restore_terminal = guard((), move |_| { let _ = crossterm::terminal::disable_raw_mode(); - process::exit(0); + abort_rx.close(); }); let stdin = stdin; let mut stdin = stdin.lock(); @@ -111,6 +107,6 @@ pub mod server { rt.block_on(local); }); - Ok((stdin, stdout)) + Ok(((stdin, stdout), abort_tx)) } }