From 459a0667b198cd6477415801a93fe25f3704119a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sun, 14 Jan 2024 19:18:57 +0100 Subject: [PATCH] Add suport for http2 as transport for tunnel --- Cargo.lock | 22 ++- Cargo.toml | 4 +- src/main.rs | 15 +- src/tunnel/client.rs | 89 ++++------- src/tunnel/server.rs | 241 ++++++++++++++++++++++++------ src/tunnel/transport/http2.rs | 168 +++++++++++++++++++++ src/tunnel/transport/io.rs | 38 ++--- src/tunnel/transport/mod.rs | 67 ++++++++- src/tunnel/transport/websocket.rs | 130 ++++++++++++---- 9 files changed, 606 insertions(+), 168 deletions(-) create mode 100644 src/tunnel/transport/http2.rs diff --git a/Cargo.lock b/Cargo.lock index 066b306..2afb6f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -613,6 +613,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "991910e35c615d8cab86b5ab04be67e6ad24d2bf5f4f11fdbbed26da999bbeab" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -651,7 +670,7 @@ dependencies = [ "futures-channel", "futures-io", "futures-util", - "h2", + "h2 0.3.22", "http 0.2.11", "idna 0.4.0", "ipnet", @@ -776,6 +795,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2 0.4.1", "http 1.0.0", "http-body", "httparse", diff --git a/Cargo.toml b/Cargo.toml index 40b469d..2af9efe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ futures-util = { version = "0.3.30" } hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] } ppp = { version = "2.2.0", features = [] } -hyper = { version = "1.1.0", features = ["client", "http1"] } -hyper-util = { version = "0.1.2", features = ["tokio"] } +hyper = { version = "1.1.0", features = ["client", "http1", "http2"] } +hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"] } http-body-util = { version = "0.1.0" } jsonwebtoken = { version = "9.2.0", default-features = false } log = "0.4.20" diff --git a/src/main.rs b/src/main.rs index 8e752d4..1b2cd74 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,7 +72,7 @@ struct Wstunnel { nb_worker_threads: Option, /// Control the log verbosity. i.e: TRACE, DEBUG, INFO, WARN, ERROR, OFF - /// for more details: https://docs.rs/env_logger/0.10.1/env_logger/#enabling-logging + /// for more details: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax #[arg( long, global = true, @@ -81,7 +81,7 @@ struct Wstunnel { env = "RUST_LOG", default_value = "INFO" )] - log_lvl: Directive, + log_lvl: String, } #[derive(clap::Subcommand, Debug)] @@ -645,13 +645,14 @@ async fn main() { .count() > 0 => {} _ => { + let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level"); + if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) { + env_filter = + env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive")); + } tracing_subscriber::fmt() .with_ansi(args.no_color.is_none()) - .with_env_filter( - EnvFilter::builder() - .with_default_directive(args.log_lvl) - .from_env_lossy(), - ) + .with_env_filter(env_filter) .init(); } } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index ed3f71b..094fdf6 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,4 +1,5 @@ -use super::{JwtTunnelConfig, RemoteAddr, JWT_DECODE}; +use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; +use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::{tunnel, WsClientConfig}; use futures_util::pin_mut; use hyper::header::COOKIE; @@ -13,52 +14,6 @@ use tracing::{error, span, Instrument, Level, Span}; use url::Host; use uuid::Uuid; -//async fn connect_http2( -// request_id: Uuid, -// client_cfg: &WsClientConfig, -// dest_addr: &RemoteAddr, -//) -> anyhow::Result> { -// let mut pooled_cnx = match client_cfg.cnx_pool().get().await { -// Ok(cnx) => Ok(cnx), -// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), -// }?; -// -// let mut req = Request::builder() -// .method("GET") -// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix)) -// .header(HOST, &client_cfg.http_header_host) -// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr)) -// .version(hyper::Version::HTTP_2); -// -// for (k, v) in &client_cfg.http_headers { -// req = req.header(k, v); -// } -// if let Some(auth) = &client_cfg.http_upgrade_credentials { -// req = req.header(AUTHORIZATION, auth); -// } -// -// let x: Vec = vec![]; -// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))])); -// let req = req.body(Empty::::new()).with_context(|| { -// format!( -// "failed to build HTTP request to contact the server {:?}", -// client_cfg.remote_addr -// ) -// })?; -// debug!("with HTTP upgrade request {:?}", req); -// let transport = pooled_cnx.deref_mut().take().unwrap(); -// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await -// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?; -// tokio::spawn(cnx); -// -// let response = request_sender.send_request(req) -// .await -// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?; -// -// // TODO: verify response is ok -// Ok(BodyStream::new(response.into_body())) -//} - async fn connect_to_server( request_id: Uuid, client_cfg: &WsClientConfig, @@ -69,7 +24,20 @@ where R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { - let ((ws_rx, ws_tx), _) = tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg).await?; + // Connect to server with the correct protocol + let (ws_rx, ws_tx) = match client_cfg.remote_addr.scheme() { + TransportScheme::Ws | TransportScheme::Wss => { + tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg) + .await + .map(|(r, w, _response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w)))? + } + TransportScheme::Http | TransportScheme::Https => { + tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg) + .await + .map(|(r, w, _response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w)))? + } + }; + let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); @@ -117,7 +85,7 @@ where } pub async fn run_reverse_tunnel( - client_config: Arc, + client_cfg: Arc, remote_addr: RemoteAddr, connect_to_dest: F, ) -> anyhow::Result<()> @@ -127,7 +95,7 @@ where T: AsyncRead + AsyncWrite + Send + 'static, { loop { - let client_config = client_config.clone(); + let client_config = client_cfg.clone(); let request_id = Uuid::now_v7(); let span = span!( Level::INFO, @@ -136,16 +104,25 @@ where remote = format!("{}:{}", remote_addr.host, remote_addr.port) ); let _span = span.enter(); - // Correctly configure tunnel cfg - let ((ws_rx, ws_tx), response) = - tunnel::transport::websocket::connect(request_id, &client_config, &remote_addr) - .instrument(span.clone()) - .await?; + let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() { + TransportScheme::Ws | TransportScheme::Wss => { + tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr) + .instrument(span.clone()) + .await + .map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))? + } + TransportScheme::Http | TransportScheme::Https => { + tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr) + .instrument(span.clone()) + .await + .map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))? + } + }; // Connect to endpoint let remote = response - .headers() + .headers .get(COOKIE) .and_then(|h| h.to_str().ok()) .and_then(|h| { diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 0a82ac9..8d7788b 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,6 +1,9 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; +use bytes::Bytes; use futures_util::{pin_mut, FutureExt, Stream, StreamExt}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyStream, Either, StreamBody}; use std::cmp::min; use std::fmt::Debug; use std::future::Future; @@ -12,24 +15,28 @@ use std::time::Duration; use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig}; -use hyper::body::Incoming; +use hyper::body::{Frame, Incoming}; use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::http::HeaderValue; -use hyper::server::conn::http1; +use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; use hyper::{http, Request, Response, StatusCode}; +use hyper_util::rt::TokioExecutor; use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; use crate::socks5::Socks5Stream; use crate::tunnel::tls_reloader::TlsReloader; +use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; +use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::select; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::TlsAcceptor; +use tokio_stream::wrappers::ReceiverStream; use tracing::{error, info, span, warn, Instrument, Level, Span}; use url::Host; use uuid::Uuid; @@ -284,6 +291,7 @@ fn extract_tunnel_info(req: &Request) -> Result, mut client_addr: SocketAddr, mut req: Request, @@ -404,10 +412,17 @@ async fn server_upgrade( ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); tokio::task::spawn( - super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()), + super::transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx) + .instrument(Span::current()), ); - let _ = super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, None).await; + let _ = super::transport::io::propagate_local_to_remote( + local_rx, + WebsocketTunnelWrite::new(ws_tx), + close_tx, + None, + ) + .await; } .instrument(Span::current()), ); @@ -429,6 +444,92 @@ async fn server_upgrade( Response::from_parts(response.into_parts().0, "".to_string()) } +async fn http_server_upgrade( + server_config: Arc, + mut client_addr: SocketAddr, + req: Request, +) -> Response>> { + match extract_x_forwarded_for(&req) { + Ok(Some((x_forward_for, x_forward_for_str))) => { + info!("Request X-Forwarded-For: {:?}", x_forward_for); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); + } + Ok(_) => {} + Err(err) => return err.map(Either::Left), + }; + + if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) { + return err.map(Either::Left); + } + + let jwt = match extract_tunnel_info(&req) { + Ok(jwt) => jwt, + Err(err) => return err.map(Either::Left), + }; + + Span::current().record("id", &jwt.claims.id); + Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); + + if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) { + return err.map(Either::Left); + } + + let req_protocol = jwt.claims.p.clone(); + let tunnel = match run_tunnel(&server_config, jwt, client_addr).await { + Ok(ret) => ret, + Err(err) => { + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + } + }; + + let (remote_addr, local_rx, local_tx) = tunnel; + info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); + + let ws_rx = BodyStream::new(req.into_body()); + let (ws_tx, rx) = mpsc::channel::(1024); + let body = BoxBody::new(StreamBody::new( + ReceiverStream::new(rx).map(|s| -> anyhow::Result> { Ok(Frame::data(s)) }), + )); + + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Either::Right(body)) + .expect("bug: failed to build response"); + + tokio::spawn( + async move { + let (close_tx, close_rx) = oneshot::channel::<()>(); + tokio::task::spawn( + super::transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx) + .instrument(Span::current()), + ); + + let _ = + super::transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None) + .await; + } + .instrument(Span::current()), + ); + + if req_protocol == LocalProtocol::ReverseSocks5 { + let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), &remote_addr)) else { + error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port); + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Either::Left("Invalid upgrade request".to_string())) + .unwrap(); + }; + response.headers_mut().insert(COOKIE, header_val); + } + + response +} + struct TlsContext<'a> { tls_acceptor: Arc, tls_reloader: TlsReloader, @@ -452,10 +553,32 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() info!("Starting wstunnel server listening on {}", server_config.bind); // setup upgrade request handler - // FIXME: Avoid double clone of the arc for each request - let mk_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { + let mk_websocket_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { move |req: Request| { - server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) + ws_server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) + } + }; + + let mk_http_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { + move |req: Request| { + http_server_upgrade(server_config.clone(), client_addr, req).map::, _>(Ok) + } + }; + + let mk_auto_upgrade_fn = |server_config: Arc, client_addr: SocketAddr| { + move |req: Request| { + let server_config = server_config.clone(); + async move { + if !fastwebsockets::upgrade::is_upgrade_request(&req) { + http_server_upgrade(server_config.clone(), client_addr, req) + .map::, _>(Ok) + .await + } else { + ws_server_upgrade(server_config.clone(), client_addr, req) + .map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left))) + .await + } + } } }; @@ -493,47 +616,75 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() ); info!("Accepting connection"); - let upgrade_fn = mk_upgrade_fn(server_config.clone(), peer_addr); - // TLS - if let Some(tls) = tls_context.as_mut() { - // Reload TLS certificate if needed - let tls_acceptor = tls.tls_acceptor().clone(); - let fut = async move { - info!("Doing TLS handshake"); - let tls_stream = match tls_acceptor.accept(stream).await { - Ok(tls_stream) => hyper_util::rt::TokioIo::new(tls_stream), - Err(err) => { - error!("error while accepting TLS connection {}", err); - return; + let server_config = server_config.clone(); + + // Check if we need to enable TLS or not + match tls_context.as_mut() { + Some(tls) => { + // Reload TLS certificate if needed + let tls_acceptor = tls.tls_acceptor().clone(); + let fut = async move { + info!("Doing TLS handshake"); + let tls_stream = match tls_acceptor.accept(stream).await { + Ok(tls_stream) => hyper_util::rt::TokioIo::new(tls_stream), + Err(err) => { + error!("error while accepting TLS connection {}", err); + return; + } + }; + + match tls_stream.inner().get_ref().1.alpn_protocol() { + // http2 + Some(b"h2") => { + let mut conn_builder = http2::Builder::new(TokioExecutor::new()); + if let Some(ping) = server_config.websocket_ping_frequency { + conn_builder.keep_alive_interval(ping); + } + + let http_upgrade_fn = mk_http_upgrade_fn(server_config, peer_addr); + let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn)); + if let Err(e) = con_fut.await { + error!("Error while upgrading cnx to http: {:?}", e); + } + } + // websocket + _ => { + let websocket_upgrade_fn = mk_websocket_upgrade_fn(server_config, peer_addr); + let conn_fut = http1::Builder::new() + .serve_connection(tls_stream, service_fn(websocket_upgrade_fn)) + .with_upgrades(); + + if let Err(e) = conn_fut.await { + error!("Error while upgrading cnx: {:?}", e); + } + } + }; + } + .instrument(span); + + tokio::spawn(fut); + // Normal + } + // HTTP without TLS + None => { + let fut = async move { + let stream = hyper_util::rt::TokioIo::new(stream); + let mut conn_fut = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + if let Some(ping) = server_config.websocket_ping_frequency { + conn_fut.http2().keep_alive_interval(ping); } - }; - let conn_fut = http1::Builder::new() - .serve_connection(tls_stream, service_fn(upgrade_fn)) - .with_upgrades(); + let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, peer_addr); + let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn)); - if let Err(e) = conn_fut.await { - error!("Error while upgrading cnx to websocket: {:?}", e); + if let Err(e) = upgradable.await { + error!("Error while upgrading cnx to websocket: {:?}", e); + } } + .instrument(span); + + tokio::spawn(fut); } - .instrument(span); - - tokio::spawn(fut); - // Normal - } else { - let stream = hyper_util::rt::TokioIo::new(stream); - let conn_fut = http1::Builder::new() - .serve_connection(stream, service_fn(upgrade_fn)) - .with_upgrades(); - - let fut = async move { - if let Err(e) = conn_fut.await { - error!("Error while upgrading cnx to websocket: {:?}", e); - } - } - .instrument(span); - - tokio::spawn(fut); - }; + } } } diff --git a/src/tunnel/transport/http2.rs b/src/tunnel/transport/http2.rs new file mode 100644 index 0000000..7c8ff39 --- /dev/null +++ b/src/tunnel/transport/http2.rs @@ -0,0 +1,168 @@ +use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; +use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr}; +use crate::WsClientConfig; +use anyhow::{anyhow, Context}; +use bytes::{Bytes, BytesMut}; +use http_body_util::{BodyExt, BodyStream, StreamBody}; +use hyper::body::{Frame, Incoming}; +use hyper::header::{AUTHORIZATION, COOKIE, HOST}; +use hyper::http::response::Parts; +use hyper::Request; +use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; +use log::{debug, error, warn}; +use std::io; +use std::io::ErrorKind; +use std::ops::DerefMut; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use uuid::Uuid; + +pub struct Http2TunnelRead { + inner: BodyStream, +} + +impl Http2TunnelRead { + pub fn new(inner: BodyStream) -> Self { + Self { inner } + } +} + +impl TunnelRead for Http2TunnelRead { + async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> { + loop { + match self.inner.next().await { + Some(Ok(frame)) => match frame.into_data() { + Ok(data) => { + return match writer.write_all(data.as_ref()).await { + Ok(_) => Ok(()), + Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)), + } + } + Err(err) => { + warn!("{:?}", err); + continue; + } + }, + Some(Err(err)) => { + return Err(io::Error::new(ErrorKind::ConnectionAborted, err)); + } + None => return Err(io::Error::new(ErrorKind::BrokenPipe, "closed")), + } + } + } +} + +pub struct Http2TunnelWrite { + inner: mpsc::Sender, + buf: BytesMut, +} + +impl Http2TunnelWrite { + pub fn new(inner: mpsc::Sender) -> Self { + Self { + inner, + buf: BytesMut::with_capacity(MAX_PACKET_LENGTH * 20), // ~ 1Mb + } + } +} + +impl TunnelWrite for Http2TunnelWrite { + fn buf_mut(&mut self) -> &mut BytesMut { + &mut self.buf + } + + async fn write(&mut self) -> Result<(), io::Error> { + let data = self.buf.split().freeze(); + let ret = match self.inner.send(data).await { + Ok(_) => Ok(()), + Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)), + }; + + if self.buf.capacity() < MAX_PACKET_LENGTH { + //info!("read {} Kb {} Kb", self.buf.capacity() / 1024, old_capa / 1024); + self.buf.reserve(MAX_PACKET_LENGTH * 4) + } + + ret + } + + async fn ping(&mut self) -> Result<(), io::Error> { + Ok(()) + } + + async fn close(&mut self) -> Result<(), io::Error> { + Ok(()) + } +} + +pub async fn connect( + request_id: Uuid, + client_cfg: &WsClientConfig, + dest_addr: &RemoteAddr, +) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> { + let mut pooled_cnx = match client_cfg.cnx_pool().get().await { + Ok(cnx) => Ok(cnx), + Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), + }?; + + let mut req = Request::builder() + .method("POST") + .uri(format!( + "{}://{}:{}/{}/events", + client_cfg.remote_addr.scheme(), + client_cfg.remote_addr.host(), + client_cfg.remote_addr.port(), + &client_cfg.http_upgrade_path_prefix + )) + .header(HOST, &client_cfg.http_header_host) + .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr)) + .version(hyper::Version::HTTP_2); + + for (k, v) in &client_cfg.http_headers { + req = req.header(k, v); + } + if let Some(auth) = &client_cfg.http_upgrade_credentials { + req = req.header(AUTHORIZATION, auth); + } + + let (tx, rx) = mpsc::channel::(1024); + let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result> { Ok(Frame::data(s)) })); + let req = req.body(body).with_context(|| { + format!( + "failed to build HTTP request to contact the server {:?}", + client_cfg.remote_addr + ) + })?; + debug!("with HTTP upgrade request {:?}", req); + let transport = pooled_cnx.deref_mut().take().unwrap(); + let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .keep_alive_interval(client_cfg.websocket_ping_frequency) + .keep_alive_while_idle(false) + .handshake(TokioIo::new(transport)) + .await + .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?; + tokio::spawn(async move { + if let Err(err) = cnx.await { + error!("{:?}", err) + } + }); + + let response = request_sender + .send_request(req) + .await + .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?; + + if !response.status().is_success() { + return Err(anyhow!( + "Http2 server rejected the connection: {:?}: {:?}", + response.status(), + String::from_utf8(response.into_body().collect().await?.to_bytes().to_vec()).unwrap_or_default() + )); + } + + let (parts, body) = response.into_parts(); + Ok((Http2TunnelRead::new(BodyStream::new(body)), Http2TunnelWrite::new(tx), parts)) +} diff --git a/src/tunnel/transport/io.rs b/src/tunnel/transport/io.rs index bd964d5..6a782a0 100644 --- a/src/tunnel/transport/io.rs +++ b/src/tunnel/transport/io.rs @@ -1,4 +1,5 @@ use crate::tunnel::transport::{TunnelRead, TunnelWrite}; +use bytes::BufMut; use futures_util::{pin_mut, FutureExt}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; @@ -6,7 +7,7 @@ use tokio::select; use tokio::sync::oneshot; use tokio::time::Instant; use tracing::log::debug; -use tracing::{error, info, trace, warn}; +use tracing::{error, info, warn}; pub async fn propagate_local_to_remote( local_rx: impl AsyncRead, @@ -19,7 +20,6 @@ pub async fn propagate_local_to_remote( }); static MAX_PACKET_LENGTH: usize = 64 * 1024; - let mut buffer = vec![0u8; MAX_PACKET_LENGTH]; // We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration // We reuse the future to avoid creating a timer in the tight loop @@ -32,21 +32,26 @@ pub async fn propagate_local_to_remote( pin_mut!(should_close); pin_mut!(local_rx); loop { + debug_assert!( + ws_tx.buf_mut().chunk_mut().len() >= MAX_PACKET_LENGTH, + "buffer must be large enough to receive a whole packet length" + ); + let read_len = select! { biased; - read_len = local_rx.read(&mut buffer) => read_len, + read_len = local_rx.read_buf(ws_tx.buf_mut()) => read_len, _ = &mut should_close => break, _ = timeout.tick(), if ping_frequency.is_some() => { - debug!("sending ping to keep websocket connection alive"); + debug!("sending ping to keep connection alive"); ws_tx.ping().await?; continue; } }; - let read_len = match read_len { + let _read_len = match read_len { Ok(0) => break, Ok(read_len) => read_len, Err(err) => { @@ -56,27 +61,10 @@ pub async fn propagate_local_to_remote( }; //debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity()); - if let Err(err) = ws_tx.write(&buffer[..read_len]).await { - warn!("error while writing to websocket tx tunnel {}", err); + if let Err(err) = ws_tx.write().await { + warn!("error while writing to tx tunnel {}", err); break; } - - // If the buffer has been completely filled with previous read, Double it ! - // For the buffer to not be a bottleneck when the TCP window scale - // For udp, the buffer will never grows. - if buffer.capacity() == read_len { - buffer.clear(); - let new_size = buffer.capacity() + (buffer.capacity() / 4); // grow buffer by 1.25 % - buffer.reserve_exact(new_size); - buffer.resize(buffer.capacity(), 0); - trace!( - "Buffer {} Mb {} {} {}", - buffer.capacity() as f64 / 1024.0 / 1024.0, - new_size, - buffer.as_slice().len(), - buffer.capacity() - ) - } } // Send normal close @@ -103,7 +91,7 @@ pub async fn propagate_remote_to_local( }; if let Err(err) = msg { - error!("error while reading from websocket rx {}", err); + error!("error while reading from tunnel rx {}", err); break; } } diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs index ff18337..480230b 100644 --- a/src/tunnel/transport/mod.rs +++ b/src/tunnel/transport/mod.rs @@ -1,15 +1,74 @@ +use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; +use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; +use bytes::BytesMut; use std::future::Future; use tokio::io::AsyncWrite; +pub mod http2; pub mod io; pub mod websocket; +static MAX_PACKET_LENGTH: usize = 64 * 1024; + pub trait TunnelWrite: Send + 'static { - fn write(&mut self, buf: &[u8]) -> impl Future> + Send; - fn ping(&mut self) -> impl Future> + Send; - fn close(&mut self) -> impl Future> + Send; + fn buf_mut(&mut self) -> &mut BytesMut; + fn write(&mut self) -> impl Future> + Send; + fn ping(&mut self) -> impl Future> + Send; + fn close(&mut self) -> impl Future> + Send; } pub trait TunnelRead: Send + 'static { - fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> impl Future> + Send; + fn copy( + &mut self, + writer: impl AsyncWrite + Unpin + Send, + ) -> impl Future> + Send; +} + +pub enum TunnelReader { + Websocket(WebsocketTunnelRead), + Http2(Http2TunnelRead), +} + +impl TunnelRead for TunnelReader { + async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> { + match self { + TunnelReader::Websocket(s) => s.copy(writer).await, + TunnelReader::Http2(s) => s.copy(writer).await, + } + } +} + +pub enum TunnelWriter { + Websocket(WebsocketTunnelWrite), + Http2(Http2TunnelWrite), +} + +impl TunnelWrite for TunnelWriter { + fn buf_mut(&mut self) -> &mut BytesMut { + match self { + TunnelWriter::Websocket(s) => s.buf_mut(), + TunnelWriter::Http2(s) => s.buf_mut(), + } + } + + async fn write(&mut self) -> Result<(), std::io::Error> { + match self { + TunnelWriter::Websocket(s) => s.write().await, + TunnelWriter::Http2(s) => s.write().await, + } + } + + async fn ping(&mut self) -> Result<(), std::io::Error> { + match self { + TunnelWriter::Websocket(s) => s.ping().await, + TunnelWriter::Http2(s) => s.ping().await, + } + } + + async fn close(&mut self) -> Result<(), std::io::Error> { + match self { + TunnelWriter::Websocket(s) => s.close().await, + TunnelWriter::Http2(s) => s.close().await, + } + } } diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index d302cbe..8eb57a9 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,40 +1,105 @@ -use crate::tunnel::transport::{TunnelRead, TunnelWrite}; +use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; use crate::WsClientConfig; use anyhow::{anyhow, Context}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; use http_body_util::Empty; -use hyper::body::Incoming; use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; +use hyper::http::response::Parts; use hyper::upgrade::Upgraded; -use hyper::{Request, Response}; +use hyper::Request; use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; use log::debug; +use std::io; +use std::io::ErrorKind; use std::ops::DerefMut; use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tracing::trace; use uuid::Uuid; -impl TunnelWrite for WebSocketWrite>> { - async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> { - self.write_frame(Frame::binary(Payload::Borrowed(buf))) - .await - .with_context(|| "cannot send ws frame") +pub struct WebsocketTunnelWrite { + inner: WebSocketWrite>>, + buf: BytesMut, +} + +impl WebsocketTunnelWrite { + pub fn new(ws: WebSocketWrite>>) -> Self { + Self { + inner: ws, + buf: BytesMut::with_capacity(MAX_PACKET_LENGTH), + } + } +} + +impl TunnelWrite for WebsocketTunnelWrite { + fn buf_mut(&mut self) -> &mut BytesMut { + &mut self.buf } - async fn ping(&mut self) -> anyhow::Result<()> { - self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))) - .await - .with_context(|| "cannot send ws ping") + async fn write(&mut self) -> Result<(), io::Error> { + let read_len = self.buf.len(); + let buf = &mut self.buf; + + let ret = self + .inner + .write_frame(Frame::binary(Payload::BorrowedMut(&mut buf[..read_len]))) + .await; + + if let Err(err) = ret { + return Err(io::Error::new(ErrorKind::ConnectionAborted, err)); + } + + // If the buffer has been completely filled with previous read, Grows it ! + // For the buffer to not be a bottleneck when the TCP window scale + // For udp, the buffer will never grows. + buf.clear(); + if buf.capacity() == read_len { + let new_size = buf.capacity() + (buf.capacity() / 4); // grow buffer by 1.25 % + buf.reserve(new_size); + buf.resize(buf.capacity(), 0); + trace!( + "Buffer {} Mb {} {} {}", + buf.capacity() as f64 / 1024.0 / 1024.0, + new_size, + buf.len(), + buf.capacity() + ) + } + + Ok(()) } - async fn close(&mut self) -> anyhow::Result<()> { - self.write_frame(Frame::close(1000, &[])) + async fn ping(&mut self) -> Result<(), io::Error> { + if let Err(err) = self + .inner + .write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))) .await - .with_context(|| "cannot close websocket cnx") + { + return Err(io::Error::new(ErrorKind::BrokenPipe, err)); + } + + Ok(()) + } + + async fn close(&mut self) -> Result<(), io::Error> { + if let Err(err) = self.inner.write_frame(Frame::close(1000, &[])).await { + return Err(io::Error::new(ErrorKind::BrokenPipe, err)); + } + + Ok(()) + } +} + +pub struct WebsocketTunnelRead { + inner: WebSocketRead>>, +} + +impl WebsocketTunnelRead { + pub fn new(ws: WebSocketRead>>) -> Self { + Self { inner: ws } } } @@ -42,21 +107,24 @@ fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready> debug!("frame {:?} {:?}", x.opcode, x.payload); futures_util::future::ready(anyhow::Ok(())) } -impl TunnelRead for WebSocketRead>> { - async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> anyhow::Result<()> { + +impl TunnelRead for WebsocketTunnelRead { + async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> { loop { - let msg = self - .read_frame(&mut frame_reader) - .await - .with_context(|| "error while reading from websocket")?; + let msg = match self.inner.read_frame(&mut frame_reader).await { + Ok(msg) => msg, + Err(err) => return Err(io::Error::new(ErrorKind::ConnectionAborted, err)), + }; trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload); match msg.opcode { OpCode::Continuation | OpCode::Text | OpCode::Binary => { - writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?; - return Ok(()); + return match writer.write_all(msg.payload.as_ref()).await { + Ok(_) => Ok(()), + Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)), + } } - OpCode::Close => return Err(anyhow!("websocket close")), + OpCode::Close => return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")), OpCode::Ping => continue, OpCode::Pong => continue, }; @@ -68,7 +136,7 @@ pub async fn connect( request_id: Uuid, client_cfg: &WsClientConfig, dest_addr: &RemoteAddr, -) -> anyhow::Result<((impl TunnelRead, impl TunnelWrite), Response)> { +) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> { let mut pooled_cnx = match client_cfg.cnx_pool().get().await { Ok(cnx) => Ok(cnx), Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), @@ -76,7 +144,7 @@ pub async fn connect( let mut req = Request::builder() .method("GET") - .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,)) + .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix)) .header(HOST, &client_cfg.http_header_host) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") @@ -109,5 +177,11 @@ pub async fn connect( ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); - Ok((ws.split(tokio::io::split), response)) + let (ws_rx, ws_tx) = ws.split(tokio::io::split); + + Ok(( + WebsocketTunnelRead::new(ws_rx), + WebsocketTunnelWrite::new(ws_tx), + response.into_parts().0, + )) }