Prep work for new transport

This commit is contained in:
Σrebe - Romain GERARD 2024-01-13 18:42:15 +01:00
parent 62f6a0287d
commit 6375e14185
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 215 additions and 40 deletions

View file

@ -24,7 +24,7 @@ use tracing::{error, span, Instrument, Level, Span};
use url::Host;
use uuid::Uuid;
pub async fn connect(
async fn connect(
request_id: Uuid,
client_cfg: &WsClientConfig,
dest_addr: &RemoteAddr,
@ -70,6 +70,52 @@ pub async fn connect(
Ok((ws, response))
}
//async fn connect_http2(
// request_id: Uuid,
// client_cfg: &WsClientConfig,
// dest_addr: &RemoteAddr,
//) -> anyhow::Result<BodyStream<Incoming>> {
// let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
// Ok(cnx) => Ok(cnx),
// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
// }?;
//
// let mut req = Request::builder()
// .method("GET")
// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
// .header(HOST, &client_cfg.http_header_host)
// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
// .version(hyper::Version::HTTP_2);
//
// for (k, v) in &client_cfg.http_headers {
// req = req.header(k, v);
// }
// if let Some(auth) = &client_cfg.http_upgrade_credentials {
// req = req.header(AUTHORIZATION, auth);
// }
//
// let x: Vec<u8> = vec![];
// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))]));
// let req = req.body(Empty::<Bytes>::new()).with_context(|| {
// format!(
// "failed to build HTTP request to contact the server {:?}",
// client_cfg.remote_addr
// )
// })?;
// debug!("with HTTP upgrade request {:?}", req);
// let transport = pooled_cnx.deref_mut().take().unwrap();
// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await
// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
// tokio::spawn(cnx);
//
// let response = request_sender.send_request(req)
// .await
// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
//
// // TODO: verify response is ok
// Ok(BodyStream::new(response.into_body()))
//}
async fn connect_to_server<R, W>(
request_id: Uuid,
client_cfg: &WsClientConfig,

View file

@ -3,13 +3,14 @@ mod io;
pub mod server;
mod tls_reloader;
use crate::{tcp, tls, LocalProtocol, WsClientConfig};
use crate::{tcp, tls, LocalProtocol, TlsClientConfig, WsClientConfig};
use async_trait::async_trait;
use bb8::ManageConnection;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::io::{Error, IoSlice};
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
@ -77,6 +78,97 @@ pub struct RemoteAddr {
pub port: u16,
}
#[derive(Clone)]
pub enum TransportAddr {
WSS {
tls: TlsClientConfig,
host: Host,
port: u16,
},
WS {
host: Host,
port: u16,
},
HTTPS {
tls: TlsClientConfig,
host: Host,
port: u16,
},
HTTP {
host: Host,
port: u16,
},
}
impl Debug for TransportAddr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}://{}:{}", self.scheme_name(), self.host(), self.port()))
}
}
impl TransportAddr {
pub fn from_str(scheme: &str, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
match scheme {
"https" => {
let Some(tls) = tls else { return None };
Some(TransportAddr::HTTPS { tls, host, port })
}
"http" => Some(TransportAddr::HTTP { host, port }),
"wss" => {
let Some(tls) = tls else { return None };
Some(TransportAddr::WSS { tls, host, port })
}
"ws" => Some(TransportAddr::WS { host, port }),
_ => None,
}
}
pub fn is_websocket(&self) -> bool {
matches!(self, TransportAddr::WS { .. } | TransportAddr::WSS { .. })
}
pub fn is_http2(&self) -> bool {
matches!(self, TransportAddr::HTTP { .. } | TransportAddr::HTTPS { .. })
}
pub fn tls(&self) -> Option<&TlsClientConfig> {
match self {
TransportAddr::WSS { tls, .. } => Some(tls),
TransportAddr::HTTPS { tls, .. } => Some(tls),
TransportAddr::WS { .. } => None,
TransportAddr::HTTP { .. } => None,
}
}
pub fn host(&self) -> &Host {
match self {
TransportAddr::WSS { host, .. } => host,
TransportAddr::WS { host, .. } => host,
TransportAddr::HTTPS { host, .. } => host,
TransportAddr::HTTP { host, .. } => host,
}
}
pub fn port(&self) -> u16 {
match self {
TransportAddr::WSS { port, .. } => *port,
TransportAddr::WS { port, .. } => *port,
TransportAddr::HTTPS { port, .. } => *port,
TransportAddr::HTTP { port, .. } => *port,
}
}
pub fn scheme_name(&self) -> &str {
match self {
TransportAddr::WSS { .. } => "wss",
TransportAddr::WS { .. } => "ws",
TransportAddr::HTTPS { .. } => "https",
TransportAddr::HTTP { .. } => "http",
}
}
}
impl TryFrom<JwtTunnelConfig> for RemoteAddr {
type Error = anyhow::Error;
fn try_from(jwt: JwtTunnelConfig) -> anyhow::Result<Self> {
@ -150,22 +242,35 @@ impl ManageConnection for WsClientConfig {
#[instrument(level = "trace", name = "cnx_server", skip_all)]
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let (host, port) = &self.remote_addr;
let so_mark = self.socket_so_mark;
let timeout = self.timeout_connect;
let tcp_stream = if let Some(http_proxy) = &self.http_proxy {
tcp::connect_with_http_proxy(http_proxy, host, *port, so_mark, timeout, &self.dns_resolver).await?
tcp::connect_with_http_proxy(
http_proxy,
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
} else {
tcp::connect(host, *port, so_mark, timeout, &self.dns_resolver).await?
tcp::connect(
self.remote_addr.host(),
self.remote_addr.port(),
so_mark,
timeout,
&self.dns_resolver,
)
.await?
};
match &self.tls {
None => Ok(Some(TransportStream::Plain(tcp_stream))),
Some(tls_cfg) => {
let tls_stream = tls::connect(self, tls_cfg, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
}
if self.remote_addr.tls().is_some() {
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
} else {
Ok(Some(TransportStream::Plain(tcp_stream)))
}
}