Improve stdio tunnel on windows
- Handle CTRL+C to exit properly - Restore terminal mode at exit - Use logger to stderr
This commit is contained in:
parent
0595e23050
commit
9b82006c6e
3 changed files with 41 additions and 28 deletions
43
src/main.rs
43
src/main.rs
|
@ -724,27 +724,28 @@ async fn main() {
|
||||||
let args = Wstunnel::parse();
|
let args = Wstunnel::parse();
|
||||||
|
|
||||||
// Setup logging
|
// Setup logging
|
||||||
match &args.commands {
|
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
|
||||||
// Disable logging if there is a stdio tunnel
|
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
|
||||||
Commands::Client(args)
|
env_filter = env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
|
||||||
if args
|
|
||||||
.local_to_remote
|
|
||||||
.iter()
|
|
||||||
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
|
|
||||||
.count()
|
|
||||||
> 0 => {}
|
|
||||||
_ => {
|
|
||||||
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
|
|
||||||
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
|
|
||||||
env_filter =
|
|
||||||
env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
|
|
||||||
}
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_ansi(args.no_color.is_none())
|
|
||||||
.with_env_filter(env_filter)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
let logger = tracing_subscriber::fmt()
|
||||||
|
.with_ansi(args.no_color.is_none())
|
||||||
|
.with_env_filter(env_filter);
|
||||||
|
|
||||||
|
// stdio tunnel capture stdio, so need to log into stderr
|
||||||
|
if let Commands::Client(args) = &args.commands {
|
||||||
|
if args
|
||||||
|
.local_to_remote
|
||||||
|
.iter()
|
||||||
|
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
|
||||||
|
.count()
|
||||||
|
> 0
|
||||||
|
{
|
||||||
|
logger.with_writer(io::stderr).init();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.init();
|
||||||
|
};
|
||||||
|
|
||||||
match args.commands {
|
match args.commands {
|
||||||
Commands::Client(args) => {
|
Commands::Client(args) => {
|
||||||
|
@ -1018,7 +1019,7 @@ async fn main() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
LocalProtocol::Unix { path } => {
|
LocalProtocol::Unix { .. } => {
|
||||||
panic!("Unix socket is not available for non Unix platform")
|
panic!("Unix socket is not available for non Unix platform")
|
||||||
}
|
}
|
||||||
LocalProtocol::Stdio
|
LocalProtocol::Stdio
|
||||||
|
|
24
src/stdio.rs
24
src/stdio.rs
|
@ -2,8 +2,9 @@
|
||||||
pub mod server {
|
pub mod server {
|
||||||
|
|
||||||
use tokio_fd::AsyncFd;
|
use tokio_fd::AsyncFd;
|
||||||
|
use tracing::info;
|
||||||
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
|
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
|
||||||
eprintln!("Starting STDIO server");
|
info!("Starting STDIO server");
|
||||||
|
|
||||||
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
|
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
|
||||||
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
|
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
|
||||||
|
@ -15,31 +16,39 @@ pub mod server {
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
pub mod server {
|
pub mod server {
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
|
use log::error;
|
||||||
|
use scopeguard::guard;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::{io, thread};
|
use std::{io, thread};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::task::LocalSet;
|
use tokio::task::LocalSet;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_util::io::StreamReader;
|
use tokio_util::io::StreamReader;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
|
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
|
||||||
eprintln!("Starting STDIO server");
|
info!("Starting STDIO server. Press ctrl+c twice to exit");
|
||||||
|
|
||||||
crossterm::terminal::enable_raw_mode()?;
|
crossterm::terminal::enable_raw_mode()?;
|
||||||
|
|
||||||
let stdin = io::stdin();
|
let stdin = io::stdin();
|
||||||
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
|
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
|
let _restore_terminal = guard((), move |_| {
|
||||||
|
let _ = crossterm::terminal::disable_raw_mode();
|
||||||
|
});
|
||||||
let stdin = stdin;
|
let stdin = stdin;
|
||||||
let mut stdin = stdin.lock();
|
let mut stdin = stdin.lock();
|
||||||
let mut buf = [0u8; 65536];
|
let mut buf = [0u8; 65536];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let n = stdin.read(&mut buf).unwrap();
|
let n = stdin.read(&mut buf).unwrap_or(0);
|
||||||
if n == 0 {
|
if n == 0 || (n == 1 && buf[0] == 3) {
|
||||||
|
// ctrl+c send char 3
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
|
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
|
||||||
eprintln!("Failed send inout: {:?}", err);
|
error!("Failed send inout: {:?}", err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,6 +59,9 @@ pub mod server {
|
||||||
let rt = tokio::runtime::Handle::current();
|
let rt = tokio::runtime::Handle::current();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let task = async move {
|
let task = async move {
|
||||||
|
let _restore_terminal = guard((), move |_| {
|
||||||
|
let _ = crossterm::terminal::disable_raw_mode();
|
||||||
|
});
|
||||||
let mut stdout = io::stdout().lock();
|
let mut stdout = io::stdout().lock();
|
||||||
let mut buf = [0u8; 65536];
|
let mut buf = [0u8; 65536];
|
||||||
loop {
|
loop {
|
||||||
|
@ -62,7 +74,7 @@ pub mod server {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(err) = stdout.write_all(&buf[..n]) {
|
if let Err(err) = stdout.write_all(&buf[..n]) {
|
||||||
eprintln!("Failed to write to stdout: {:?}", err);
|
error!("Failed to write to stdout: {:?}", err);
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
let _ = stdout.flush();
|
let _ = stdout.flush();
|
||||||
|
|
|
@ -167,7 +167,7 @@ async fn run_tunnel(
|
||||||
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
|
||||||
}
|
}
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
LocalProtocol::ReverseUnix { ref path } => {
|
LocalProtocol::ReverseUnix { .. } => {
|
||||||
error!("Received an unsupported target protocol {:?}", remote);
|
error!("Received an unsupported target protocol {:?}", remote);
|
||||||
Err(anyhow::anyhow!("Invalid upgrade request"))
|
Err(anyhow::anyhow!("Invalid upgrade request"))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue