Add suport for http2 as transport for tunnel

This commit is contained in:
Σrebe - Romain GERARD 2024-01-14 19:18:57 +01:00
parent cf3500dffb
commit 459a0667b1
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
9 changed files with 606 additions and 168 deletions

View file

@ -0,0 +1,168 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr};
use crate::WsClientConfig;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use http_body_util::{BodyExt, BodyStream, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::header::{AUTHORIZATION, COOKIE, HOST};
use hyper::http::response::Parts;
use hyper::Request;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use log::{debug, error, warn};
use std::io;
use std::io::ErrorKind;
use std::ops::DerefMut;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use uuid::Uuid;
pub struct Http2TunnelRead {
inner: BodyStream<Incoming>,
}
impl Http2TunnelRead {
pub fn new(inner: BodyStream<Incoming>) -> Self {
Self { inner }
}
}
impl TunnelRead for Http2TunnelRead {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
loop {
match self.inner.next().await {
Some(Ok(frame)) => match frame.into_data() {
Ok(data) => {
return match writer.write_all(data.as_ref()).await {
Ok(_) => Ok(()),
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
}
}
Err(err) => {
warn!("{:?}", err);
continue;
}
},
Some(Err(err)) => {
return Err(io::Error::new(ErrorKind::ConnectionAborted, err));
}
None => return Err(io::Error::new(ErrorKind::BrokenPipe, "closed")),
}
}
}
}
pub struct Http2TunnelWrite {
inner: mpsc::Sender<Bytes>,
buf: BytesMut,
}
impl Http2TunnelWrite {
pub fn new(inner: mpsc::Sender<Bytes>) -> Self {
Self {
inner,
buf: BytesMut::with_capacity(MAX_PACKET_LENGTH * 20), // ~ 1Mb
}
}
}
impl TunnelWrite for Http2TunnelWrite {
fn buf_mut(&mut self) -> &mut BytesMut {
&mut self.buf
}
async fn write(&mut self) -> Result<(), io::Error> {
let data = self.buf.split().freeze();
let ret = match self.inner.send(data).await {
Ok(_) => Ok(()),
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
};
if self.buf.capacity() < MAX_PACKET_LENGTH {
//info!("read {} Kb {} Kb", self.buf.capacity() / 1024, old_capa / 1024);
self.buf.reserve(MAX_PACKET_LENGTH * 4)
}
ret
}
async fn ping(&mut self) -> Result<(), io::Error> {
Ok(())
}
async fn close(&mut self) -> Result<(), io::Error> {
Ok(())
}
}
pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
dest_addr: &RemoteAddr,
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(cnx) => Ok(cnx),
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
}?;
let mut req = Request::builder()
.method("POST")
.uri(format!(
"{}://{}:{}/{}/events",
client_cfg.remote_addr.scheme(),
client_cfg.remote_addr.host(),
client_cfg.remote_addr.port(),
&client_cfg.http_upgrade_path_prefix
))
.header(HOST, &client_cfg.http_header_host)
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
.version(hyper::Version::HTTP_2);
for (k, v) in &client_cfg.http_headers {
req = req.header(k, v);
}
if let Some(auth) = &client_cfg.http_upgrade_credentials {
req = req.header(AUTHORIZATION, auth);
}
let (tx, rx) = mpsc::channel::<Bytes>(1024);
let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result<Frame<Bytes>> { Ok(Frame::data(s)) }));
let req = req.body(body).with_context(|| {
format!(
"failed to build HTTP request to contact the server {:?}",
client_cfg.remote_addr
)
})?;
debug!("with HTTP upgrade request {:?}", req);
let transport = pooled_cnx.deref_mut().take().unwrap();
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(client_cfg.websocket_ping_frequency)
.keep_alive_while_idle(false)
.handshake(TokioIo::new(transport))
.await
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
tokio::spawn(async move {
if let Err(err) = cnx.await {
error!("{:?}", err)
}
});
let response = request_sender
.send_request(req)
.await
.with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
if !response.status().is_success() {
return Err(anyhow!(
"Http2 server rejected the connection: {:?}: {:?}",
response.status(),
String::from_utf8(response.into_body().collect().await?.to_bytes().to_vec()).unwrap_or_default()
));
}
let (parts, body) = response.into_parts();
Ok((Http2TunnelRead::new(BodyStream::new(body)), Http2TunnelWrite::new(tx), parts))
}

