Add custom trait for tunnel transport

This commit is contained in:
Σrebe - Romain GERARD 2024-01-13 21:06:57 +01:00
parent 6375e14185
commit 3eef03d8c4
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
7 changed files with 124 additions and 88 deletions

112
src/tunnel/transport/io.rs Normal file
View file

@ -0,0 +1,112 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use futures_util::{pin_mut, FutureExt};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::select;
use tokio::sync::oneshot;
use tokio::time::Instant;
use tracing::log::debug;
use tracing::{error, info, trace, warn};
pub async fn propagate_local_to_remote(
local_rx: impl AsyncRead,
mut ws_tx: impl TunnelWrite,
mut close_tx: oneshot::Sender<()>,
ping_frequency: Option<Duration>,
) -> anyhow::Result<()> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local ==>> remote tunnel");
});
static MAX_PACKET_LENGTH: usize = 64 * 1024;
let mut buffer = vec![0u8; MAX_PACKET_LENGTH];
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
let frequency = ping_frequency.unwrap_or(Duration::from_secs(3600 * 24));
let start_at = Instant::now().checked_add(frequency).unwrap_or(Instant::now());
let timeout = tokio::time::interval_at(start_at, frequency);
let should_close = close_tx.closed().fuse();
pin_mut!(timeout);
pin_mut!(should_close);
pin_mut!(local_rx);
loop {
let read_len = select! {
biased;
read_len = local_rx.read(&mut buffer) => read_len,
_ = &mut should_close => break,
_ = timeout.tick(), if ping_frequency.is_some() => {
debug!("sending ping to keep websocket connection alive");
ws_tx.ping().await?;
continue;
}
};
let read_len = match read_len {
Ok(0) => break,
Ok(read_len) => read_len,
Err(err) => {
warn!("error while reading incoming bytes from local tx tunnel: {}", err);
break;
}
};
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
// If the buffer has been completely filled with previous read, Double it !
// For the buffer to not be a bottleneck when the TCP window scale
// For udp, the buffer will never grows.
if buffer.capacity() == read_len {
buffer.clear();
let new_size = buffer.capacity() + (buffer.capacity() / 4); // grow buffer by 1.25 %
buffer.reserve_exact(new_size);
buffer.resize(buffer.capacity(), 0);
trace!(
"Buffer {} Mb {} {} {}",
buffer.capacity() as f64 / 1024.0 / 1024.0,
new_size,
buffer.as_slice().len(),
buffer.capacity()
)
}
}
// Send normal close
let _ = ws_tx.close().await;
Ok(())
}
pub async fn propagate_remote_to_local(
local_tx: impl AsyncWrite,
mut ws_rx: impl TunnelRead,
mut close_rx: oneshot::Receiver<()>,
) -> anyhow::Result<()> {
let _guard = scopeguard::guard((), |_| {
info!("Closing local <<== remote tunnel");
});
pin_mut!(local_tx);
loop {
let msg = select! {
biased;
msg = ws_rx.copy(&mut local_tx) => msg,
_ = &mut close_rx => break,
};
if let Err(err) = msg {
error!("error while reading from websocket rx {}", err);
break;
}
}
Ok(())
}

View file

@ -0,0 +1,14 @@
use tokio::io::AsyncWrite;
pub mod io;
pub mod websocket;
pub trait TunnelWrite {
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>;
async fn ping(&mut self) -> anyhow::Result<()>;
async fn close(&mut self) -> anyhow::Result<()>;
}
pub trait TunnelRead {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>;
}

View file

@ -0,0 +1,54 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use anyhow::{anyhow, Context};
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use log::debug;
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tracing::trace;
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
.await
.with_context(|| "cannot send ws frame")
}
async fn ping(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
.await
.with_context(|| "cannot send ws ping")
}
async fn close(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::close(1000, &[]))
.await
.with_context(|| "cannot close websocket cnx")
}
}
fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>> {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
}
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> {
loop {
let msg = self
.read_frame(&mut frame_reader)
.await
.with_context(|| "error while reading from websocket")?;
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
return Ok(());
}
OpCode::Close => return Err(anyhow!("websocket close")),
OpCode::Ping => continue,
OpCode::Pong => continue,
};
}
}
}