Add custom trait for tunnel transport
This commit is contained in:
parent
6375e14185
commit
3eef03d8c4
7 changed files with 124 additions and 88 deletions
|
@ -114,9 +114,9 @@ pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyh
|
||||||
);
|
);
|
||||||
|
|
||||||
let tls_connector = match &client_cfg.remote_addr {
|
let tls_connector = match &client_cfg.remote_addr {
|
||||||
TransportAddr::WSS { tls, .. } => &tls.tls_connector,
|
TransportAddr::Wss { tls, .. } => &tls.tls_connector,
|
||||||
TransportAddr::HTTPS { tls, .. } => &tls.tls_connector,
|
TransportAddr::Https { tls, .. } => &tls.tls_connector,
|
||||||
TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => {
|
TransportAddr::Http { .. } | TransportAddr::Ws { .. } => {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"Transport does not support TLS: {}",
|
"Transport does not support TLS: {}",
|
||||||
client_cfg.remote_addr.scheme_name()
|
client_cfg.remote_addr.scheme_name()
|
||||||
|
|
|
@ -136,11 +136,12 @@ where
|
||||||
// Forward local tx to websocket tx
|
// Forward local tx to websocket tx
|
||||||
let ping_frequency = client_cfg.websocket_ping_frequency;
|
let ping_frequency = client_cfg.websocket_ping_frequency;
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()),
|
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||||
|
.instrument(Span::current()),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Forward websocket rx to local rx
|
// Forward websocket rx to local rx
|
||||||
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
|
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -233,11 +234,12 @@ where
|
||||||
let tunnel = async move {
|
let tunnel = async move {
|
||||||
let ping_frequency = client_config.websocket_ping_frequency;
|
let ping_frequency = client_config.websocket_ping_frequency;
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()),
|
super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency))
|
||||||
|
.instrument(Span::current()),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Forward websocket rx to local rx
|
// Forward websocket rx to local rx
|
||||||
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
|
let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await;
|
||||||
}
|
}
|
||||||
.instrument(span.clone());
|
.instrument(span.clone());
|
||||||
tokio::spawn(tunnel);
|
tokio::spawn(tunnel);
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
mod io;
|
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod tls_reloader;
|
mod tls_reloader;
|
||||||
|
mod transport;
|
||||||
|
|
||||||
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
|
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -80,21 +80,21 @@ pub struct RemoteAddr {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum TransportAddr {
|
pub enum TransportAddr {
|
||||||
WSS {
|
Wss {
|
||||||
tls: TlsClientConfig,
|
tls: TlsClientConfig,
|
||||||
host: Host,
|
host: Host,
|
||||||
port: u16,
|
port: u16,
|
||||||
},
|
},
|
||||||
WS {
|
Ws {
|
||||||
host: Host,
|
host: Host,
|
||||||
port: u16,
|
port: u16,
|
||||||
},
|
},
|
||||||
HTTPS {
|
Https {
|
||||||
tls: TlsClientConfig,
|
tls: TlsClientConfig,
|
||||||
host: Host,
|
host: Host,
|
||||||
port: u16,
|
port: u16,
|
||||||
},
|
},
|
||||||
HTTP {
|
Http {
|
||||||
host: Host,
|
host: Host,
|
||||||
port: u16,
|
port: u16,
|
||||||
},
|
},
|
||||||
|
@ -109,62 +109,54 @@ impl Debug for TransportAddr {
|
||||||
impl TransportAddr {
|
impl TransportAddr {
|
||||||
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
|
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
|
||||||
match scheme {
|
match scheme {
|
||||||
"https" => {
|
"https" => Some(TransportAddr::Https { tls: tls?, host, port }),
|
||||||
let Some(tls) = tls else { return None };
|
"http" => Some(TransportAddr::Http { host, port }),
|
||||||
|
"wss" => Some(TransportAddr::Wss { tls: tls?, host, port }),
|
||||||
Some(TransportAddr::HTTPS { tls, host, port })
|
"ws" => Some(TransportAddr::Ws { host, port }),
|
||||||
}
|
|
||||||
"http" => Some(TransportAddr::HTTP { host, port }),
|
|
||||||
"wss" => {
|
|
||||||
let Some(tls) = tls else { return None };
|
|
||||||
|
|
||||||
Some(TransportAddr::WSS { tls, host, port })
|
|
||||||
}
|
|
||||||
"ws" => Some(TransportAddr::WS { host, port }),
|
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn is_websocket(&self) -> bool {
|
pub fn is_websocket(&self) -> bool {
|
||||||
matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. })
|
matches!(self, TransportAddr::Ws { .. } | TransportAddr::Wss { .. })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_http2(&self) -> bool {
|
pub fn is_http2(&self) -> bool {
|
||||||
matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. })
|
matches!(self, TransportAddr::Http { .. } | TransportAddr::Https { .. })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tls(&self) -> Option<&TlsClientConfig> {
|
pub fn tls(&self) -> Option<&TlsClientConfig> {
|
||||||
match self {
|
match self {
|
||||||
TransportAddr::WSS { tls, .. } => Some(tls),
|
TransportAddr::Wss { tls, .. } => Some(tls),
|
||||||
TransportAddr::HTTPS { tls, .. } => Some(tls),
|
TransportAddr::Https { tls, .. } => Some(tls),
|
||||||
TransportAddr::WS { .. } => None,
|
TransportAddr::Ws { .. } => None,
|
||||||
TransportAddr::HTTP { .. } => None,
|
TransportAddr::Http { .. } => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn host(&self) -> &Host {
|
pub fn host(&self) -> &Host {
|
||||||
match self {
|
match self {
|
||||||
TransportAddr::WSS { host, .. } => host,
|
TransportAddr::Wss { host, .. } => host,
|
||||||
TransportAddr::WS { host, .. } => host,
|
TransportAddr::Ws { host, .. } => host,
|
||||||
TransportAddr::HTTPS { host, .. } => host,
|
TransportAddr::Https { host, .. } => host,
|
||||||
TransportAddr::HTTP { host, .. } => host,
|
TransportAddr::Http { host, .. } => host,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn port(&self) -> u16 {
|
pub fn port(&self) -> u16 {
|
||||||
match self {
|
match self {
|
||||||
TransportAddr::WSS { port, .. } => *port,
|
TransportAddr::Wss { port, .. } => *port,
|
||||||
TransportAddr::WS { port, .. } => *port,
|
TransportAddr::Ws { port, .. } => *port,
|
||||||
TransportAddr::HTTPS { port, .. } => *port,
|
TransportAddr::Https { port, .. } => *port,
|
||||||
TransportAddr::HTTP { port, .. } => *port,
|
TransportAddr::Http { port, .. } => *port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scheme_name(&self) -> &str {
|
pub fn scheme_name(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
TransportAddr::WSS { .. } => "wss",
|
TransportAddr::Wss { .. } => "wss",
|
||||||
TransportAddr::WS { .. } => "ws",
|
TransportAddr::Ws { .. } => "ws",
|
||||||
TransportAddr::HTTPS { .. } => "https",
|
TransportAddr::Https { .. } => "https",
|
||||||
TransportAddr::HTTP { .. } => "http",
|
TransportAddr::Http { .. } => "http",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -403,9 +403,11 @@ async fn server_upgrade(
|
||||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||||
|
|
||||||
tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()));
|
tokio::task::spawn(
|
||||||
|
super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, None).await;
|
let _ = super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, None).await;
|
||||||
}
|
}
|
||||||
.instrument(Span::current()),
|
.instrument(Span::current()),
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,24 +1,21 @@
|
||||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
|
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||||
use futures_util::{pin_mut, FutureExt};
|
use futures_util::{pin_mut, FutureExt};
|
||||||
use hyper::upgrade::Upgraded;
|
|
||||||
|
|
||||||
use hyper_util::rt::TokioIo;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::log::debug;
|
use tracing::log::debug;
|
||||||
use tracing::{error, info, trace, warn};
|
use tracing::{error, info, trace, warn};
|
||||||
|
|
||||||
pub(super) async fn propagate_read(
|
pub async fn propagate_local_to_remote(
|
||||||
local_rx: impl AsyncRead,
|
local_rx: impl AsyncRead,
|
||||||
mut ws_tx: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
mut ws_tx: impl TunnelWrite,
|
||||||
mut close_tx: oneshot::Sender<()>,
|
mut close_tx: oneshot::Sender<()>,
|
||||||
ping_frequency: Option<Duration>,
|
ping_frequency: Option<Duration>,
|
||||||
) -> Result<(), WebSocketError> {
|
) -> anyhow::Result<()> {
|
||||||
let _guard = scopeguard::guard((), |_| {
|
let _guard = scopeguard::guard((), |_| {
|
||||||
info!("Closing local tx ==> websocket tx tunnel");
|
info!("Closing local ==>> remote tunnel");
|
||||||
});
|
});
|
||||||
|
|
||||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||||
|
@ -44,8 +41,7 @@ pub(super) async fn propagate_read(
|
||||||
|
|
||||||
_ = timeout.tick(), if ping_frequency.is_some() => {
|
_ = timeout.tick(), if ping_frequency.is_some() => {
|
||||||
debug!("sending ping to keep websocket connection alive");
|
debug!("sending ping to keep websocket connection alive");
|
||||||
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
|
ws_tx.ping().await?;
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -60,10 +56,7 @@ pub(super) async fn propagate_read(
|
||||||
};
|
};
|
||||||
|
|
||||||
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
|
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
|
||||||
if let Err(err) = ws_tx
|
if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
|
||||||
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!("error while writing to websocket tx tunnel {}", err);
|
warn!("error while writing to websocket tx tunnel {}", err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -87,53 +80,32 @@ pub(super) async fn propagate_read(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send normal close
|
// Send normal close
|
||||||
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
|
let _ = ws_tx.close().await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn propagate_write(
|
pub async fn propagate_remote_to_local(
|
||||||
local_tx: impl AsyncWrite,
|
local_tx: impl AsyncWrite,
|
||||||
mut ws_rx: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
|
mut ws_rx: impl TunnelRead,
|
||||||
mut close_rx: oneshot::Receiver<()>,
|
mut close_rx: oneshot::Receiver<()>,
|
||||||
) -> Result<(), WebSocketError> {
|
) -> anyhow::Result<()> {
|
||||||
let _guard = scopeguard::guard((), |_| {
|
let _guard = scopeguard::guard((), |_| {
|
||||||
info!("Closing local rx <== websocket rx tunnel");
|
info!("Closing local <<== remote tunnel");
|
||||||
});
|
});
|
||||||
let mut x = |x: Frame<'_>| {
|
|
||||||
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
|
||||||
futures_util::future::ready(anyhow::Ok(()))
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(local_tx);
|
pin_mut!(local_tx);
|
||||||
loop {
|
loop {
|
||||||
let msg = select! {
|
let msg = select! {
|
||||||
biased;
|
biased;
|
||||||
msg = ws_rx.read_frame(&mut x) => msg,
|
msg = ws_rx.copy(&mut local_tx) => msg,
|
||||||
|
|
||||||
_ = &mut close_rx => break,
|
_ = &mut close_rx => break,
|
||||||
};
|
};
|
||||||
|
|
||||||
let msg = match msg {
|
if let Err(err) = msg {
|
||||||
Ok(msg) => msg,
|
|
||||||
Err(err) => {
|
|
||||||
error!("error while reading from websocket rx {}", err);
|
error!("error while reading from websocket rx {}", err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
|
||||||
let ret = match msg.opcode {
|
|
||||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
|
|
||||||
OpCode::Close => break,
|
|
||||||
OpCode::Ping => Ok(()),
|
|
||||||
OpCode::Pong => Ok(()),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(err) = ret {
|
|
||||||
error!("error while writing bytes to local for rx tunnel {}", err);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
14
src/tunnel/transport/mod.rs
Normal file
14
src/tunnel/transport/mod.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
use tokio::io::AsyncWrite;
|
||||||
|
|
||||||
|
pub mod io;
|
||||||
|
pub mod websocket;
|
||||||
|
|
||||||
|
pub trait TunnelWrite {
|
||||||
|
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>;
|
||||||
|
async fn ping(&mut self) -> anyhow::Result<()>;
|
||||||
|
async fn close(&mut self) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait TunnelRead {
|
||||||
|
async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>;
|
||||||
|
}
|
54
src/tunnel/transport/websocket.rs
Normal file
54
src/tunnel/transport/websocket.rs
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||||
|
use hyper::upgrade::Upgraded;
|
||||||
|
use hyper_util::rt::TokioIo;
|
||||||
|
use log::debug;
|
||||||
|
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||||
|
use tracing::trace;
|
||||||
|
|
||||||
|
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
|
||||||
|
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||||
|
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
|
||||||
|
.await
|
||||||
|
.with_context(|| "cannot send ws frame")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ping(&mut self) -> anyhow::Result<()> {
|
||||||
|
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
|
||||||
|
.await
|
||||||
|
.with_context(|| "cannot send ws ping")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&mut self) -> anyhow::Result<()> {
|
||||||
|
self.write_frame(Frame::close(1000, &[]))
|
||||||
|
.await
|
||||||
|
.with_context(|| "cannot close websocket cnx")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>> {
|
||||||
|
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
||||||
|
futures_util::future::ready(anyhow::Ok(()))
|
||||||
|
}
|
||||||
|
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
|
||||||
|
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> {
|
||||||
|
loop {
|
||||||
|
let msg = self
|
||||||
|
.read_frame(&mut frame_reader)
|
||||||
|
.await
|
||||||
|
.with_context(|| "error while reading from websocket")?;
|
||||||
|
|
||||||
|
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
||||||
|
match msg.opcode {
|
||||||
|
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
||||||
|
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
OpCode::Close => return Err(anyhow!("websocket close")),
|
||||||
|
OpCode::Ping => continue,
|
||||||
|
OpCode::Pong => continue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue