Add flag to read http headers from a file

This commit is contained in:
Σrebe - Romain GERARD 2024-01-25 19:16:11 +01:00
parent 13aa664caf
commit f0cb4ab671
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 59 additions and 2 deletions

View file

@ -196,6 +196,12 @@ struct Client {
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)] #[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
http_headers: Vec<(HeaderName, HeaderValue)>, http_headers: Vec<(HeaderName, HeaderValue)>,
/// Send custom headers in the upgrade request reading them from a file.
/// It overrides http_headers specified from command line.
/// File is read everytime and file format must contains lines with `HEADER_NAME: HEADER_VALUE`
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
http_headers_file_path: Option<PathBuf>,
/// Address of the wstunnel server /// Address of the wstunnel server
/// You can either use websocket or http2 as transport protocol. Use websocket if you are unsure. /// You can either use websocket or http2 as transport protocol. Use websocket if you are unsure.
/// Example: For websocket with TLS wss://wstunnel.example.com or without ws://wstunnel.example.com /// Example: For websocket with TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
@ -598,6 +604,7 @@ pub struct WsClientConfig {
pub http_upgrade_path_prefix: String, pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>, pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>, pub http_headers: HashMap<HeaderName, HeaderValue>,
pub http_headers_file: Option<PathBuf>,
pub http_header_host: HeaderValue, pub http_header_host: HeaderValue,
pub timeout_connect: Duration, pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration, pub websocket_ping_frequency: Duration,
@ -696,6 +703,11 @@ async fn main() {
}; };
HeaderValue::from_str(&host).unwrap() HeaderValue::from_str(&host).unwrap()
}; };
if let Some(path) = &args.http_headers_file_path {
if !path.exists() {
panic!("http headers file does not exists: {}", path.display());
}
}
let mut client_config = WsClientConfig { let mut client_config = WsClientConfig {
remote_addr: TransportAddr::new( remote_addr: TransportAddr::new(
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
@ -708,6 +720,7 @@ async fn main() {
http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials, http_upgrade_credentials: args.http_upgrade_credentials,
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(), http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
http_headers_file: args.http_headers_file_path,
http_header_host: host_header, http_header_host: host_header,
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),

View file

@ -1,4 +1,4 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr};
use crate::WsClientConfig; use crate::WsClientConfig;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
@ -124,11 +124,19 @@ pub async fn connect(
let _ = headers.remove(k); let _ = headers.remove(k);
headers.append(k, v.clone()); headers.append(k, v.clone());
} }
if let Some(auth) = &client_cfg.http_upgrade_credentials { if let Some(auth) = &client_cfg.http_upgrade_credentials {
let _ = headers.remove(AUTHORIZATION); let _ = headers.remove(AUTHORIZATION);
headers.append(AUTHORIZATION, auth.clone()); headers.append(AUTHORIZATION, auth.clone());
} }
if let Some(headers_file_path) = &client_cfg.http_headers_file {
for (k, v) in headers_from_file(headers_file_path) {
let _ = headers.remove(&k);
headers.append(k, v);
}
}
let (tx, rx) = mpsc::channel::<Bytes>(1024); let (tx, rx) = mpsc::channel::<Bytes>(1024);
let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result<Frame<Bytes>> { Ok(Frame::data(s)) })); let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result<Frame<Bytes>> { Ok(Frame::data(s)) }));
let req = req.body(body).with_context(|| { let req = req.body(body).with_context(|| {

View file

@ -1,8 +1,14 @@
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use bytes::BytesMut; use bytes::BytesMut;
use hyper::http::{HeaderName, HeaderValue};
use std::future::Future; use std::future::Future;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::str::FromStr;
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
use tracing::error;
pub mod http2; pub mod http2;
pub mod io; pub mod io;
@ -72,3 +78,25 @@ impl TunnelWrite for TunnelWriter {
} }
} }
} }
#[inline]
pub fn headers_from_file(path: &Path) -> Vec<(HeaderName, HeaderValue)> {
let file = match std::fs::File::open(path) {
Ok(file) => file,
Err(err) => {
error!("Cannot read headers from file: {:?}: {:?}", path, err);
return vec![];
}
};
BufReader::new(file)
.lines()
.filter_map(|line| {
let line = line.ok()?;
let (header, value) = line.split_once(':')?;
let header = HeaderName::from_str(header.trim()).ok()?;
let value = HeaderValue::from_str(value.trim()).ok()?;
Some((header, value))
})
.collect()
}

View file

@ -1,4 +1,4 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
use crate::WsClientConfig; use crate::WsClientConfig;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
@ -160,11 +160,19 @@ pub async fn connect(
let _ = headers.remove(k); let _ = headers.remove(k);
headers.append(k, v.clone()); headers.append(k, v.clone());
} }
if let Some(auth) = &client_cfg.http_upgrade_credentials { if let Some(auth) = &client_cfg.http_upgrade_credentials {
let _ = headers.remove(AUTHORIZATION); let _ = headers.remove(AUTHORIZATION);
headers.append(AUTHORIZATION, auth.clone()); headers.append(AUTHORIZATION, auth.clone());
} }
if let Some(headers_file_path) = &client_cfg.http_headers_file {
for (k, v) in headers_from_file(headers_file_path) {
let _ = headers.remove(&k);
headers.append(k, v);
}
}
let req = req.body(Empty::<Bytes>::new()).with_context(|| { let req = req.body(Empty::<Bytes>::new()).with_context(|| {
format!( format!(
"failed to build HTTP request to contact the server {:?}", "failed to build HTTP request to contact the server {:?}",