Add custom trait for tunnel transport

This commit is contained in:
Σrebe - Romain GERARD 2024-01-13 21:06:57 +01:00
parent 6375e14185
commit 3eef03d8c4
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
7 changed files with 124 additions and 88 deletions

View file

@ -114,9 +114,9 @@ pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyh
); );
let tls_connector = match &client_cfg.remote_addr { let tls_connector = match &client_cfg.remote_addr {
TransportAddr::WSS { tls, .. } => &tls.tls_connector, TransportAddr::Wss { tls, .. } => &tls.tls_connector,
TransportAddr::HTTPS { tls, .. } => &tls.tls_connector, TransportAddr::Https { tls, .. } => &tls.tls_connector,
TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => { TransportAddr::Http { .. } | TransportAddr::Ws { .. } => {
return Err(anyhow!( return Err(anyhow!(
"Transport does not support TLS: {}", "Transport does not support TLS: {}",
client_cfg.remote_addr.scheme_name() client_cfg.remote_addr.scheme_name()

View file

@ -136,11 +136,12 @@ where
// Forward local tx to websocket tx // Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency; let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn( tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()), super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
.instrument(Span::current()),
); );
// Forward websocket rx to local rx // Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
Ok(()) Ok(())
} }
@ -233,11 +234,12 @@ where
let tunnel = async move { let tunnel = async move {
let ping_frequency = client_config.websocket_ping_frequency; let ping_frequency = client_config.websocket_ping_frequency;
tokio::spawn( tokio::spawn(
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()), super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
.instrument(Span::current()),
); );
// Forward websocket rx to local rx // Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
} }
.instrument(span.clone()); .instrument(span.clone());
tokio::spawn(tunnel); tokio::spawn(tunnel);

View file

@ -1,7 +1,7 @@
pub mod client; pub mod client;
mod io;
pub mod server; pub mod server;
mod tls_reloader; mod tls_reloader;
mod transport;
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig}; use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
use async_trait::async_trait; use async_trait::async_trait;
@ -80,21 +80,21 @@ pub struct RemoteAddr {
#[derive(Clone)] #[derive(Clone)]
pub enum TransportAddr { pub enum TransportAddr {
WSS { Wss {
tls: TlsClientConfig, tls: TlsClientConfig,
host: Host, host: Host,
port: u16, port: u16,
}, },
WS { Ws {
host: Host, host: Host,
port: u16, port: u16,
}, },
HTTPS { Https {
tls: TlsClientConfig, tls: TlsClientConfig,
host: Host, host: Host,
port: u16, port: u16,
}, },
HTTP { Http {
host: Host, host: Host,
port: u16, port: u16,
}, },
@ -109,62 +109,54 @@ impl Debug for TransportAddr {
impl TransportAddr { impl TransportAddr {
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> { pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
match scheme { match scheme {
"https" => { "https" => Some(TransportAddr::Https { tls: tls?, host, port }),
let Some(tls) = tls else { return None }; "http" => Some(TransportAddr::Http { host, port }),
"wss" => Some(TransportAddr::Wss { tls: tls?, host, port }),
Some(TransportAddr::HTTPS { tls, host, port }) "ws" => Some(TransportAddr::Ws { host, port }),
}
"http" => Some(TransportAddr::HTTP { host, port }),
"wss" => {
let Some(tls) = tls else { return None };
Some(TransportAddr::WSS { tls, host, port })
}
"ws" => Some(TransportAddr::WS { host, port }),
_ => None, _ => None,
} }
} }
pub fn is_websocket(&self) -> bool { pub fn is_websocket(&self) -> bool {
matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. }) matches!(self, TransportAddr::Ws { .. } | TransportAddr::Wss { .. })
} }
pub fn is_http2(&self) -> bool { pub fn is_http2(&self) -> bool {
matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. }) matches!(self, TransportAddr::Http { .. } | TransportAddr::Https { .. })
} }
pub fn tls(&self) -> Option<&TlsClientConfig> { pub fn tls(&self) -> Option<&TlsClientConfig> {
match self { match self {
TransportAddr::WSS { tls, .. } => Some(tls), TransportAddr::Wss { tls, .. } => Some(tls),
TransportAddr::HTTPS { tls, .. } => Some(tls), TransportAddr::Https { tls, .. } => Some(tls),
TransportAddr::WS { .. } => None, TransportAddr::Ws { .. } => None,
TransportAddr::HTTP { .. } => None, TransportAddr::Http { .. } => None,
} }
} }
pub fn host(&self) -> &Host { pub fn host(&self) -> &Host {
match self { match self {
TransportAddr::WSS { host, .. } => host, TransportAddr::Wss { host, .. } => host,
TransportAddr::WS { host, .. } => host, TransportAddr::Ws { host, .. } => host,
TransportAddr::HTTPS { host, .. } => host, TransportAddr::Https { host, .. } => host,
TransportAddr::HTTP { host, .. } => host, TransportAddr::Http { host, .. } => host,
} }
} }
pub fn port(&self) -> u16 { pub fn port(&self) -> u16 {
match self { match self {
TransportAddr::WSS { port, .. } => *port, TransportAddr::Wss { port, .. } => *port,
TransportAddr::WS { port, .. } => *port, TransportAddr::Ws { port, .. } => *port,
TransportAddr::HTTPS { port, .. } => *port, TransportAddr::Https { port, .. } => *port,
TransportAddr::HTTP { port, .. } => *port, TransportAddr::Http { port, .. } => *port,
} }
} }
pub fn scheme_name(&self) -> &str { pub fn scheme_name(&self) -> &str {
match self { match self {
TransportAddr::WSS { .. } => "wss", TransportAddr::Wss { .. } => "wss",
TransportAddr::WS { .. } => "ws", TransportAddr::Ws { .. } => "ws",
TransportAddr::HTTPS { .. } => "https", TransportAddr::Https { .. } => "https",
TransportAddr::HTTP { .. } => "http", TransportAddr::Http { .. } => "http",
} }
} }
} }

View file

@ -403,9 +403,11 @@ async fn server_upgrade(
let (close_tx, close_rx) = oneshot::channel::<()>(); let (close_tx, close_rx) = oneshot::channel::<()>();
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current())); tokio::task::spawn(
super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, None).await; let _ = super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, None).await;
} }
.instrument(Span::current()), .instrument(Span::current()),
); );

