diff --git a/src/tls.rs b/src/tls.rs index 7f5b247..8bf35f4 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -114,9 +114,9 @@ pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyh ); let tls_connector = match &client_cfg.remote_addr { - TransportAddr::WSS { tls, .. } => &tls.tls_connector, - TransportAddr::HTTPS { tls, .. } => &tls.tls_connector, - TransportAddr::HTTP { .. } | TransportAddr::WS { .. } => { + TransportAddr::Wss { tls, .. } => &tls.tls_connector, + TransportAddr::Https { tls, .. } => &tls.tls_connector, + TransportAddr::Http { .. } | TransportAddr::Ws { .. } => { return Err(anyhow!( "Transport does not support TLS: {}", client_cfg.remote_addr.scheme_name() diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 06e27b5..42aac00 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -136,11 +136,12 @@ where // Forward local tx to websocket tx let ping_frequency = client_cfg.websocket_ping_frequency; tokio::spawn( - super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()), + super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) + .instrument(Span::current()), ); // Forward websocket rx to local rx - let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; + let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; Ok(()) } @@ -233,11 +234,12 @@ where let tunnel = async move { let ping_frequency = client_config.websocket_ping_frequency; tokio::spawn( - super::io::propagate_read(local_rx, ws_tx, close_tx, Some(ping_frequency)).instrument(Span::current()), + super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) + .instrument(Span::current()), ); // Forward websocket rx to local rx - let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; + let _ = super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).await; } .instrument(span.clone()); tokio::spawn(tunnel); diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index b8d8225..16dcbdb 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -1,7 +1,7 @@ pub mod client; -mod io; pub mod server; mod tls_reloader; +mod transport; use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig}; use async_trait::async_trait; @@ -80,21 +80,21 @@ pub struct RemoteAddr { #[derive(Clone)] pub enum TransportAddr { - WSS { + Wss { tls: TlsClientConfig, host: Host, port: u16, }, - WS { + Ws { host: Host, port: u16, }, - HTTPS { + Https { tls: TlsClientConfig, host: Host, port: u16, }, - HTTP { + Http { host: Host, port: u16, }, @@ -109,62 +109,54 @@ impl Debug for TransportAddr { impl TransportAddr { pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option) -> Option { match scheme { - "https" => { - let Some(tls) = tls else { return None }; - - Some(TransportAddr::HTTPS { tls, host, port }) - } - "http" => Some(TransportAddr::HTTP { host, port }), - "wss" => { - let Some(tls) = tls else { return None }; - - Some(TransportAddr::WSS { tls, host, port }) - } - "ws" => Some(TransportAddr::WS { host, port }), + "https" => Some(TransportAddr::Https { tls: tls?, host, port }), + "http" => Some(TransportAddr::Http { host, port }), + "wss" => Some(TransportAddr::Wss { tls: tls?, host, port }), + "ws" => Some(TransportAddr::Ws { host, port }), _ => None, } } pub fn is_websocket(&self) -> bool { - matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. }) + matches!(self, TransportAddr::Ws { .. } | TransportAddr::Wss { .. }) } pub fn is_http2(&self) -> bool { - matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. }) + matches!(self, TransportAddr::Http { .. } | TransportAddr::Https { .. }) } pub fn tls(&self) -> Option<&TlsClientConfig> { match self { - TransportAddr::WSS { tls, .. } => Some(tls), - TransportAddr::HTTPS { tls, .. } => Some(tls), - TransportAddr::WS { .. } => None, - TransportAddr::HTTP { .. } => None, + TransportAddr::Wss { tls, .. } => Some(tls), + TransportAddr::Https { tls, .. } => Some(tls), + TransportAddr::Ws { .. } => None, + TransportAddr::Http { .. } => None, } } pub fn host(&self) -> &Host { match self { - TransportAddr::WSS { host, .. } => host, - TransportAddr::WS { host, .. } => host, - TransportAddr::HTTPS { host, .. } => host, - TransportAddr::HTTP { host, .. } => host, + TransportAddr::Wss { host, .. } => host, + TransportAddr::Ws { host, .. } => host, + TransportAddr::Https { host, .. } => host, + TransportAddr::Http { host, .. } => host, } } pub fn port(&self) -> u16 { match self { - TransportAddr::WSS { port, .. } => *port, - TransportAddr::WS { port, .. } => *port, - TransportAddr::HTTPS { port, .. } => *port, - TransportAddr::HTTP { port, .. } => *port, + TransportAddr::Wss { port, .. } => *port, + TransportAddr::Ws { port, .. } => *port, + TransportAddr::Https { port, .. } => *port, + TransportAddr::Http { port, .. } => *port, } } pub fn scheme_name(&self) -> &str { match self { - TransportAddr::WSS { .. } => "wss", - TransportAddr::WS { .. } => "ws", - TransportAddr::HTTPS { .. } => "https", - TransportAddr::HTTP { .. } => "http", + TransportAddr::Wss { .. } => "wss", + TransportAddr::Ws { .. } => "ws", + TransportAddr::Https { .. } => "https", + TransportAddr::Http { .. } => "http", } } } diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index ddd5e1e..9beb07e 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -403,9 +403,11 @@ async fn server_upgrade( let (close_tx, close_rx) = oneshot::channel::<()>(); ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); - tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current())); + tokio::task::spawn( + super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()), + ); - let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, None).await; + let _ = super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, None).await; } .instrument(Span::current()), ); diff --git a/src/tunnel/io.rs b/src/tunnel/transport/io.rs similarity index 62% rename from src/tunnel/io.rs rename to src/tunnel/transport/io.rs index 607d09e..46b86ba 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/transport/io.rs @@ -1,24 +1,21 @@ -use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; +use crate::tunnel::transport::{TunnelRead, TunnelWrite}; use futures_util::{pin_mut, FutureExt}; -use hyper::upgrade::Upgraded; - -use hyper_util::rt::TokioIo; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::select; use tokio::sync::oneshot; use tokio::time::Instant; use tracing::log::debug; use tracing::{error, info, trace, warn}; -pub(super) async fn propagate_read( +pub async fn propagate_local_to_remote( local_rx: impl AsyncRead, - mut ws_tx: WebSocketWrite>>, + mut ws_tx: impl TunnelWrite, mut close_tx: oneshot::Sender<()>, ping_frequency: Option, -) -> Result<(), WebSocketError> { +) -> anyhow::Result<()> { let _guard = scopeguard::guard((), |_| { - info!("Closing local tx ==> websocket tx tunnel"); + info!("Closing local ==>> remote tunnel"); }); static MAX_PACKET_LENGTH: usize = 64 * 1024; @@ -44,8 +41,7 @@ pub(super) async fn propagate_read( _ = timeout.tick(), if ping_frequency.is_some() => { debug!("sending ping to keep websocket connection alive"); - ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; - + ws_tx.ping().await?; continue; } }; @@ -60,10 +56,7 @@ pub(super) async fn propagate_read( }; //debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity()); - if let Err(err) = ws_tx - .write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len]))) - .await - { + if let Err(err) = ws_tx.write(&buffer[..read_len]).await { warn!("error while writing to websocket tx tunnel {}", err); break; } @@ -87,51 +80,30 @@ pub(super) async fn propagate_read( } // Send normal close - let _ = ws_tx.write_frame(Frame::close(1000, &[])).await; + let _ = ws_tx.close().await; Ok(()) } -pub(super) async fn propagate_write( +pub async fn propagate_remote_to_local( local_tx: impl AsyncWrite, - mut ws_rx: WebSocketRead>>, + mut ws_rx: impl TunnelRead, mut close_rx: oneshot::Receiver<()>, -) -> Result<(), WebSocketError> { +) -> anyhow::Result<()> { let _guard = scopeguard::guard((), |_| { - info!("Closing local rx <== websocket rx tunnel"); + info!("Closing local <<== remote tunnel"); }); - let mut x = |x: Frame<'_>| { - debug!("frame {:?} {:?}", x.opcode, x.payload); - futures_util::future::ready(anyhow::Ok(())) - }; pin_mut!(local_tx); loop { let msg = select! { biased; - msg = ws_rx.read_frame(&mut x) => msg, - + msg = ws_rx.copy(&mut local_tx) => msg, _ = &mut close_rx => break, }; - let msg = match msg { - Ok(msg) => msg, - Err(err) => { - error!("error while reading from websocket rx {}", err); - break; - } - }; - - trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload); - let ret = match msg.opcode { - OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await, - OpCode::Close => break, - OpCode::Ping => Ok(()), - OpCode::Pong => Ok(()), - }; - - if let Err(err) = ret { - error!("error while writing bytes to local for rx tunnel {}", err); + if let Err(err) = msg { + error!("error while reading from websocket rx {}", err); break; } } diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs new file mode 100644 index 0000000..8d24e11 --- /dev/null +++ b/src/tunnel/transport/mod.rs @@ -0,0 +1,14 @@ +use tokio::io::AsyncWrite; + +pub mod io; +pub mod websocket; + +pub trait TunnelWrite { + async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>; + async fn ping(&mut self) -> anyhow::Result<()>; + async fn close(&mut self) -> anyhow::Result<()>; +} + +pub trait TunnelRead { + async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>; +} diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs new file mode 100644 index 0000000..b8d5b68 --- /dev/null +++ b/src/tunnel/transport/websocket.rs @@ -0,0 +1,54 @@ +use crate::tunnel::transport::{TunnelRead, TunnelWrite}; +use anyhow::{anyhow, Context}; +use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; +use log::debug; +use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tracing::trace; + +impl TunnelWrite for WebSocketWrite>> { + async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> { + self.write_frame(Frame::binary(Payload::Borrowed(buf))) + .await + .with_context(|| "cannot send ws frame") + } + + async fn ping(&mut self) -> anyhow::Result<()> { + self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))) + .await + .with_context(|| "cannot send ws ping") + } + + async fn close(&mut self) -> anyhow::Result<()> { + self.write_frame(Frame::close(1000, &[])) + .await + .with_context(|| "cannot close websocket cnx") + } +} + +fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready> { + debug!("frame {:?} {:?}", x.opcode, x.payload); + futures_util::future::ready(anyhow::Ok(())) +} +impl TunnelRead for WebSocketRead>> { + async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> { + loop { + let msg = self + .read_frame(&mut frame_reader) + .await + .with_context(|| "error while reading from websocket")?; + + trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload); + match msg.opcode { + OpCode::Continuation | OpCode::Text | OpCode::Binary => { + writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?; + return Ok(()); + } + OpCode::Close => return Err(anyhow!("websocket close")), + OpCode::Ping => continue, + OpCode::Pong => continue, + }; + } + } +}