diff --git a/Cargo.lock b/Cargo.lock index 2abb276..98ad290 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -385,11 +385,14 @@ dependencies = [ [[package]] name = "fastwebsockets" -version = "0.5.0" -source = "git+https://github.com/denoland/fastwebsockets?rev=35a1930fdbcdfe9034bb531fcd1690ba9f7737ec#35a1930fdbcdfe9034bb531fcd1690ba9f7737ec" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63dd7b57f9b33b1741fa631c9522eb35d43e96dcca4a6a91d5e4ca7c93acdc1" dependencies = [ "base64", + "http-body-util", "hyper", + "hyper-util", "pin-project", "rand", "sha1", @@ -561,9 +564,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.11" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" dependencies = [ "bytes", "fnv", @@ -572,12 +575,24 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", "pin-project-lite", ] @@ -595,13 +610,12 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.27" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "403f9214f3e703236b221f1a9cd88ec8b4adfa5296de01ab96216361f4692f56" dependencies = [ "bytes", "futures-channel", - "futures-core", "futures-util", "http", "http-body", @@ -609,11 +623,28 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", "tokio", + "want", +] + +[[package]] +name = "hyper-util" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ca339002caeb0d159cc6e023dff48e199f081e42fa039895c7c6f38b37f2e9d" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", + "tower", "tower-service", "tracing", - "want", ] [[package]] @@ -1235,16 +1266,6 @@ version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "socket2" version = "0.5.5" @@ -1402,7 +1423,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.5", + "socket2", "tokio-macros", "windows-sys 0.48.0", ] @@ -1462,6 +1483,28 @@ dependencies = [ "tokio", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" @@ -1860,7 +1903,9 @@ dependencies = [ "fast-socks5", "fastwebsockets", "futures-util", + "http-body-util", "hyper", + "hyper-util", "jsonwebtoken", "log", "nix", @@ -1871,7 +1916,7 @@ dependencies = [ "rustls-pemfile", "scopeguard", "serde", - "socket2 0.5.5", + "socket2", "testcontainers", "tokio", "tokio-fd", diff --git a/Cargo.toml b/Cargo.toml index 2b3b6b9..e5f87f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,16 +13,18 @@ base64 = "0.21.5" bb8 = { version = "0.8", features = [] } bytes = { version = "1.5.0", features = [] } -clap = { version = "4.4.10", features = ["derive"] } -fast-socks5 = { version = "0.9.1", features = [] } -fastwebsockets = { git = "https://github.com/denoland/fastwebsockets", rev = '35a1930fdbcdfe9034bb531fcd1690ba9f7737ec', features = ["upgrade", "simd", "unstable-split"] } +clap = { version = "4.4.11", features = ["derive"] } +fast-socks5 = { version = "0.9.2", features = [] } +fastwebsockets = { version = "0.6.0", features = ["upgrade", "simd", "unstable-split"] } futures-util = { version = "0.3.29" } -hyper = { version = "0.14.27", features = ["client", "runtime"] } -jsonwebtoken = { version = "9.1.0", default-features = false } +hyper = { version = "1.0.1", features = ["client", "http1"] } +hyper-util = { version = "0.1.0", features = ["tokio"] } +http-body-util = { version = "0.1.0" } +jsonwebtoken = { version = "9.2.0", default-features = false } log = "0.4.20" nix = { version = "0.27.1", features = ["socket", "net", "uio"] } -once_cell = { version = "1.18.0", features = [] } +once_cell = { version = "1.19.0", features = [] } parking_lot = "0.12.1" pin-project = "1" @@ -31,7 +33,7 @@ rustls-pemfile = { version = "2.0.0", features = [] } scopeguard = "1.2.0" serde = { version = "1.0.193", features = ["derive"] } socket2 = { version = "0.5.5", features = [] } -tokio = { version = "1.34.0", features = ["full"] } +tokio = { version = "1.35.0", features = ["full"] } tokio-rustls = { version = "0.24.1", features = ["tls12", "dangerous_configuration", "early-data"] } tokio-stream = { version = "0.1.14", features = ["net"] } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 77163bd..cab8021 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -3,12 +3,16 @@ use crate::{LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; use base64::Engine; +use bytes::Bytes; use fastwebsockets::WebSocket; use futures_util::pin_mut; +use http_body_util::Empty; +use hyper::body::Incoming; use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; -use hyper::{Body, Request, Response}; +use hyper::{Request, Response}; +use hyper_util::rt::{TokioExecutor, TokioIo}; use std::future::Future; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -20,18 +24,6 @@ use tracing::{error, span, Instrument, Level, Span}; use url::{Host, Url}; use uuid::Uuid; -struct SpawnExecutor; - -impl hyper::rt::Executor for SpawnExecutor -where - Fut: Future + Send + 'static, - Fut::Output: Send + 'static, -{ - fn execute(&self, fut: Fut) { - tokio::task::spawn(fut); - } -} - fn tunnel_to_jwt_token(request_id: Uuid, tunnel: &LocalToRemote) -> String { let cfg = JwtTunnelConfig::new(request_id, tunnel); let (alg, secret) = JWT_KEY.deref(); @@ -42,7 +34,7 @@ pub async fn connect( request_id: Uuid, client_cfg: &WsClientConfig, tunnel_cfg: &LocalToRemote, -) -> anyhow::Result<(WebSocket, Response)> { +) -> anyhow::Result<(WebSocket>, Response)> { let mut pooled_cnx = match client_cfg.cnx_pool().get().await { Ok(tcp_stream) => tcp_stream, Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?, @@ -69,7 +61,7 @@ pub async fn connect( req = req.header(AUTHORIZATION, auth); } - let req = req.body(Body::empty()).with_context(|| { + let req = req.body(Empty::::new()).with_context(|| { format!( "failed to build HTTP request to contact the server {:?}", client_cfg.remote_addr @@ -77,7 +69,7 @@ pub async fn connect( })?; debug!("with HTTP upgrade request {:?}", req); let transport = pooled_cnx.deref_mut().take().unwrap(); - let (ws, response) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport) + let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) .await .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 088e8ec..683d1b0 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -2,6 +2,7 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebS use futures_util::{pin_mut, FutureExt}; use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::select; @@ -12,7 +13,7 @@ use tracing::{error, info, trace, warn}; pub(super) async fn propagate_read( local_rx: impl AsyncRead, - mut ws_tx: WebSocketWrite>, + mut ws_tx: WebSocketWrite>>, mut close_tx: oneshot::Sender<()>, ping_frequency: Option, ) -> Result<(), WebSocketError> { @@ -84,7 +85,7 @@ pub(super) async fn propagate_read( pub(super) async fn propagate_write( local_tx: impl AsyncWrite, - mut ws_rx: WebSocketRead>, + mut ws_rx: WebSocketRead>>, mut close_rx: oneshot::Receiver<()>, ) -> Result<(), WebSocketError> { let _guard = scopeguard::guard((), |_| { diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 653356a..808a541 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -12,11 +12,12 @@ use std::time::Duration; use super::{JwtTunnelConfig, JWT_DECODE}; use crate::{socks5, tcp, tls, udp, LocalProtocol, WsServerConfig}; +use hyper::body::Incoming; use hyper::header::COOKIE; use hyper::http::HeaderValue; -use hyper::server::conn::Http; +use hyper::server::conn::http1; use hyper::service::service_fn; -use hyper::{http, Body, Request, Response, StatusCode}; +use hyper::{http, Request, Response, StatusCode}; use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; @@ -188,8 +189,8 @@ where async fn server_upgrade( server_config: Arc, - mut req: Request, -) -> Result, anyhow::Error> { + mut req: Request, +) -> Result, anyhow::Error> { if let Some(x) = req.headers().get("X-Forwarded-For") { info!("Request X-Forwarded-For: {:?}", x); Span::current().record("forwarded_for", x.to_str().unwrap_or_default()); @@ -199,8 +200,8 @@ async fn server_upgrade( warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid upgrade request")) - .unwrap_or_default()); + .body("Invalid upgrade request".into()) + .unwrap()); } if let Some(paths_prefix) = &server_config.restrict_http_upgrade_path_prefix { @@ -217,8 +218,8 @@ async fn server_upgrade( warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid upgrade request")) - .unwrap_or_default()); + .body("Invalid upgrade request".to_string()) + .unwrap()); } } @@ -229,8 +230,8 @@ async fn server_upgrade( warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from(format!("Invalid upgrade request: {:?}", err))) - .unwrap_or_default()); + .body("Invalid upgrade request".to_string()) + .unwrap()); } }; @@ -241,8 +242,8 @@ async fn server_upgrade( warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from(format!("Invalid upgrade request: {:?}", err))) - .unwrap_or_default()); + .body(format!("Invalid upgrade request: {:?}", err)) + .unwrap()); } }; @@ -273,6 +274,8 @@ async fn server_upgrade( )?, ); } + + let response = Response::from_parts(response.into_parts().0, "".to_string()); Ok(response) } @@ -280,7 +283,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Starting wstunnel server listening on {}", server_config.bind); let config = server_config.clone(); - let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req); + let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req); let listener = TcpListener::bind(&server_config.bind).await?; let tls_acceptor = if let Some(tls) = &server_config.tls { @@ -310,14 +313,14 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() let fut = async move { info!("Doing TLS handshake"); let tls_stream = match tls_acceptor.accept(stream).await { - Ok(tls_stream) => tls_stream, + Ok(tls_stream) => hyper_util::rt::TokioIo::new(tls_stream), Err(err) => { error!("error while accepting TLS connection {}", err); return; } }; - let conn_fut = Http::new() - .http1_only(true) + + let conn_fut = http1::Builder::new() .serve_connection(tls_stream, service_fn(upgrade_fn)) .with_upgrades(); @@ -330,8 +333,8 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() tokio::spawn(fut); // Normal } else { - let conn_fut = Http::new() - .http1_only(true) + let stream = hyper_util::rt::TokioIo::new(stream); + let conn_fut = http1::Builder::new() .serve_connection(stream, service_fn(upgrade_fn)) .with_upgrades();