View file

@ -1,24 +1,21 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use futures_util::{pin_mut, FutureExt}; use futures_util::{pin_mut, FutureExt};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::select; use tokio::select;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::log::debug; use tracing::log::debug;
use tracing::{error, info, trace, warn}; use tracing::{error, info, trace, warn};
pub(super) async fn propagate_read( pub async fn propagate_local_to_remote(
local_rx: impl AsyncRead, local_rx: impl AsyncRead,
mut ws_tx: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>, mut ws_tx: impl TunnelWrite,
mut close_tx: oneshot::Sender<()>, mut close_tx: oneshot::Sender<()>,
ping_frequency: Option<Duration>, ping_frequency: Option<Duration>,
) -> Result<(), WebSocketError> { ) -> anyhow::Result<()> {
let _guard = scopeguard::guard((), |_| { let _guard = scopeguard::guard((), |_| {
info!("Closing local tx ==> websocket tx tunnel"); info!("Closing local ==>> remote tunnel");
}); });
static MAX_PACKET_LENGTH: usize = 64 * 1024; static MAX_PACKET_LENGTH: usize = 64 * 1024;
@ -44,8 +41,7 @@ pub(super) async fn propagate_read(
_ = timeout.tick(), if ping_frequency.is_some() => { _ = timeout.tick(), if ping_frequency.is_some() => {
debug!("sending ping to keep websocket connection alive"); debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; ws_tx.ping().await?;
continue; continue;
} }
}; };
@ -60,10 +56,7 @@ pub(super) async fn propagate_read(
}; };
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity()); //debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
if let Err(err) = ws_tx if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
.await
{
warn!("error while writing to websocket tx tunnel {}", err); warn!("error while writing to websocket tx tunnel {}", err);
break; break;
} }
@ -87,51 +80,30 @@ pub(super) async fn propagate_read(
} }
// Send normal close // Send normal close
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await; let _ = ws_tx.close().await;
Ok(()) Ok(())
} }
pub(super) async fn propagate_write( pub async fn propagate_remote_to_local(
local_tx: impl AsyncWrite, local_tx: impl AsyncWrite,
mut ws_rx: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>, mut ws_rx: impl TunnelRead,
mut close_rx: oneshot::Receiver<()>, mut close_rx: oneshot::Receiver<()>,
) -> Result<(), WebSocketError> { ) -> anyhow::Result<()> {
let _guard = scopeguard::guard((), |_| { let _guard = scopeguard::guard((), |_| {
info!("Closing local rx <== websocket rx tunnel"); info!("Closing local <<== remote tunnel");
}); });
let mut x = |x: Frame<'_>| {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
};
pin_mut!(local_tx); pin_mut!(local_tx);
loop { loop {
let msg = select! { let msg = select! {
biased; biased;
msg = ws_rx.read_frame(&mut x) => msg, msg = ws_rx.copy(&mut local_tx) => msg,
_ = &mut close_rx => break, _ = &mut close_rx => break,
}; };
let msg = match msg { if let Err(err) = msg {
Ok(msg) => msg, error!("error while reading from websocket rx {}", err);
Err(err) => {
error!("error while reading from websocket rx {}", err);
break;
}
};
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
OpCode::Close => break,
OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()),
};
if let Err(err) = ret {
error!("error while writing bytes to local for rx tunnel {}", err);
break; break;
} }
} }

View file

@ -0,0 +1,14 @@
use tokio::io::AsyncWrite;
pub mod io;
pub mod websocket;
pub trait TunnelWrite {
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>;
async fn ping(&mut self) -> anyhow::Result<()>;
async fn close(&mut self) -> anyhow::Result<()>;
}
pub trait TunnelRead {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>;
}

View file

@ -0,0 +1,54 @@
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
use anyhow::{anyhow, Context};
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use log::debug;
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tracing::trace;
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
.await
.with_context(|| "cannot send ws frame")
}
async fn ping(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
.await
.with_context(|| "cannot send ws ping")
}
async fn close(&mut self) -> anyhow::Result<()> {
self.write_frame(Frame::close(1000, &[]))
.await
.with_context(|| "cannot close websocket cnx")
}
}
fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>> {
debug!("frame {:?} {:?}", x.opcode, x.payload);
futures_util::future::ready(anyhow::Ok(()))
}
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> {
loop {
let msg = self
.read_frame(&mut frame_reader)
.await
.with_context(|| "error while reading from websocket")?;
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
return Ok(());
}
OpCode::Close => return Err(anyhow!("websocket close")),
OpCode::Ping => continue,
OpCode::Pong => continue,
};
}
}
}