diff --git a/src/main.rs b/src/main.rs index b1e28fe..4a0bc8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,6 +196,12 @@ struct Client { #[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)] http_headers: Vec<(HeaderName, HeaderValue)>, + /// Send custom headers in the upgrade request reading them from a file. + /// It overrides http_headers specified from command line. + /// File is read everytime and file format must contains lines with `HEADER_NAME: HEADER_VALUE` + #[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)] + http_headers_file_path: Option, + /// Address of the wstunnel server /// You can either use websocket or http2 as transport protocol. Use websocket if you are unsure. /// Example: For websocket with TLS wss://wstunnel.example.com or without ws://wstunnel.example.com @@ -598,6 +604,7 @@ pub struct WsClientConfig { pub http_upgrade_path_prefix: String, pub http_upgrade_credentials: Option, pub http_headers: HashMap, + pub http_headers_file: Option, pub http_header_host: HeaderValue, pub timeout_connect: Duration, pub websocket_ping_frequency: Duration, @@ -696,6 +703,11 @@ async fn main() { }; HeaderValue::from_str(&host).unwrap() }; + if let Some(path) = &args.http_headers_file_path { + if !path.exists() { + panic!("http headers file does not exists: {}", path.display()); + } + } let mut client_config = WsClientConfig { remote_addr: TransportAddr::new( TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), @@ -708,6 +720,7 @@ async fn main() { http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_credentials: args.http_upgrade_credentials, http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(), + http_headers_file: args.http_headers_file_path, http_header_host: host_header, timeout_connect: Duration::from_secs(10), websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs index 74ae543..617987f 100644 --- a/src/tunnel/transport/http2.rs +++ b/src/tunnel/transport/http2.rs @@ -1,4 +1,4 @@ -use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; +use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr}; use crate::WsClientConfig; use anyhow::{anyhow, Context}; @@ -124,11 +124,19 @@ pub async fn connect( let _ = headers.remove(k); headers.append(k, v.clone()); } + if let Some(auth) = &client_cfg.http_upgrade_credentials { let _ = headers.remove(AUTHORIZATION); headers.append(AUTHORIZATION, auth.clone()); } + if let Some(headers_file_path) = &client_cfg.http_headers_file { + for (k, v) in headers_from_file(headers_file_path) { + let _ = headers.remove(&k); + headers.append(k, v); + } + } + let (tx, rx) = mpsc::channel::(1024); let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result> { Ok(Frame::data(s)) })); let req = req.body(body).with_context(|| { diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs index 480230b..aeb6bac 100644 --- a/src/tunnel/transport/mod.rs +++ b/src/tunnel/transport/mod.rs @@ -1,8 +1,14 @@ use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use bytes::BytesMut; +use hyper::http::{HeaderName, HeaderValue}; use std::future::Future; +use std::io::{BufRead, BufReader}; +use std::path::Path; +use std::str::FromStr; + use tokio::io::AsyncWrite; +use tracing::error; pub mod http2; pub mod io; @@ -72,3 +78,25 @@ impl TunnelWrite for TunnelWriter { } } } + +#[inline] +pub fn headers_from_file(path: &Path) -> Vec<(HeaderName, HeaderValue)> { + let file = match std::fs::File::open(path) { + Ok(file) => file, + Err(err) => { + error!("Cannot read headers from file: {:?}: {:?}", path, err); + return vec![]; + } + }; + + BufReader::new(file) + .lines() + .filter_map(|line| { + let line = line.ok()?; + let (header, value) = line.split_once(':')?; + let header = HeaderName::from_str(header.trim()).ok()?; + let value = HeaderValue::from_str(value.trim()).ok()?; + Some((header, value)) + }) + .collect() +} diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index 3201661..583adf4 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,4 +1,4 @@ -use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; +use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; use crate::WsClientConfig; use anyhow::{anyhow, Context}; @@ -160,11 +160,19 @@ pub async fn connect( let _ = headers.remove(k); headers.append(k, v.clone()); } + if let Some(auth) = &client_cfg.http_upgrade_credentials { let _ = headers.remove(AUTHORIZATION); headers.append(AUTHORIZATION, auth.clone()); } + if let Some(headers_file_path) = &client_cfg.http_headers_file { + for (k, v) in headers_from_file(headers_file_path) { + let _ = headers.remove(&k); + headers.append(k, v); + } + } + let req = req.body(Empty::::new()).with_context(|| { format!( "failed to build HTTP request to contact the server {:?}",