Improve stdio tunnel on windows

- Handle CTRL+C to exit properly
- Restore terminal mode at exit
- Use logger to stderr
This commit is contained in:
erebe 2024-05-18 11:23:22 +02:00 committed by Σrebe - Romain GERARD
parent 0595e23050
commit 9b82006c6e
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 41 additions and 28 deletions

View file

@ -724,27 +724,28 @@ async fn main() {
let args = Wstunnel::parse(); let args = Wstunnel::parse();
// Setup logging // Setup logging
match &args.commands { let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
// Disable logging if there is a stdio tunnel if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
Commands::Client(args) env_filter = env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
if args
.local_to_remote
.iter()
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
.count()
> 0 => {}
_ => {
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
env_filter =
env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
}
tracing_subscriber::fmt()
.with_ansi(args.no_color.is_none())
.with_env_filter(env_filter)
.init();
}
} }
let logger = tracing_subscriber::fmt()
.with_ansi(args.no_color.is_none())
.with_env_filter(env_filter);
// stdio tunnel capture stdio, so need to log into stderr
if let Commands::Client(args) = &args.commands {
if args
.local_to_remote
.iter()
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
.count()
> 0
{
logger.with_writer(io::stderr).init();
}
} else {
logger.init();
};
match args.commands { match args.commands {
Commands::Client(args) => { Commands::Client(args) => {
@ -1018,7 +1019,7 @@ async fn main() {
}); });
} }
#[cfg(not(unix))] #[cfg(not(unix))]
LocalProtocol::Unix { path } => { LocalProtocol::Unix { .. } => {
panic!("Unix socket is not available for non Unix platform") panic!("Unix socket is not available for non Unix platform")
} }
LocalProtocol::Stdio LocalProtocol::Stdio

View file

@ -2,8 +2,9 @@
pub mod server { pub mod server {
use tokio_fd::AsyncFd; use tokio_fd::AsyncFd;
use tracing::info;
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> { pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
eprintln!("Starting STDIO server"); info!("Starting STDIO server");
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?; let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?; let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
@ -15,31 +16,39 @@ pub mod server {
#[cfg(not(unix))] #[cfg(not(unix))]
pub mod server { pub mod server {
use bytes::BytesMut; use bytes::BytesMut;
use log::error;
use scopeguard::guard;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::{io, thread}; use std::{io, thread};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::task::LocalSet; use tokio::task::LocalSet;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::StreamReader; use tokio_util::io::StreamReader;
use tracing::info;
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> { pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
eprintln!("Starting STDIO server"); info!("Starting STDIO server. Press ctrl+c twice to exit");
crossterm::terminal::enable_raw_mode()?; crossterm::terminal::enable_raw_mode()?;
let stdin = io::stdin(); let stdin = io::stdin();
let (send, recv) = tokio::sync::mpsc::unbounded_channel(); let (send, recv) = tokio::sync::mpsc::unbounded_channel();
thread::spawn(move || { thread::spawn(move || {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
});
let stdin = stdin; let stdin = stdin;
let mut stdin = stdin.lock(); let mut stdin = stdin.lock();
let mut buf = [0u8; 65536]; let mut buf = [0u8; 65536];
loop { loop {
let n = stdin.read(&mut buf).unwrap(); let n = stdin.read(&mut buf).unwrap_or(0);
if n == 0 { if n == 0 || (n == 1 && buf[0] == 3) {
// ctrl+c send char 3
break; break;
} }
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) { if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
eprintln!("Failed send inout: {:?}", err); error!("Failed send inout: {:?}", err);
break; break;
} }
} }
@ -50,6 +59,9 @@ pub mod server {
let rt = tokio::runtime::Handle::current(); let rt = tokio::runtime::Handle::current();
thread::spawn(move || { thread::spawn(move || {
let task = async move { let task = async move {
let _restore_terminal = guard((), move |_| {
let _ = crossterm::terminal::disable_raw_mode();
});
let mut stdout = io::stdout().lock(); let mut stdout = io::stdout().lock();
let mut buf = [0u8; 65536]; let mut buf = [0u8; 65536];
loop { loop {
@ -62,7 +74,7 @@ pub mod server {
} }
if let Err(err) = stdout.write_all(&buf[..n]) { if let Err(err) = stdout.write_all(&buf[..n]) {
eprintln!("Failed to write to stdout: {:?}", err); error!("Failed to write to stdout: {:?}", err);
break; break;
}; };
let _ = stdout.flush(); let _ = stdout.flush();

View file

@ -167,7 +167,7 @@ async fn run_tunnel(
Ok((remote, Box::pin(local_rx), Box::pin(local_tx))) Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
} }
#[cfg(not(unix))] #[cfg(not(unix))]
LocalProtocol::ReverseUnix { ref path } => { LocalProtocol::ReverseUnix { .. } => {
error!("Received an unsupported target protocol {:?}", remote); error!("Received an unsupported target protocol {:?}", remote);
Err(anyhow::anyhow!("Invalid upgrade request")) Err(anyhow::anyhow!("Invalid upgrade request"))
} }