From 9b82006c6e49f484220b1cafcac0c6a20f18cfb7 Mon Sep 17 00:00:00 2001 From: erebe Date: Sat, 18 May 2024 11:23:22 +0200 Subject: [PATCH] Improve stdio tunnel on windows - Handle CTRL+C to exit properly - Restore terminal mode at exit - Use logger to stderr --- src/main.rs | 43 ++++++++++++++++++++++--------------------- src/stdio.rs | 24 ++++++++++++++++++------ src/tunnel/server.rs | 2 +- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/src/main.rs b/src/main.rs index 808f6bf..78321a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -724,27 +724,28 @@ async fn main() { let args = Wstunnel::parse(); // Setup logging - match &args.commands { - // Disable logging if there is a stdio tunnel - Commands::Client(args) - if args - .local_to_remote - .iter() - .filter(|x| x.local_protocol == LocalProtocol::Stdio) - .count() - > 0 => {} - _ => { - let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level"); - if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) { - env_filter = - env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive")); - } - tracing_subscriber::fmt() - .with_ansi(args.no_color.is_none()) - .with_env_filter(env_filter) - .init(); - } + let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level"); + if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) { + env_filter = env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive")); } + let logger = tracing_subscriber::fmt() + .with_ansi(args.no_color.is_none()) + .with_env_filter(env_filter); + + // stdio tunnel capture stdio, so need to log into stderr + if let Commands::Client(args) = &args.commands { + if args + .local_to_remote + .iter() + .filter(|x| x.local_protocol == LocalProtocol::Stdio) + .count() + > 0 + { + logger.with_writer(io::stderr).init(); + } + } else { + logger.init(); + }; match args.commands { Commands::Client(args) => { @@ -1018,7 +1019,7 @@ async fn main() { }); } #[cfg(not(unix))] - LocalProtocol::Unix { path } => { + LocalProtocol::Unix { .. } => { panic!("Unix socket is not available for non Unix platform") } LocalProtocol::Stdio diff --git a/src/stdio.rs b/src/stdio.rs index 074b61f..18fe02c 100644 --- a/src/stdio.rs +++ b/src/stdio.rs @@ -2,8 +2,9 @@ pub mod server { use tokio_fd::AsyncFd; + use tracing::info; pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> { - eprintln!("Starting STDIO server"); + info!("Starting STDIO server"); let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; @@ -15,31 +16,39 @@ pub mod server { #[cfg(not(unix))] pub mod server { use bytes::BytesMut; + use log::error; + use scopeguard::guard; use std::io::{Read, Write}; use std::{io, thread}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::task::LocalSet; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::io::StreamReader; + use tracing::info; pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> { - eprintln!("Starting STDIO server"); + info!("Starting STDIO server. Press ctrl+c twice to exit"); crossterm::terminal::enable_raw_mode()?; let stdin = io::stdin(); let (send, recv) = tokio::sync::mpsc::unbounded_channel(); thread::spawn(move || { + let _restore_terminal = guard((), move |_| { + let _ = crossterm::terminal::disable_raw_mode(); + }); let stdin = stdin; let mut stdin = stdin.lock(); let mut buf = [0u8; 65536]; + loop { - let n = stdin.read(&mut buf).unwrap(); - if n == 0 { + let n = stdin.read(&mut buf).unwrap_or(0); + if n == 0 || (n == 1 && buf[0] == 3) { + // ctrl+c send char 3 break; } if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) { - eprintln!("Failed send inout: {:?}", err); + error!("Failed send inout: {:?}", err); break; } } @@ -50,6 +59,9 @@ pub mod server { let rt = tokio::runtime::Handle::current(); thread::spawn(move || { let task = async move { + let _restore_terminal = guard((), move |_| { + let _ = crossterm::terminal::disable_raw_mode(); + }); let mut stdout = io::stdout().lock(); let mut buf = [0u8; 65536]; loop { @@ -62,7 +74,7 @@ pub mod server { } if let Err(err) = stdout.write_all(&buf[..n]) { - eprintln!("Failed to write to stdout: {:?}", err); + error!("Failed to write to stdout: {:?}", err); break; }; let _ = stdout.flush(); diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 63efbf5..d53ccd1 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -167,7 +167,7 @@ async fn run_tunnel( Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) } #[cfg(not(unix))] - LocalProtocol::ReverseUnix { ref path } => { + LocalProtocol::ReverseUnix { .. } => { error!("Received an unsupported target protocol {:?}", remote); Err(anyhow::anyhow!("Invalid upgrade request")) }