128 lines
4.1 KiB
Rust
128 lines
4.1 KiB
Rust
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
|
|
use futures_util::pin_mut;
|
|
use hyper::upgrade::Upgraded;
|
|
use std::pin::Pin;
|
|
|
|
use std::time::Duration;
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
|
use tokio::select;
|
|
use tokio::sync::oneshot;
|
|
use tracing::log::debug;
|
|
use tracing::{error, info, trace, warn};
|
|
|
|
pub(super) async fn propagate_read(
|
|
local_rx: impl AsyncRead,
|
|
mut ws_tx: WebSocketWrite<WriteHalf<Upgraded>>,
|
|
mut close_tx: oneshot::Sender<()>,
|
|
ping_frequency: Duration,
|
|
) -> Result<(), WebSocketError> {
|
|
let _guard = scopeguard::guard((), |_| {
|
|
info!("Closing local tx ==> websocket tx tunnel");
|
|
});
|
|
|
|
static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame
|
|
let mut buffer = vec![0u8; JUMBO_FRAME_SIZE];
|
|
|
|
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
|
|
// We reuse the future to avoid creating a timer in the tight loop
|
|
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
|
|
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
|
|
|
pin_mut!(local_rx);
|
|
loop {
|
|
let read_len = select! {
|
|
biased;
|
|
|
|
read_len = local_rx.read(&mut buffer) => read_len,
|
|
|
|
_ = close_tx.closed() => break,
|
|
|
|
_ = &mut timeout => {
|
|
debug!("sending ping to keep websocket connection alive");
|
|
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
|
|
|
|
timeout_unpin = tokio::time::sleep(ping_frequency);
|
|
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let read_len = match read_len {
|
|
Ok(0) => break,
|
|
Ok(read_len) => read_len,
|
|
Err(err) => {
|
|
warn!("error while reading incoming bytes from local tx tunnel {}", err);
|
|
break;
|
|
}
|
|
};
|
|
|
|
trace!("read {} bytes", read_len);
|
|
if let Err(err) = ws_tx
|
|
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
|
|
.await
|
|
{
|
|
warn!("error while writing to websocket tx tunnel {}", err);
|
|
break;
|
|
}
|
|
|
|
// If the buffer has been completely filled with previous read, Double it !
|
|
// For the buffer to not be a bottleneck when the TCP window scale
|
|
// For udp, the buffer will never grows.
|
|
if buffer.capacity() == read_len {
|
|
buffer.clear();
|
|
buffer.resize(buffer.capacity() * 2, 0);
|
|
}
|
|
}
|
|
|
|
// Send normal close
|
|
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(super) async fn propagate_write(
|
|
local_tx: impl AsyncWrite,
|
|
mut ws_rx: WebSocketRead<ReadHalf<Upgraded>>,
|
|
mut close_rx: oneshot::Receiver<()>,
|
|
) -> Result<(), WebSocketError> {
|
|
let _guard = scopeguard::guard((), |_| {
|
|
info!("Closing local rx <== websocket rx tunnel");
|
|
});
|
|
let mut x = |x: Frame<'_>| {
|
|
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
|
futures_util::future::ready(anyhow::Ok(()))
|
|
};
|
|
|
|
pin_mut!(local_tx);
|
|
loop {
|
|
let msg = select! {
|
|
biased;
|
|
msg = ws_rx.read_frame(&mut x) => msg,
|
|
|
|
_ = &mut close_rx => break,
|
|
};
|
|
|
|
let msg = match msg {
|
|
Ok(msg) => msg,
|
|
Err(err) => {
|
|
error!("error while reading from websocket rx {}", err);
|
|
break;
|
|
}
|
|
};
|
|
|
|
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
|
let ret = match msg.opcode {
|
|
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
|
|
OpCode::Close => break,
|
|
OpCode::Ping => Ok(()),
|
|
OpCode::Pong => Ok(()),
|
|
};
|
|
|
|
if let Err(err) = ret {
|
|
error!("error while writing bytes to local for rx tunnel {}", err);
|
|
break;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|