Exit wstunnel when stdio tunnel terminate

This commit is contained in:
Σrebe - Romain GERARD 2024-05-24 20:50:13 +02:00
parent 904c775324
commit ad7d752f98
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4

View file

@ -1,15 +1,38 @@
#[cfg(unix)] #[cfg(unix)]
pub mod server { pub mod server {
use std::pin::Pin;
use std::process;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio_fd::AsyncFd; use tokio_fd::AsyncFd;
use tracing::info; use tracing::info;
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
pub struct AbortOnDropStdin {
stdin: AsyncFd,
}
// Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop.
// As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped.
impl Drop for AbortOnDropStdin {
fn drop(&mut self) {
// Hackish !
process::exit(0);
}
}
impl AsyncRead for AbortOnDropStdin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
}
}
pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> {
info!("Starting STDIO server"); info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
Ok((stdin, stdout)) Ok((AbortOnDropStdin { stdin }, stdout))
} }
} }
@ -19,7 +42,7 @@ pub mod server {
use log::error; use log::error;
use scopeguard::guard; use scopeguard::guard;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::{io, thread}; use std::{io, process, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::task::LocalSet; use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
@ -36,6 +59,7 @@ pub mod server {
thread::spawn(move || { thread::spawn(move || {
let _restore_terminal = guard((), move |_| { let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode(); let _ = crossterm::terminal::disable_raw_mode();
process::exit(0);
}); });
let stdin = stdin; let stdin = stdin;
let mut stdin = stdin.lock(); let mut stdin = stdin.lock();