diff --git a/src/main.rs b/src/main.rs index b448c7c..8e752d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,7 @@ use tokio_rustls::TlsConnector; use tracing::{error, info}; use crate::dns::DnsResolver; -use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr}; +use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::udp::MyUdpSocket; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; @@ -525,7 +525,7 @@ fn parse_server_url(arg: &str) -> Result { )); }; - if url.scheme() != "ws" && url.scheme() != "wss" { + if !TransportScheme::values().iter().any(|x| x.to_str() == url.scheme()) { return Err(io::Error::new( ErrorKind::InvalidInput, format!("invalid scheme {}", url.scheme()), @@ -658,15 +658,21 @@ async fn main() { match args.commands { Commands::Client(args) => { - let tls = match args.remote_addr.scheme() { - "ws" => None, - "wss" => Some(TlsClientConfig { + let tls = match TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url") + { + TransportScheme::Ws | TransportScheme::Http => None, + TransportScheme::Wss => Some(TlsClientConfig { tls_connector: tls::tls_connector(args.tls_verify_certificate, Some(vec![b"http/1.1".to_vec()])) .expect("Cannot create tls connector"), tls_sni_override: args.tls_sni_override, tls_verify_certificate: args.tls_verify_certificate, }), - _ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()), + TransportScheme::Https => Some(TlsClientConfig { + tls_connector: tls::tls_connector(args.tls_verify_certificate, Some(vec![b"h2".to_vec()])) + .expect("Cannot create tls connector"), + tls_sni_override: args.tls_sni_override, + tls_verify_certificate: args.tls_verify_certificate, + }), }; // Extract host header from http_headers @@ -680,8 +686,8 @@ async fn main() { HeaderValue::from_str(&host).unwrap() }; let mut client_config = WsClientConfig { - remote_addr: TransportAddr::from_str( - args.remote_addr.scheme(), + remote_addr: TransportAddr::new( + TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), args.remote_addr.host().unwrap().to_owned(), args.remote_addr.port_or_known_default().unwrap(), tls, diff --git a/src/tls.rs b/src/tls.rs index 8bf35f4..ccd81ef 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -117,10 +117,7 @@ pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyh TransportAddr::Wss { tls, .. } => &tls.tls_connector, TransportAddr::Https { tls, .. } => &tls.tls_connector, TransportAddr::Http { .. } | TransportAddr::Ws { .. } => { - return Err(anyhow!( - "Transport does not support TLS: {}", - client_cfg.remote_addr.scheme_name() - )) + return Err(anyhow!("Transport does not support TLS: {}", client_cfg.remote_addr.scheme())) } }; let tls_stream = tls_connector.connect(sni, tcp_stream).await.with_context(|| { diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 42aac00..ed3f71b 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -1,75 +1,18 @@ -use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX}; -use crate::WsClientConfig; -use anyhow::{anyhow, Context}; - -use bytes::Bytes; -use fastwebsockets::WebSocket; +use super::{JwtTunnelConfig, RemoteAddr, JWT_DECODE}; +use crate::{tunnel, WsClientConfig}; use futures_util::pin_mut; -use http_body_util::Empty; -use hyper::body::Incoming; -use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; -use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; -use hyper::upgrade::Upgraded; -use hyper::{Request, Response}; -use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper::header::COOKIE; use jsonwebtoken::TokenData; use std::future::Future; -use std::ops::{Deref, DerefMut}; +use std::ops::Deref; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::oneshot; use tokio_stream::{Stream, StreamExt}; -use tracing::log::debug; use tracing::{error, span, Instrument, Level, Span}; use url::Host; use uuid::Uuid; -async fn connect( - request_id: Uuid, - client_cfg: &WsClientConfig, - dest_addr: &RemoteAddr, -) -> anyhow::Result<(WebSocket>, Response)> { - let mut pooled_cnx = match client_cfg.cnx_pool().get().await { - Ok(cnx) => Ok(cnx), - Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), - }?; - - let mut req = Request::builder() - .method("GET") - .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,)) - .header(HOST, &client_cfg.http_header_host) - .header(UPGRADE, "websocket") - .header(CONNECTION, "upgrade") - .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) - .header(SEC_WEBSOCKET_VERSION, "13") - .header( - SEC_WEBSOCKET_PROTOCOL, - format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, dest_addr)), - ) - .version(hyper::Version::HTTP_11); - - for (k, v) in &client_cfg.http_headers { - req = req.header(k, v); - } - if let Some(auth) = &client_cfg.http_upgrade_credentials { - req = req.header(AUTHORIZATION, auth); - } - - let req = req.body(Empty::::new()).with_context(|| { - format!( - "failed to build HTTP request to contact the server {:?}", - client_cfg.remote_addr - ) - })?; - debug!("with HTTP upgrade request {:?}", req); - let transport = pooled_cnx.deref_mut().take().unwrap(); - let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) - .await - .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; - - Ok((ws, response)) -} - //async fn connect_http2( // request_id: Uuid, // client_cfg: &WsClientConfig, @@ -126,10 +69,7 @@ where R: AsyncRead + Send + 'static, W: AsyncWrite + Send + 'static, { - let (mut ws, _) = connect(request_id, client_cfg, remote_cfg).await?; - ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); - - let (ws_rx, ws_tx) = ws.split(tokio::io::split); + let ((ws_rx, ws_tx), _) = tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg).await?; let (local_rx, local_tx) = duplex_stream; let (close_tx, close_rx) = oneshot::channel::<()>(); @@ -198,10 +138,10 @@ where let _span = span.enter(); // Correctly configure tunnel cfg - let (mut ws, response) = connect(request_id, &client_config, &remote_addr) - .instrument(span.clone()) - .await?; - ws.set_auto_apply_mask(client_config.websocket_mask_frame); + let ((ws_rx, ws_tx), response) = + tunnel::transport::websocket::connect(request_id, &client_config, &remote_addr) + .instrument(span.clone()) + .await?; // Connect to endpoint let remote = response @@ -228,7 +168,6 @@ where }; let (local_rx, local_tx) = tokio::io::split(stream); - let (ws_rx, ws_tx) = ws.split(tokio::io::split); let (close_tx, close_rx) = oneshot::channel::<()>(); let tunnel = async move { diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 16dcbdb..2564476 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -10,11 +10,12 @@ use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashSet; -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::io::{Error, IoSlice}; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; use std::pin::Pin; +use std::str::FromStr; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; @@ -78,23 +79,73 @@ pub struct RemoteAddr { pub port: u16, } +#[derive(Copy, Clone, Debug)] +pub enum TransportScheme { + Ws, + Wss, + Http, + Https, +} + +impl TransportScheme { + pub fn values() -> &'static [TransportScheme] { + &[ + TransportScheme::Ws, + TransportScheme::Wss, + TransportScheme::Http, + TransportScheme::Https, + ] + } + pub fn to_str(self) -> &'static str { + match self { + TransportScheme::Ws => "ws", + TransportScheme::Wss => "wss", + TransportScheme::Http => "http", + TransportScheme::Https => "https", + } + } +} +impl FromStr for TransportScheme { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "https" => Ok(TransportScheme::Https), + "http" => Ok(TransportScheme::Http), + "wss" => Ok(TransportScheme::Wss), + "ws" => Ok(TransportScheme::Ws), + _ => Err(()), + } + } +} + +impl Display for TransportScheme { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(self.to_str()) + } +} + #[derive(Clone)] pub enum TransportAddr { Wss { tls: TlsClientConfig, + scheme: TransportScheme, host: Host, port: u16, }, Ws { + scheme: TransportScheme, host: Host, port: u16, }, Https { + scheme: TransportScheme, tls: TlsClientConfig, host: Host, port: u16, }, Http { + scheme: TransportScheme, host: Host, port: u16, }, @@ -102,18 +153,35 @@ pub enum TransportAddr { impl Debug for TransportAddr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{}://{}:{}", self.scheme_name(), self.host(), self.port())) + f.write_fmt(format_args!("{}://{}:{}", self.scheme(), self.host(), self.port())) } } impl TransportAddr { - pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option) -> Option { + pub fn new(scheme: TransportScheme, host: Host, port: u16, tls: Option) -> Option { match scheme { - "https" => Some(TransportAddr::Https { tls: tls?, host, port }), - "http" => Some(TransportAddr::Http { host, port }), - "wss" => Some(TransportAddr::Wss { tls: tls?, host, port }), - "ws" => Some(TransportAddr::Ws { host, port }), - _ => None, + TransportScheme::Https => Some(TransportAddr::Https { + scheme: TransportScheme::Https, + tls: tls?, + host, + port, + }), + TransportScheme::Http => Some(TransportAddr::Http { + scheme: TransportScheme::Http, + host, + port, + }), + TransportScheme::Wss => Some(TransportAddr::Wss { + scheme: TransportScheme::Wss, + tls: tls?, + host, + port, + }), + TransportScheme::Ws => Some(TransportAddr::Ws { + scheme: TransportScheme::Ws, + host, + port, + }), } } pub fn is_websocket(&self) -> bool { @@ -151,12 +219,12 @@ impl TransportAddr { } } - pub fn scheme_name(&self) -> &str { + pub fn scheme(&self) -> &TransportScheme { match self { - TransportAddr::Wss { .. } => "wss", - TransportAddr::Ws { .. } => "ws", - TransportAddr::Https { .. } => "https", - TransportAddr::Http { .. } => "http", + TransportAddr::Wss { scheme, .. } => scheme, + TransportAddr::Ws { scheme, .. } => scheme, + TransportAddr::Https { scheme, .. } => scheme, + TransportAddr::Http { scheme, .. } => scheme, } } } diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 9beb07e..0a82ac9 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -438,7 +438,7 @@ impl TlsContext<'_> { #[inline] pub fn tls_acceptor(&mut self) -> &Arc { if self.tls_reloader.should_reload_certificate() { - match tls::tls_acceptor(self.tls_config, Some(vec![b"http/1.1".to_vec()])) { + match tls::tls_acceptor(self.tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()])) { Ok(acceptor) => self.tls_acceptor = Arc::new(acceptor), Err(err) => error!("Cannot reload TLS certificate {:?}", err), }; @@ -462,7 +462,7 @@ pub async fn run_server(server_config: Arc) -> anyhow::Result<() // Init TLS if needed let mut tls_context = if let Some(tls_config) = &server_config.tls { let tls_context = TlsContext { - tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"http/1.1".to_vec()]))?), + tls_acceptor: Arc::new(tls::tls_acceptor(tls_config, Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]))?), tls_reloader: TlsReloader::new(server_config.clone())?, tls_config, }; diff --git a/src/tunnel/transport/io.rs b/src/tunnel/transport/io.rs index 46b86ba..bd964d5 100644 --- a/src/tunnel/transport/io.rs +++ b/src/tunnel/transport/io.rs @@ -15,7 +15,7 @@ pub async fn propagate_local_to_remote( ping_frequency: Option, ) -> anyhow::Result<()> { let _guard = scopeguard::guard((), |_| { - info!("Closing local ==>> remote tunnel"); + info!("Closing local => remote tunnel"); }); static MAX_PACKET_LENGTH: usize = 64 * 1024; @@ -86,12 +86,12 @@ pub async fn propagate_local_to_remote( } pub async fn propagate_remote_to_local( - local_tx: impl AsyncWrite, + local_tx: impl AsyncWrite + Send, mut ws_rx: impl TunnelRead, mut close_rx: oneshot::Receiver<()>, ) -> anyhow::Result<()> { let _guard = scopeguard::guard((), |_| { - info!("Closing local <<== remote tunnel"); + info!("Closing local <= remote tunnel"); }); pin_mut!(local_tx); diff --git a/src/tunnel/transport/mod.rs b/src/tunnel/transport/mod.rs index 8d24e11..ff18337 100644 --- a/src/tunnel/transport/mod.rs +++ b/src/tunnel/transport/mod.rs @@ -1,14 +1,15 @@ +use std::future::Future; use tokio::io::AsyncWrite; pub mod io; pub mod websocket; -pub trait TunnelWrite { - async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>; - async fn ping(&mut self) -> anyhow::Result<()>; - async fn close(&mut self) -> anyhow::Result<()>; +pub trait TunnelWrite: Send + 'static { + fn write(&mut self, buf: &[u8]) -> impl Future> + Send; + fn ping(&mut self) -> impl Future> + Send; + fn close(&mut self) -> impl Future> + Send; } -pub trait TunnelRead { - async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>; +pub trait TunnelRead: Send + 'static { + fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> impl Future> + Send; } diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index b8d5b68..d302cbe 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,11 +1,22 @@ use crate::tunnel::transport::{TunnelRead, TunnelWrite}; +use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX}; +use crate::WsClientConfig; use anyhow::{anyhow, Context}; +use bytes::Bytes; use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use http_body_util::Empty; +use hyper::body::Incoming; +use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; +use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; use hyper::upgrade::Upgraded; +use hyper::{Request, Response}; +use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; use log::debug; +use std::ops::DerefMut; use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tracing::trace; +use uuid::Uuid; impl TunnelWrite for WebSocketWrite>> { async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> { @@ -32,7 +43,7 @@ fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready> futures_util::future::ready(anyhow::Ok(())) } impl TunnelRead for WebSocketRead>> { - async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> { + async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> anyhow::Result<()> { loop { let msg = self .read_frame(&mut frame_reader) @@ -52,3 +63,51 @@ impl TunnelRead for WebSocketRead>> { } } } + +pub async fn connect( + request_id: Uuid, + client_cfg: &WsClientConfig, + dest_addr: &RemoteAddr, +) -> anyhow::Result<((impl TunnelRead, impl TunnelWrite), Response)> { + let mut pooled_cnx = match client_cfg.cnx_pool().get().await { + Ok(cnx) => Ok(cnx), + Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")), + }?; + + let mut req = Request::builder() + .method("GET") + .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,)) + .header(HOST, &client_cfg.http_header_host) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) + .header(SEC_WEBSOCKET_VERSION, "13") + .header( + SEC_WEBSOCKET_PROTOCOL, + format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, dest_addr)), + ) + .version(hyper::Version::HTTP_11); + + for (k, v) in &client_cfg.http_headers { + req = req.header(k, v); + } + if let Some(auth) = &client_cfg.http_upgrade_credentials { + req = req.header(AUTHORIZATION, auth); + } + + let req = req.body(Empty::::new()).with_context(|| { + format!( + "failed to build HTTP request to contact the server {:?}", + client_cfg.remote_addr + ) + })?; + debug!("with HTTP upgrade request {:?}", req); + let transport = pooled_cnx.deref_mut().take().unwrap(); + let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) + .await + .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; + + ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); + + Ok((ws.split(tokio::io::split), response)) +}