diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 094fdf6..cf03256 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -4,6 +4,7 @@ use crate::{tunnel, WsClientConfig}; use futures_util::pin_mut; use hyper::header::COOKIE; use jsonwebtoken::TokenData; +use log::debug; use std::future::Future; use std::ops::Deref; use std::sync::Arc; @@ -25,19 +26,20 @@ where W: AsyncWrite + Send + 'static, { // Connect to server with the correct protocol - let (ws_rx, ws_tx) = match client_cfg.remote_addr.scheme() { + let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { TransportScheme::Ws | TransportScheme::Wss => { tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg) .await - .map(|(r, w, _response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w)))? + .map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))? } TransportScheme::Http | TransportScheme::Https => { tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg) .await - .map(|(r, w, _response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w)))? + .map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))? } }; + debug!("Server response: {:?}", response); let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); @@ -121,6 +123,7 @@ where }; // Connect to endpoint + debug!("Server response: {:?}", response); let remote = response .headers .get(COOKIE) diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 8d7788b..a684249 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -20,7 +20,7 @@ use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; -use hyper::{http, Request, Response, StatusCode}; +use hyper::{http, Request, Response, StatusCode, Version}; use hyper_util::rt::TokioExecutor; use jsonwebtoken::TokenData; use once_cell::sync::Lazy; @@ -569,14 +569,20 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() move |req: Request| { let server_config = server_config.clone(); async move { - if !fastwebsockets::upgrade::is_upgrade_request(&req) { + if fastwebsockets::upgrade::is_upgrade_request(&req) { + ws_server_upgrade(server_config.clone(), client_addr, req) + .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) + .await + } else if req.version() == Version::HTTP_2 { http_server_upgrade(server_config.clone(), client_addr, req) .map::, _>(Ok) .await } else { - ws_server_upgrade(server_config.clone(), client_addr, req) - .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) - .await + error!("Invalid protocol version request, got {:?} while expecting either websocket http1 upgrade or http2", req.version()); + Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid protocol request".to_string())) + .unwrap()) } } }