diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index cab8021..434330b 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,4 +1,4 @@ -use super::{to_host_port, JwtTunnelConfig, JWT_KEY}; +use super::{to_host_port, JwtTunnelConfig, JWT_HEADER_PREFIX, JWT_KEY}; use crate::{LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; @@ -8,7 +8,7 @@ use fastwebsockets::WebSocket; use futures_util::pin_mut; use http_body_util::Empty; use hyper::body::Incoming; -use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_VERSION, UPGRADE}; +use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; use hyper::{Request, Response}; @@ -42,16 +42,16 @@ pub async fn connect( let mut req = Request::builder() .method("GET") - .uri(format!( - "/{}/events?bearer={}", - &client_cfg.http_upgrade_path_prefix, - tunnel_to_jwt_token(request_id, tunnel_cfg) - )) + .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,)) .header(HOST, &client_cfg.http_header_host) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) .header(SEC_WEBSOCKET_VERSION, "13") + .header( + SEC_WEBSOCKET_PROTOCOL, + format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, tunnel_cfg)), + ) .version(hyper::Version::HTTP_11); for (k, v) in &client_cfg.http_headers { diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index fd68a89..3c80a34 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -49,6 +49,7 @@ impl JwtTunnelConfig { } } +static JWT_HEADER_PREFIX: &str = "authorization.bearer."; static JWT_SECRET: &[u8; 15] = b"champignonfrais"; static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET))); diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 32746c7..e00f93d 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -10,10 +10,10 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use super::{JwtTunnelConfig, JWT_DECODE}; +use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX}; use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig}; use hyper::body::Incoming; -use hyper::header::COOKIE; +use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; use hyper::server::conn::http1; use hyper::service::service_fn; @@ -221,22 +221,23 @@ fn validate_url( #[inline] fn extract_tunnel_info(req: &Request) -> Result, Response> { - let jwt: TokenData = match req.uri().query().unwrap_or_default().split_once('=') { - Some(("bearer", jwt)) => { - let (validation, decode_key) = JWT_DECODE.deref(); - match jsonwebtoken::decode(jwt, decode_key, validation) { - Ok(jwt) => jwt, - err => { - error!("error while decoding jwt for tunnel info {:?}", err); - return Err(http::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid upgrade request".to_string()) - .unwrap()); - } - } - } + let jwt = req + .headers() + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|header| header.to_str().ok()) + .and_then(|header| header.split_once(JWT_HEADER_PREFIX)) + .map(|(_prefix, jwt)| jwt) + .unwrap_or_default(); + + let (validation, decode_key) = JWT_DECODE.deref(); + let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) { + Ok(jwt) => jwt, err => { - error!("Missing jwt tunnel config from request {:?}", err); + warn!( + "error while decoding jwt for tunnel info {:?} header {:?}", + err, + req.headers().get(SEC_WEBSOCKET_PROTOCOL) + ); return Err(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body("Invalid upgrade request".to_string()) @@ -358,6 +359,9 @@ async fn server_upgrade(server_config: Arc, mut req: Request