Cleanup exit wstunnel when stdio tunnel terminate

This commit is contained in:
Σrebe - Romain GERARD 2024-05-25 10:31:51 +02:00
parent a79a1bc107
commit e8a27ea4df
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 24 additions and 18 deletions

View file

@ -34,6 +34,7 @@ use std::time::Duration;
use std::{fmt, io}; use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select;
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName}; use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
@ -1188,7 +1189,7 @@ async fn main() {
} }
LocalProtocol::Stdio => { LocalProtocol::Stdio => {
let server = stdio::server::run_server().await.unwrap_or_else(|err| { let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| {
panic!("Cannot start STDIO server: {}", err); panic!("Cannot start STDIO server: {}", err);
}); });
tokio::spawn(async move { tokio::spawn(async move {
@ -1208,6 +1209,15 @@ async fn main() {
error!("{:?}", err); error!("{:?}", err);
} }
}); });
// We need to wait for either a ctrl+c of that the stdio tunnel is closed
// to force exit the program
select! {
_ = handle.closed() => {},
_ = tokio::signal::ctrl_c() => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
std::process::exit(0);
} }
LocalProtocol::ReverseTcp => {} LocalProtocol::ReverseTcp => {}
LocalProtocol::ReverseUdp { .. } => {} LocalProtocol::ReverseUdp { .. } => {}

View file

@ -1,38 +1,31 @@
#[cfg(unix)] #[cfg(unix)]
pub mod server { pub mod server {
use std::pin::Pin; use std::pin::Pin;
use std::process;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::oneshot;
use tokio_fd::AsyncFd; use tokio_fd::AsyncFd;
use tracing::info; use tracing::info;
pub struct AbortOnDropStdin { pub struct WsStdin {
stdin: AsyncFd, stdin: AsyncFd,
_receiver: oneshot::Receiver<()>,
} }
// Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop. impl AsyncRead for WsStdin {
// As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped.
impl Drop for AbortOnDropStdin {
fn drop(&mut self) {
// Hackish !
process::exit(0);
}
}
impl AsyncRead for AbortOnDropStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf) unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
} }
} }
pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> { pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server"); info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
let (tx, rx) = oneshot::channel::<()>();
Ok((AbortOnDropStdin { stdin }, stdout)) Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
} }
} }
@ -42,24 +35,27 @@ pub mod server {
use log::error; use log::error;
use scopeguard::guard; use scopeguard::guard;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::sync::mpsc;
use std::{io, process, thread}; use std::{io, process, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot;
use tokio::task::LocalSet; use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader; use tokio_util::io::StreamReader;
use tracing::info; use tracing::info;
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> { pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server. Press ctrl+c twice to exit"); info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?; crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin(); let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel(); let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, mut abort_rx) = oneshot::channel::<()>();
thread::spawn(move || { thread::spawn(move || {
let _restore_terminal = guard((), move |_| { let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode(); let _ = crossterm::terminal::disable_raw_mode();
process::exit(0); abort_rx.close();
}); });
let stdin = stdin; let stdin = stdin;
let mut stdin = stdin.lock(); let mut stdin = stdin.lock();
@ -111,6 +107,6 @@ pub mod server {
rt.block_on(local); rt.block_on(local);
}); });
Ok((stdin, stdout)) Ok(((stdin, stdout), abort_tx))
} }
} }