Cleanup exit wstunnel when stdio tunnel terminate

This commit is contained in:
Σrebe - Romain GERARD 2024-05-25 10:31:51 +02:00
parent a79a1bc107
commit e8a27ea4df
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 24 additions and 18 deletions

View file

@ -34,6 +34,7 @@ use std::time::Duration;
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::select;
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
use tokio_rustls::TlsConnector;
@ -1188,7 +1189,7 @@ async fn main() {
}
LocalProtocol::Stdio => {
let server = stdio::server::run_server().await.unwrap_or_else(|err| {
let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| {
panic!("Cannot start STDIO server: {}", err);
});
tokio::spawn(async move {
@ -1208,6 +1209,15 @@ async fn main() {
error!("{:?}", err);
}
});
// We need to wait for either a ctrl+c of that the stdio tunnel is closed
// to force exit the program
select! {
_ = handle.closed() => {},
_ = tokio::signal::ctrl_c() => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
std::process::exit(0);
}
LocalProtocol::ReverseTcp => {}
LocalProtocol::ReverseUdp { .. } => {}

View file

@ -1,38 +1,31 @@
#[cfg(unix)]
pub mod server {
use std::pin::Pin;
use std::process;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::oneshot;
use tokio_fd::AsyncFd;
use tracing::info;
pub struct AbortOnDropStdin {
pub struct WsStdin {
stdin: AsyncFd,
_receiver: oneshot::Receiver<()>,
}
// Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop.
// As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped.
impl Drop for AbortOnDropStdin {
fn drop(&mut self) {
// Hackish !
process::exit(0);
}
}
impl AsyncRead for AbortOnDropStdin {
impl AsyncRead for WsStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
}
}
pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> {
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
let (tx, rx) = oneshot::channel::<()>();
Ok((AbortOnDropStdin { stdin }, stdout))
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
}
}
@ -42,24 +35,27 @@ pub mod server {
use log::error;
use scopeguard::guard;
use std::io::{Read, Write};
use std::sync::mpsc;
use std::{io, process, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader;
use tracing::info;
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, mut abort_rx) = oneshot::channel::<()>();
thread::spawn(move || {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
process::exit(0);
abort_rx.close();
});
let stdin = stdin;
let mut stdin = stdin.lock();
@ -111,6 +107,6 @@ pub mod server {
rt.block_on(local);
});
Ok((stdin, stdout))
Ok(((stdin, stdout), abort_tx))
}
}