View file

@ -1,4 +1,5 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use bytes::BufMut;
use futures_util::{pin_mut, FutureExt};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
@ -6,7 +7,7 @@ use tokio::select;
use tokio::sync::oneshot;
use tokio::time::Instant;
use tracing::log::debug;
use tracing::{error, info, trace, warn};
use tracing::{error, info, warn};
pub async fn propagate_local_to_remote(
local_rx: impl AsyncRead,
@ -19,7 +20,6 @@ pub async fn propagate_local_to_remote(
});
static MAX_PACKET_LENGTH: usize = 64 * 1024;
let mut buffer = vec![0u8; MAX_PACKET_LENGTH];
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
@ -32,21 +32,26 @@ pub async fn propagate_local_to_remote(
pin_mut!(should_close);
pin_mut!(local_rx);
loop {
debug_assert!(
ws_tx.buf_mut().chunk_mut().len() >= MAX_PACKET_LENGTH,
"buffer must be large enough to receive a whole packet length"
);
let read_len = select! {
biased;
read_len = local_rx.read(&mut buffer) => read_len,
read_len = local_rx.read_buf(ws_tx.buf_mut()) => read_len,
_ = &mut should_close => break,
_ = timeout.tick(), if ping_frequency.is_some() => {
debug!("sending ping to keep websocket connection alive");
debug!("sending ping to keep connection alive");
ws_tx.ping().await?;
continue;
}
};
let read_len = match read_len {
let _read_len = match read_len {
Ok(0) => break,
Ok(read_len) => read_len,
Err(err) => {
@ -56,27 +61,10 @@ pub async fn propagate_local_to_remote(
};
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
warn!("error while writing to websocket tx tunnel {}", err);
if let Err(err) = ws_tx.write().await {
warn!("error while writing to tx tunnel {}", err);
break;
}
// If the buffer has been completely filled with previous read, Double it !
// For the buffer to not be a bottleneck when the TCP window scale
// For udp, the buffer will never grows.
if buffer.capacity() == read_len {
buffer.clear();
let new_size = buffer.capacity() + (buffer.capacity() / 4); // grow buffer by 1.25 %
buffer.reserve_exact(new_size);
buffer.resize(buffer.capacity(), 0);
trace!(
"Buffer {} Mb {} {} {}",
buffer.capacity() as f64 / 1024.0 / 1024.0,
new_size,
buffer.as_slice().len(),
buffer.capacity()
)
}
}
// Send normal close
@ -103,7 +91,7 @@ pub async fn propagate_remote_to_local(
};
if let Err(err) = msg {
error!("error while reading from websocket rx {}", err);
error!("error while reading from tunnel rx {}", err);
break;
}
}

View file

@ -1,15 +1,74 @@
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use bytes::BytesMut;
use std::future::Future;
use tokio::io::AsyncWrite;
pub mod http2;
pub mod io;
pub mod websocket;
static MAX_PACKET_LENGTH: usize = 64 * 1024;
pub trait TunnelWrite: Send + 'static {
fn write(&mut self, buf: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send;
fn ping(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
fn close(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
fn buf_mut(&mut self) -> &mut BytesMut;
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub trait TunnelRead: Send + 'static {
fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> impl Future<Output = anyhow::Result<()>> + Send;
fn copy(
&mut self,
writer: impl AsyncWrite + Unpin + Send,
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
pub enum TunnelReader {
Websocket(WebsocketTunnelRead),
Http2(Http2TunnelRead),
}
impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
match self {
TunnelReader::Websocket(s) => s.copy(writer).await,
TunnelReader::Http2(s) => s.copy(writer).await,
}
}
}
pub enum TunnelWriter {
Websocket(WebsocketTunnelWrite),
Http2(Http2TunnelWrite),
}
impl TunnelWrite for TunnelWriter {
fn buf_mut(&mut self) -> &mut BytesMut {
match self {
TunnelWriter::Websocket(s) => s.buf_mut(),
TunnelWriter::Http2(s) => s.buf_mut(),
}
}
async fn write(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.write().await,
TunnelWriter::Http2(s) => s.write().await,
}
}
async fn ping(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.ping().await,
TunnelWriter::Http2(s) => s.ping().await,
}
}
async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.close().await,
TunnelWriter::Http2(s) => s.close().await,
}
}
}

View file

@ -1,40 +1,105 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
use crate::WsClientConfig;
use anyhow::{anyhow, Context};
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
use http_body_util::Empty;
use hyper::body::Incoming;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
use hyper::http::response::Parts;
use hyper::upgrade::Upgraded;
use hyper::{Request, Response};
use hyper::Request;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use log::debug;
use std::io;
use std::io::ErrorKind;
use std::ops::DerefMut;
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tracing::trace;
use uuid::Uuid;
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
.await
.with_context(|| "cannot send ws frame")
pub struct WebsocketTunnelWrite {
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
buf: BytesMut,
}
impl WebsocketTunnelWrite {
pub fn new(ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>) -> Self {
Self {
inner: ws,
buf: BytesMut::with_capacity(MAX_PACKET_LENGTH),
}
}
}
impl TunnelWrite for WebsocketTunnelWrite {
fn buf_mut(&mut self) -> &mut BytesMut {
&mut self.buf
}
async fn ping(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
.await
.with_context(|| "cannot send ws ping")
async fn write(&mut self) -> Result<(), io::Error> {
let read_len = self.buf.len();
let buf = &mut self.buf;
let ret = self
.inner
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buf[..read_len])))
.await;
if let Err(err) = ret {
return Err(io::Error::new(ErrorKind::ConnectionAborted, err));
}
// If the buffer has been completely filled with previous read, Grows it !
// For the buffer to not be a bottleneck when the TCP window scale
// For udp, the buffer will never grows.
buf.clear();
if buf.capacity() == read_len {
let new_size = buf.capacity() + (buf.capacity() / 4); // grow buffer by 1.25 %
buf.reserve(new_size);
buf.resize(buf.capacity(), 0);
trace!(
"Buffer {} Mb {} {} {}",
buf.capacity() as f64 / 1024.0 / 1024.0,
new_size,
buf.len(),
buf.capacity()
)
}
Ok(())
}
async fn close(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::close(1000, &[]))
async fn ping(&mut self) -> Result<(), io::Error> {
if let Err(err) = self
.inner
.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
.await
.with_context(|| "cannot close websocket cnx")
{
return Err(io::Error::new(ErrorKind::BrokenPipe, err));
}
Ok(())
}
async fn close(&mut self) -> Result<(), io::Error> {
if let Err(err) = self.inner.write_frame(Frame::close(1000, &[])).await {
return Err(io::Error::new(ErrorKind::BrokenPipe, err));
}
Ok(())
}
}
pub struct WebsocketTunnelRead {
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
}
impl WebsocketTunnelRead {
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> Self {
Self { inner: ws }
}
}
@ -42,21 +107,24 @@ fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>>
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
}
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> anyhow::Result<()> {
impl TunnelRead for WebsocketTunnelRead {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
loop {
let msg = self
.read_frame(&mut frame_reader)
.await
.with_context(|| "error while reading from websocket")?;
let msg = match self.inner.read_frame(&mut frame_reader).await {
Ok(msg) => msg,
Err(err) => return Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
};
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
return Ok(());
return match writer.write_all(msg.payload.as_ref()).await {
Ok(_) => Ok(()),
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
}
}
OpCode::Close => return Err(anyhow!("websocket close")),
OpCode::Close => return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")),
OpCode::Ping => continue,
OpCode::Pong => continue,
};
@ -68,7 +136,7 @@ pub async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
dest_addr: &RemoteAddr,
) -> anyhow::Result<((impl TunnelRead, impl TunnelWrite), Response<Incoming>)> {
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(cnx) => Ok(cnx),
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
@ -76,7 +144,7 @@ pub async fn connect(
let mut req = Request::builder()
.method("GET")
.uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,))
.uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
.header(HOST, &client_cfg.http_header_host)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
@ -109,5 +177,11 @@ pub async fn connect(
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
Ok((ws.split(tokio::io::split), response))
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
Ok((
WebsocketTunnelRead::new(ws_rx),
WebsocketTunnelWrite::new(ws_tx),
response.into_parts().0,
))
}