Prepare host header ahead of time
This commit is contained in:
parent
79632bb058
commit
f813d925d6
2 changed files with 21 additions and 13 deletions
22
src/main.rs
22
src/main.rs
|
@ -10,7 +10,8 @@ mod udp;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
|
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};
|
||||||
use hyper::http::HeaderValue;
|
use hyper::header::HOST;
|
||||||
|
use hyper::http::{HeaderName, HeaderValue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::{BTreeMap, HashMap};
|
||||||
use std::fmt::{Debug, Formatter};
|
use std::fmt::{Debug, Formatter};
|
||||||
|
@ -119,7 +120,7 @@ struct Client {
|
||||||
/// Send custom headers in the upgrade request
|
/// Send custom headers in the upgrade request
|
||||||
/// Can be specified multiple time
|
/// Can be specified multiple time
|
||||||
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
|
#[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)]
|
||||||
http_headers: Vec<(String, HeaderValue)>,
|
http_headers: Vec<(HeaderName, HeaderValue)>,
|
||||||
|
|
||||||
/// Address of the wstunnel server
|
/// Address of the wstunnel server
|
||||||
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
|
/// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com
|
||||||
|
@ -449,7 +450,8 @@ pub struct WsClientConfig {
|
||||||
pub tls: Option<TlsClientConfig>,
|
pub tls: Option<TlsClientConfig>,
|
||||||
pub http_upgrade_path_prefix: String,
|
pub http_upgrade_path_prefix: String,
|
||||||
pub http_upgrade_credentials: Option<HeaderValue>,
|
pub http_upgrade_credentials: Option<HeaderValue>,
|
||||||
pub http_headers: HashMap<String, HeaderValue>,
|
pub http_headers: HashMap<HeaderName, HeaderValue>,
|
||||||
|
pub host_http_header: HeaderValue,
|
||||||
pub timeout_connect: Duration,
|
pub timeout_connect: Duration,
|
||||||
pub websocket_ping_frequency: Duration,
|
pub websocket_ping_frequency: Duration,
|
||||||
pub websocket_mask_frame: bool,
|
pub websocket_mask_frame: bool,
|
||||||
|
@ -528,6 +530,13 @@ async fn main() {
|
||||||
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
|
_ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Extract host header from http_headers
|
||||||
|
let host_header =
|
||||||
|
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
|
||||||
|
host_val.clone()
|
||||||
|
} else {
|
||||||
|
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
|
||||||
|
};
|
||||||
let mut client_config = WsClientConfig {
|
let mut client_config = WsClientConfig {
|
||||||
remote_addr: (
|
remote_addr: (
|
||||||
args.remote_addr.host().unwrap().to_owned(),
|
args.remote_addr.host().unwrap().to_owned(),
|
||||||
|
@ -537,7 +546,12 @@ async fn main() {
|
||||||
tls,
|
tls,
|
||||||
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
||||||
http_upgrade_credentials: args.http_upgrade_credentials,
|
http_upgrade_credentials: args.http_upgrade_credentials,
|
||||||
http_headers: args.http_headers.into_iter().collect(),
|
http_headers: args
|
||||||
|
.http_headers
|
||||||
|
.into_iter()
|
||||||
|
.filter(|(k, _)| k != HOST)
|
||||||
|
.collect(),
|
||||||
|
host_http_header: host_header,
|
||||||
timeout_connect: Duration::from_secs(10),
|
timeout_connect: Duration::from_secs(10),
|
||||||
websocket_ping_frequency: args
|
websocket_ping_frequency: args
|
||||||
.websocket_ping_frequency_sec
|
.websocket_ping_frequency_sec
|
||||||
|
|
|
@ -58,24 +58,18 @@ pub async fn connect(
|
||||||
&client_cfg.http_upgrade_path_prefix,
|
&client_cfg.http_upgrade_path_prefix,
|
||||||
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
|
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
|
||||||
))
|
))
|
||||||
|
.header(HOST, &client_cfg.host_http_header)
|
||||||
.header(UPGRADE, "websocket")
|
.header(UPGRADE, "websocket")
|
||||||
.header(CONNECTION, "upgrade")
|
.header(CONNECTION, "upgrade")
|
||||||
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
|
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
|
||||||
.header(SEC_WEBSOCKET_VERSION, "13")
|
.header(SEC_WEBSOCKET_VERSION, "13")
|
||||||
.version(hyper::Version::HTTP_11);
|
.version(hyper::Version::HTTP_11);
|
||||||
|
|
||||||
let mut contains_host = false;
|
|
||||||
for (k, v) in &client_cfg.http_headers {
|
for (k, v) in &client_cfg.http_headers {
|
||||||
if k == HOST.as_str() {
|
req = req.header(k, v);
|
||||||
contains_host = true;
|
|
||||||
}
|
|
||||||
req = req.header(k.clone(), v.clone());
|
|
||||||
}
|
|
||||||
if !contains_host {
|
|
||||||
req = req.header(HOST, client_cfg.remote_addr.0.to_string());
|
|
||||||
}
|
}
|
||||||
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||||
req = req.header(AUTHORIZATION, auth.clone());
|
req = req.header(AUTHORIZATION, auth);
|
||||||
}
|
}
|
||||||
|
|
||||||
let req = req.body(Body::empty()).with_context(|| {
|
let req = req.body(Body::empty()).with_context(|| {
|
||||||
|
|
Loading…
Reference in a new issue