Add custom trait for tunnel transport
This commit is contained in:
parent
6375e14185
commit
3eef03d8c4
7 changed files with 124 additions and 88 deletions
112
src/tunnel/transport/io.rs
Normal file
112
src/tunnel/transport/io.rs
Normal file
|
@ -0,0 +1,112 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||
use futures_util::{pin_mut, FutureExt};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::select;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tracing::log::debug;
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
pub async fn propagate_local_to_remote(
|
||||
local_rx: impl AsyncRead,
|
||||
mut ws_tx: impl TunnelWrite,
|
||||
mut close_tx: oneshot::Sender<()>,
|
||||
ping_frequency: Option<Duration>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = scopeguard::guard((), |_| {
|
||||
info!("Closing local ==>> remote tunnel");
|
||||
});
|
||||
|
||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||
let mut buffer = vec![0u8; MAX_PACKET_LENGTH];
|
||||
|
||||
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
|
||||
// We reuse the future to avoid creating a timer in the tight loop
|
||||
let frequency = ping_frequency.unwrap_or(Duration::from_secs(3600 * 24));
|
||||
let start_at = Instant::now().checked_add(frequency).unwrap_or(Instant::now());
|
||||
let timeout = tokio::time::interval_at(start_at, frequency);
|
||||
let should_close = close_tx.closed().fuse();
|
||||
|
||||
pin_mut!(timeout);
|
||||
pin_mut!(should_close);
|
||||
pin_mut!(local_rx);
|
||||
loop {
|
||||
let read_len = select! {
|
||||
biased;
|
||||
|
||||
read_len = local_rx.read(&mut buffer) => read_len,
|
||||
|
||||
_ = &mut should_close => break,
|
||||
|
||||
_ = timeout.tick(), if ping_frequency.is_some() => {
|
||||
debug!("sending ping to keep websocket connection alive");
|
||||
ws_tx.ping().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let read_len = match read_len {
|
||||
Ok(0) => break,
|
||||
Ok(read_len) => read_len,
|
||||
Err(err) => {
|
||||
warn!("error while reading incoming bytes from local tx tunnel: {}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
|
||||
if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
|
||||
warn!("error while writing to websocket tx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
|
||||
// If the buffer has been completely filled with previous read, Double it !
|
||||
// For the buffer to not be a bottleneck when the TCP window scale
|
||||
// For udp, the buffer will never grows.
|
||||
if buffer.capacity() == read_len {
|
||||
buffer.clear();
|
||||
let new_size = buffer.capacity() + (buffer.capacity() / 4); // grow buffer by 1.25 %
|
||||
buffer.reserve_exact(new_size);
|
||||
buffer.resize(buffer.capacity(), 0);
|
||||
trace!(
|
||||
"Buffer {} Mb {} {} {}",
|
||||
buffer.capacity() as f64 / 1024.0 / 1024.0,
|
||||
new_size,
|
||||
buffer.as_slice().len(),
|
||||
buffer.capacity()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Send normal close
|
||||
let _ = ws_tx.close().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn propagate_remote_to_local(
|
||||
local_tx: impl AsyncWrite,
|
||||
mut ws_rx: impl TunnelRead,
|
||||
mut close_rx: oneshot::Receiver<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = scopeguard::guard((), |_| {
|
||||
info!("Closing local <<== remote tunnel");
|
||||
});
|
||||
|
||||
pin_mut!(local_tx);
|
||||
loop {
|
||||
let msg = select! {
|
||||
biased;
|
||||
msg = ws_rx.copy(&mut local_tx) => msg,
|
||||
_ = &mut close_rx => break,
|
||||
};
|
||||
|
||||
if let Err(err) = msg {
|
||||
error!("error while reading from websocket rx {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
14
src/tunnel/transport/mod.rs
Normal file
14
src/tunnel/transport/mod.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use tokio::io::AsyncWrite;
|
||||
|
||||
pub mod io;
|
||||
pub mod websocket;
|
||||
|
||||
pub trait TunnelWrite {
|
||||
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>;
|
||||
async fn ping(&mut self) -> anyhow::Result<()>;
|
||||
async fn close(&mut self) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
pub trait TunnelRead {
|
||||
async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>;
|
||||
}
|
54
src/tunnel/transport/websocket.rs
Normal file
54
src/tunnel/transport/websocket.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||
use anyhow::{anyhow, Context};
|
||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::debug;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tracing::trace;
|
||||
|
||||
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
|
||||
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
|
||||
.await
|
||||
.with_context(|| "cannot send ws frame")
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
|
||||
.await
|
||||
.with_context(|| "cannot send ws ping")
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::close(1000, &[]))
|
||||
.await
|
||||
.with_context(|| "cannot close websocket cnx")
|
||||
}
|
||||
}
|
||||
|
||||
fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>> {
|
||||
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
||||
futures_util::future::ready(anyhow::Ok(()))
|
||||
}
|
||||
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let msg = self
|
||||
.read_frame(&mut frame_reader)
|
||||
.await
|
||||
.with_context(|| "error while reading from websocket")?;
|
||||
|
||||
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
||||
match msg.opcode {
|
||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
||||
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
|
||||
return Ok(());
|
||||
}
|
||||
OpCode::Close => return Err(anyhow!("websocket close")),
|
||||
OpCode::Ping => continue,
|
||||
OpCode::Pong => continue,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue