Cleanup exit wstunnel when stdio tunnel terminate
This commit is contained in:
parent
a79a1bc107
commit
e8a27ea4df
2 changed files with 24 additions and 18 deletions
12
src/main.rs
12
src/main.rs
|
@ -34,6 +34,7 @@ use std::time::Duration;
|
|||
use std::{fmt, io};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
@ -1188,7 +1189,7 @@ async fn main() {
|
|||
}
|
||||
|
||||
LocalProtocol::Stdio => {
|
||||
let server = stdio::server::run_server().await.unwrap_or_else(|err| {
|
||||
let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| {
|
||||
panic!("Cannot start STDIO server: {}", err);
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
|
@ -1208,6 +1209,15 @@ async fn main() {
|
|||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
// We need to wait for either a ctrl+c of that the stdio tunnel is closed
|
||||
// to force exit the program
|
||||
select! {
|
||||
_ = handle.closed() => {},
|
||||
_ = tokio::signal::ctrl_c() => {}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
std::process::exit(0);
|
||||
}
|
||||
LocalProtocol::ReverseTcp => {}
|
||||
LocalProtocol::ReverseUdp { .. } => {}
|
||||
|
|
30
src/stdio.rs
30
src/stdio.rs
|
@ -1,38 +1,31 @@
|
|||
#[cfg(unix)]
|
||||
pub mod server {
|
||||
use std::pin::Pin;
|
||||
use std::process;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_fd::AsyncFd;
|
||||
use tracing::info;
|
||||
|
||||
pub struct AbortOnDropStdin {
|
||||
pub struct WsStdin {
|
||||
stdin: AsyncFd,
|
||||
_receiver: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
// Wrapper around stdin is needed in order to properly abort the process in case the tunnel drop.
|
||||
// As we are going to launch the tunnel in a threadpool, we cant know when the tunnel is dropped.
|
||||
impl Drop for AbortOnDropStdin {
|
||||
fn drop(&mut self) {
|
||||
// Hackish !
|
||||
process::exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for AbortOnDropStdin {
|
||||
impl AsyncRead for WsStdin {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
unsafe { self.map_unchecked_mut(|s| &mut s.stdin) }.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_server() -> Result<(AbortOnDropStdin, AsyncFd), anyhow::Error> {
|
||||
pub async fn run_server() -> Result<((WsStdin, AsyncFd), oneshot::Sender<()>), anyhow::Error> {
|
||||
info!("Starting STDIO server");
|
||||
|
||||
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
|
||||
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
|
||||
let (tx, rx) = oneshot::channel::<()>();
|
||||
|
||||
Ok((AbortOnDropStdin { stdin }, stdout))
|
||||
Ok(((WsStdin { stdin, _receiver: rx }, stdout), tx))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,24 +35,27 @@ pub mod server {
|
|||
use log::error;
|
||||
use scopeguard::guard;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::mpsc;
|
||||
use std::{io, process, thread};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::LocalSet;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
|
||||
pub async fn run_server() -> Result<((impl AsyncRead, impl AsyncWrite), oneshot::Sender<()>), anyhow::Error> {
|
||||
info!("Starting STDIO server. Press ctrl+c twice to exit");
|
||||
|
||||
crossterm::terminal::enable_raw_mode()?;
|
||||
|
||||
let stdin = io::stdin();
|
||||
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (abort_tx, mut abort_rx) = oneshot::channel::<()>();
|
||||
thread::spawn(move || {
|
||||
let _restore_terminal = guard((), move |_| {
|
||||
let _ = crossterm::terminal::disable_raw_mode();
|
||||
process::exit(0);
|
||||
abort_rx.close();
|
||||
});
|
||||
let stdin = stdin;
|
||||
let mut stdin = stdin.lock();
|
||||
|
@ -111,6 +107,6 @@ pub mod server {
|
|||
rt.block_on(local);
|
||||
});
|
||||
|
||||
Ok((stdin, stdout))
|
||||
Ok(((stdin, stdout), abort_tx))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue