Prepare host header ahead of time

This commit is contained in:
Σrebe - Romain GERARD 2023-10-26 22:04:41 +02:00
parent 79632bb058
commit f813d925d6
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
2 changed files with 21 additions and 13 deletions

View file

@ -10,7 +10,8 @@ mod udp;
use base64::Engine; use base64::Engine;
use clap::Parser; use clap::Parser;
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
use hyper::http::HeaderValue; use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
@ -119,7 +120,7 @@ struct Client {
/// Send custom headers in the upgrade request /// Send custom headers in the upgrade request
/// Can be specified multiple time /// Can be specified multiple time
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)] #[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
http_headers: Vec<(String, HeaderValue)>, http_headers: Vec<(HeaderName, HeaderValue)>,
/// Address of the wstunnel server /// Address of the wstunnel server
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com /// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
@ -449,7 +450,8 @@ pub struct WsClientConfig {
pub tls: Option<TlsClientConfig>, pub tls: Option<TlsClientConfig>,
pub http_upgrade_path_prefix: String, pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>, pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<String, HeaderValue>, pub http_headers: HashMap<HeaderName, HeaderValue>,
pub host_http_header: HeaderValue,
pub timeout_connect: Duration, pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration, pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool, pub websocket_mask_frame: bool,
@ -528,6 +530,13 @@ async fn main() {
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()), _ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
}; };
// Extract host header from http_headers
let host_header =
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
host_val.clone()
} else {
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
};
let mut client_config = WsClientConfig { let mut client_config = WsClientConfig {
remote_addr: ( remote_addr: (
args.remote_addr.host().unwrap().to_owned(), args.remote_addr.host().unwrap().to_owned(),
@ -537,7 +546,12 @@ async fn main() {
tls, tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials, http_upgrade_credentials: args.http_upgrade_credentials,
http_headers: args.http_headers.into_iter().collect(), http_headers: args
.http_headers
.into_iter()
.filter(|(k, _)| k != HOST)
.collect(),
host_http_header: host_header,
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args websocket_ping_frequency: args
.websocket_ping_frequency_sec .websocket_ping_frequency_sec

View file

@ -58,24 +58,18 @@ pub async fn connect(
&client_cfg.http_upgrade_path_prefix, &client_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(), jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
)) ))
.header(HOST, &client_cfg.host_http_header)
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13") .header(SEC_WEBSOCKET_VERSION, "13")
.version(hyper::Version::HTTP_11); .version(hyper::Version::HTTP_11);
let mut contains_host = false;
for (k, v) in &client_cfg.http_headers { for (k, v) in &client_cfg.http_headers {
if k == HOST.as_str() { req = req.header(k, v);
contains_host = true;
}
req = req.header(k.clone(), v.clone());
}
if !contains_host {
req = req.header(HOST, client_cfg.remote_addr.0.to_string());
} }
if let Some(auth) = &client_cfg.http_upgrade_credentials { if let Some(auth) = &client_cfg.http_upgrade_credentials {
req = req.header(AUTHORIZATION, auth.clone()); req = req.header(AUTHORIZATION, auth);
} }
let req = req.body(Body::empty()).with_context(|| { let req = req.body(Body::empty()).with_context(|| {