diff --git a/src/main.rs b/src/main.rs index 0d787db..f02f165 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,8 @@ mod udp; use base64::Engine; use clap::Parser; use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; -use hyper::http::HeaderValue; +use hyper::header::HOST; +use hyper::http::{HeaderName, HeaderValue}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Debug, Formatter}; @@ -119,7 +120,7 @@ struct Client { /// Send custom headers in the upgrade request /// Can be specified multiple time #[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)] - http_headers: Vec<(String, HeaderValue)>, + http_headers: Vec<(HeaderName, HeaderValue)>, /// Address of the wstunnel server /// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com @@ -449,7 +450,8 @@ pub struct WsClientConfig { pub tls: Option, pub http_upgrade_path_prefix: String, pub http_upgrade_credentials: Option, - pub http_headers: HashMap, + pub http_headers: HashMap, + pub host_http_header: HeaderValue, pub timeout_connect: Duration, pub websocket_ping_frequency: Duration, pub websocket_mask_frame: bool, @@ -528,6 +530,13 @@ async fn main() { _ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()), }; + // Extract host header from http_headers + let host_header = + if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) { + host_val.clone() + } else { + HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap() + }; let mut client_config = WsClientConfig { remote_addr: ( args.remote_addr.host().unwrap().to_owned(), @@ -537,7 +546,12 @@ async fn main() { tls, http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_credentials: args.http_upgrade_credentials, - http_headers: args.http_headers.into_iter().collect(), + http_headers: args + .http_headers + .into_iter() + .filter(|(k, _)| k != HOST) + .collect(), + host_http_header: host_header, timeout_connect: Duration::from_secs(10), websocket_ping_frequency: args .websocket_ping_frequency_sec diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 8ad8d77..37e66eb 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -58,24 +58,18 @@ pub async fn connect( &client_cfg.http_upgrade_path_prefix, jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(), )) + .header(HOST, &client_cfg.host_http_header) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) .header(SEC_WEBSOCKET_VERSION, "13") .version(hyper::Version::HTTP_11); - let mut contains_host = false; for (k, v) in &client_cfg.http_headers { - if k == HOST.as_str() { - contains_host = true; - } - req = req.header(k.clone(), v.clone()); - } - if !contains_host { - req = req.header(HOST, client_cfg.remote_addr.0.to_string()); + req = req.header(k, v); } if let Some(auth) = &client_cfg.http_upgrade_credentials { - req = req.header(AUTHORIZATION, auth.clone()); + req = req.header(AUTHORIZATION, auth); } let req = req.body(Body::empty()).with_context(|| {