Pass tunnel info into sec-websocket-protocol header

This commit is contained in:
Σrebe - Romain GERARD 2023-12-26 21:16:34 +01:00
parent f752ce67fb
commit 259da14d4d
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
3 changed files with 29 additions and 24 deletions

View file

@ -10,10 +10,10 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use super::{JwtTunnelConfig, JWT_DECODE};
use super::{JwtTunnelConfig, JWT_DECODE, JWT_HEADER_PREFIX};
use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig};
use hyper::body::Incoming;
use hyper::header::COOKIE;
use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL};
use hyper::http::HeaderValue;
use hyper::server::conn::http1;
use hyper::service::service_fn;
@ -221,22 +221,23 @@ fn validate_url(
#[inline]
fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelConfig>, Response<String>> {
let jwt: TokenData<JwtTunnelConfig> = match req.uri().query().unwrap_or_default().split_once('=') {
Some(("bearer", jwt)) => {
let (validation, decode_key) = JWT_DECODE.deref();
match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt,
err => {
error!("error while decoding jwt for tunnel info {:?}", err);
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
.unwrap());
}
}
}
let jwt = req
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|header| header.to_str().ok())
.and_then(|header| header.split_once(JWT_HEADER_PREFIX))
.map(|(_prefix, jwt)| jwt)
.unwrap_or_default();
let (validation, decode_key) = JWT_DECODE.deref();
let jwt = match jsonwebtoken::decode(jwt, decode_key, validation) {
Ok(jwt) => jwt,
err => {
error!("Missing jwt tunnel config from request {:?}", err);
warn!(
"error while decoding jwt for tunnel info {:?} header {:?}",
err,
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
);
return Err(http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Invalid upgrade request".to_string())
@ -358,6 +359,9 @@ async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Inc
};
response.headers_mut().insert(COOKIE, header_val);
}
response
.headers_mut()
.insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("v1"));
Response::from_parts(response.into_parts().0, "".to_string())
}