diff --git a/src/stdio.rs b/src/stdio.rs index 18fe02c..127c8d2 100644 --- a/src/stdio.rs +++ b/src/stdio.rs @@ -1,15 +1,38 @@ #[cfg(unix)] pub mod server { - + use std::pin::Pin; + use std::process; + use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, ReadBuf}; use tokio_fd::AsyncFd; use tracing::info; - pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> { + + pub struct AbortOnDropStdin { + stdin: AsyncFd, + } + + // Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop. + // As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped. + impl Drop for AbortOnDropStdin { + fn drop(&mut self) { + // Hackish ! + process::exit(0); + } + } + + impl AsyncRead for AbortOnDropStdin { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf) + } + } + + pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> { info!("Starting STDIO server"); let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; - Ok((stdin, stdout)) + Ok((AbortOnDropStdin { stdin }, stdout)) } } @@ -19,7 +42,7 @@ pub mod server { use log::error; use scopeguard::guard; use std::io::{Read, Write}; - use std::{io, thread}; + use std::{io, process, thread}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::task::LocalSet; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -36,6 +59,7 @@ pub mod server { thread::spawn(move || { let _restore_terminal = guard((), move |_| { let _ = crossterm::terminal::disable_raw_mode(); + process::exit(0); }); let stdin = stdin; let mut stdin = stdin.lock();