diff --git a/Cargo.toml b/Cargo.toml index 41dc3c4..310ec7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wstunnel" -version = "9.2.2" +version = "9.2.3" edition = "2021" repository = "https://github.com/erebe/wstunnel.git" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs index 3afadd9..9b6eadb 100644 --- a/src/tunnel/transport/http2.rs +++ b/src/tunnel/transport/http2.rs @@ -1,5 +1,5 @@ use crate::tunnel::transport::{headers_from_file, TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; -use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr}; +use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, TransportScheme}; use crate::WsClientConfig; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; @@ -7,7 +7,6 @@ use http_body_util::{BodyExt, BodyStream, StreamBody}; use hyper::body::{Frame, Incoming}; use hyper::header::{AUTHORIZATION, CONTENT_TYPE, COOKIE}; use hyper::http::response::Parts; -use hyper::http::HeaderName; use hyper::Request; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use log::{debug, error, warn}; @@ -110,12 +109,18 @@ pub async fn connect( // In http2 HOST header does not exist, it is explicitly set in the authority from the request uri let (headers_file, authority) = if let Some(headers_file_path) = &client_cfg.http_headers_file { - let headers = headers_from_file(headers_file_path); - let host = headers - .iter() - .find(|(h, _)| h == HeaderName::from_static("host")) - .and_then(|(_, v)| v.to_str().ok()) - .map(|v| v.to_string()); + let (host, headers) = headers_from_file(headers_file_path); + let host = if let Some((_, v)) = host { + match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) { + (TransportScheme::Http, 80) | (TransportScheme::Https, 443) => { + Some(v.to_str().unwrap_or("").to_string()) + } + (_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)), + } + } else { + None + }; + (Some(headers), host) } else { (None, None) diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs index aeb6bac..f41301c 100644 --- a/src/tunnel/transport/mod.rs +++ b/src/tunnel/transport/mod.rs @@ -79,24 +79,34 @@ impl TunnelWrite for TunnelWriter { } } +#[allow(clippy::type_complexity)] #[inline] -pub fn headers_from_file(path: &Path) -> Vec<(HeaderName, HeaderValue)> { +pub fn headers_from_file(path: &Path) -> (Option<(HeaderName, HeaderValue)>, Vec<(HeaderName, HeaderValue)>) { + static HOST_HEADER: HeaderName = HeaderName::from_static("host"); + let file = match std::fs::File::open(path) { Ok(file) => file, Err(err) => { error!("Cannot read headers from file: {:?}: {:?}", path, err); - return vec![]; + return (None, vec![]); } }; - BufReader::new(file) + let mut host_header = None; + let headers = BufReader::new(file) .lines() .filter_map(|line| { let line = line.ok()?; let (header, value) = line.split_once(':')?; let header = HeaderName::from_str(header.trim()).ok()?; let value = HeaderValue::from_str(value.trim()).ok()?; + if header == HOST_HEADER { + host_header = Some((header, value)); + return None; + } Some((header, value)) }) - .collect() + .collect(); + + (host_header, headers) } diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index 583adf4..f433da5 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -167,10 +167,15 @@ pub async fn connect( } if let Some(headers_file_path) = &client_cfg.http_headers_file { - for (k, v) in headers_from_file(headers_file_path) { + let (host, headers_file) = headers_from_file(headers_file_path); + for (k, v) in headers_file { let _ = headers.remove(&k); headers.append(k, v); } + if let Some((host, val)) = host { + let _ = headers.remove(&host); + headers.append(host, val); + } } let req = req.body(Empty::::new()).with_context(|| {