Compare commits

...
Sign in to create a new pull request.

2 commits
dev ... tls

Author SHA1 Message Date
hex
663d508093 temp 2025-08-10 15:44:37 +02:00
7757ef32f4 temp 2025-08-07 20:08:42 +02:00
8 changed files with 940 additions and 40 deletions

701
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -25,4 +25,9 @@ ansi_colours = "1.2.3"
colour = "2.1.0" colour = "2.1.0"
async-trait = "0.1.88" async-trait = "0.1.88"
http = "1.3.1" http = "1.3.1"
instant-acme = "0.8.2"
rustls = "0.23.31"
safe-path = "0.1.0"
rustls-pemfile = "2.2.0"
tokio-rustls = "0.26.2"

View file

@ -5,6 +5,7 @@ use tokio_postgres::{Client, Row};
const ENDPOINT_TABLE: &str = "endpoints"; const ENDPOINT_TABLE: &str = "endpoints";
const HOSTS_RELATION_TABLE: &str = "hosts"; const HOSTS_RELATION_TABLE: &str = "hosts";
const CERTIFICATES_TABLE: &str = "certs";
#[derive(Debug)] #[derive(Debug)]
pub struct BoxyDatabase { pub struct BoxyDatabase {
@ -18,6 +19,86 @@ pub struct Endpoint {
pub callback: String, pub callback: String,
} }
pub struct Certificate {
pub hostname: String,
pub cert_data: Vec<u8>,
pub key_data: Vec<u8>,
}
impl Certificate {
pub async fn new(hostname: String, cert_data: Vec<u8>, key_data: Vec<u8>) -> Self {
Self {
hostname,
cert_data,
key_data,
}
}
pub async fn get_by_hostname(
db: &BoxyDatabase,
hostname: String,
) -> Result<Self, Box<dyn Error>> {
let row = db
.client
.query_one(
format!(
"SELECT * FROM {HOSTS_RELATION_TABLE}
WHERE hostname = $1"
)
.as_str(),
&[&hostname],
)
.await?;
Ok(row.into())
}
pub async fn get_all(db: &BoxyDatabase) -> Result<Vec<Self>, Box<dyn Error>> {
let mut result: Vec<Self> = Vec::new();
let rows = db
.client
.query(format!("SELECT * FROM {CERTIFICATES_TABLE}").as_str(), &[])
.await?;
for row in rows {
result.push(row.into());
}
Ok(result)
}
}
impl Certificate {
pub async fn delete(self, db: &mut BoxyDatabase) -> Result<(), tokio_postgres::Error> {
let tx = db.client.transaction().await?;
tx.execute(
format!(
"DELETE FROM {CERTIFICATES_TABLE}
WHERE hostname = $1"
)
.as_str(),
&[&self.hostname],
)
.await?;
tx.commit().await?;
warn!("Removed certificate for host {}", self.hostname);
Ok(())
}
}
impl From<Row> for Certificate {
fn from(value: Row) -> Self {
Self {
hostname: value.get("hostname"),
cert_data: value.get("certificate"),
key_data: value.get("key"),
}
}
}
impl Endpoint { impl Endpoint {
pub async fn new(id: Option<u32>, address: IpAddr, port: u16, callback: String) -> Self { pub async fn new(id: Option<u32>, address: IpAddr, port: u16, callback: String) -> Self {
Self { Self {
@ -119,7 +200,7 @@ impl Endpoint {
let id = self.id.unwrap() as i32; let id = self.id.unwrap() as i32;
tx.execute( tx.execute(
format!("DELETE FROM {ENDPOINT_TABLE} where id = $1").as_str(), format!("DELETE FROM {ENDPOINT_TABLE} WHERE id = $1").as_str(),
&[&id], &[&id],
) )
.await?; .await?;
@ -205,6 +286,21 @@ impl BoxyDatabase {
) )
.await?; .await?;
c.execute(
format!(
"CREATE TABLE IF NOT EXISTS {CERTIFICATES_TABLE}
(
hostname text PRIMARY KEY,
certificate bytea,
key bytea
)
"
)
.as_str(),
&[],
)
.await?;
Ok(BoxyDatabase { client: c }) Ok(BoxyDatabase { client: c })
} }
} }

View file

@ -5,6 +5,7 @@ mod matchers;
mod routes; mod routes;
mod server; mod server;
mod services; mod services;
mod tls;
use std::{env, process::exit, sync::Arc, time::Duration}; use std::{env, process::exit, sync::Arc, time::Duration};
@ -15,6 +16,7 @@ use log::{debug, error, info};
use matchers::api::ApiMatcher; use matchers::api::ApiMatcher;
use server::Server; use server::Server;
use services::{controller::ControllerService, matcher::Matcher}; use services::{controller::ControllerService, matcher::Matcher};
use tls::{TlsManager, TlsOption};
use tokio::{ use tokio::{
sync::Mutex, sync::Mutex,
time::{self}, time::{self},
@ -70,6 +72,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}); });
info!("Connected to database."); info!("Connected to database.");
let database = Box::new(BoxyDatabase::new(client).await.unwrap()); let database = Box::new(BoxyDatabase::new(client).await.unwrap());
@ -81,11 +84,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
database: database_shared.clone(), database: database_shared.clone(),
}; };
let api_server = Server::new(api_matcher.service(), (config.api.listen, config.api.port)) let api_server = Server::new(api_matcher.service(), (config.api.listen, config.api.port), TlsOption::NoTls)
.await .await
.unwrap(); .unwrap();
let proxy_server = Server::new(svc, (config.proxy.listen, config.proxy.port)) let proxy_server = Server::new(svc, (config.proxy.listen, config.proxy.port), TlsOption::Tls(TlsManager.clone))
.await .await
.unwrap(); .unwrap();

View file

@ -11,7 +11,7 @@ use tokio::{net::TcpStream, sync::Mutex};
use crate::{ use crate::{
config::{Client, Config}, config::{Client, Config},
db::BoxyDatabase, db::BoxyDatabase,
routes::api::{AddHost, RegisterEndpoint, RemoveHost}, routes::api::{AddHostToEndpoint, RegisterEndpoint, RemoveHost},
server::{GeneralResponse, custom_resp}, server::{GeneralResponse, custom_resp},
services::matcher::Matcher, services::matcher::Matcher,
}; };
@ -119,7 +119,7 @@ impl Matcher for ApiMatcher {
fn retrieve(&self) -> Vec<Arc<dyn crate::services::matcher::Route<Self> + Sync + Send>> { fn retrieve(&self) -> Vec<Arc<dyn crate::services::matcher::Route<Self> + Sync + Send>> {
vec![ vec![
Arc::new(RegisterEndpoint {}), Arc::new(RegisterEndpoint {}),
Arc::new(AddHost {}), Arc::new(AddHostToEndpoint {}),
Arc::new(RemoveHost {}), Arc::new(RemoveHost {}),
] ]
} }

View file

@ -10,10 +10,14 @@ use crate::{
services::matcher::Route, services::matcher::Route,
}; };
pub struct AddHost {} pub struct LinkHost {}
pub struct UnlinkHost {}
pub struct GetHostStatus {}
pub struct RegisterEndpoint {}
pub struct DeregisterEndpoint {}
#[async_trait] #[async_trait]
impl Route<ApiMatcher> for AddHost { impl Route<ApiMatcher> for LinkHost {
fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool {
req.uri().path().starts_with("/endpoint/") && req.method() == Method::POST req.uri().path().starts_with("/endpoint/") && req.method() == Method::POST
} }
@ -63,8 +67,6 @@ impl Route<ApiMatcher> for AddHost {
} }
} }
pub struct RegisterEndpoint {}
#[async_trait] #[async_trait]
impl Route<ApiMatcher> for RegisterEndpoint { impl Route<ApiMatcher> for RegisterEndpoint {
fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool {
@ -113,10 +115,9 @@ impl Route<ApiMatcher> for RegisterEndpoint {
} }
} }
pub struct RemoveHost {}
#[async_trait] #[async_trait]
impl Route<ApiMatcher> for RemoveHost { impl Route<ApiMatcher> for DeregisterEndpoint {
fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request<hyper::body::Incoming>) -> bool {
req.uri().path().starts_with("/endpoint/") && req.method() == Method::DELETE req.uri().path().starts_with("/endpoint/") && req.method() == Method::DELETE
} }

View file

@ -1,16 +1,26 @@
use std::{any::type_name_of_val, error::Error}; use std::{any::type_name_of_val, collections::HashMap, error::Error, sync::Arc};
use http_body_util::{Either, Full}; use http_body_util::{Either, Full};
use hyper::{ use hyper::{
Request, Response, StatusCode, Request, Response, StatusCode,
body::{Body, Bytes, Incoming}, body::{Body, Bytes, Incoming},
rt::{Read, Write},
server::conn::http1, server::conn::http1,
service::{HttpService, Service}, service::{HttpService, Service},
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use json::JsonValue; use json::JsonValue;
use log::{error, info}; use log::{error, info};
use rustls::{
ConfigBuilder, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
server::Acceptor,
sign::SingleCertAndKey,
};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{LazyConfigAcceptor, StartHandshake};
use crate::tls::TlsOption;
pub type GeneralResponse = Response<GeneralBody>; pub type GeneralResponse = Response<GeneralBody>;
pub type GeneralBody = Either<Incoming, Full<Bytes>>; pub type GeneralBody = Either<Incoming, Full<Bytes>>;
@ -23,6 +33,7 @@ pub fn to_general_response(res: Response<Incoming>) -> GeneralResponse {
pub struct Server<S> { pub struct Server<S> {
listener: TcpListener, listener: TcpListener,
service: S, service: S,
tls: TlsOption,
} }
pub trait TcpIntercept { pub trait TcpIntercept {
@ -59,7 +70,7 @@ pub async fn json_to_vec(v: JsonValue) -> Option<Vec<String>> {
impl<S> Server<S> impl<S> Server<S>
where where
S: TcpIntercept, S: TcpIntercept + Sync,
S: Service<Request<Incoming>> + Clone + Send + 'static, S: Service<Request<Incoming>> + Clone + Send + 'static,
S: HttpService<Incoming> + Clone + Send, S: HttpService<Incoming> + Clone + Send,
<S::ResBody as Body>::Error: Into<Box<dyn Error + Send + Sync>>, <S::ResBody as Body>::Error: Into<Box<dyn Error + Send + Sync>>,
@ -67,7 +78,7 @@ where
S::ResBody: Send, S::ResBody: Send,
<S as HttpService<Incoming>>::Future: Send, <S as HttpService<Incoming>>::Future: Send,
{ {
pub async fn handle(&self) { pub async fn handle(&'static self) {
info!( info!(
"Server started at http://{} for service: {}", "Server started at http://{} for service: {}",
self.listener.local_addr().unwrap(), self.listener.local_addr().unwrap(),
@ -75,31 +86,76 @@ where
); );
loop { loop {
let (stream, _) = self.listener.accept().await.unwrap(); let (tcp_stream, _) = self.listener.accept().await.unwrap();
let mut svc_clone = self.service.clone(); let mut svc_clone = self.service.clone();
svc_clone.stream(&stream);
let io = TokioIo::new(stream);
tokio::task::spawn(async move { tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new() svc_clone.stream(&tcp_stream);
.writev(false)
.serve_connection(io, svc_clone) match &self.tls {
.await TlsOption::NoTls => {
{ if let Err(err) = http1::Builder::new()
error!("Error while trying to serve connection: {err}") .writev(false)
.serve_connection(TokioIo::new(tcp_stream), svc_clone)
.await
{
error!("Error while trying to serve connection: {err}")
};
}
TlsOption::Tls(x) => {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), tcp_stream);
match acceptor.await {
Ok(y) => {
let mut manager = x.lock().await;
let hello = y.client_hello();
let hostname = hello.server_name().unwrap();
let raw_certificate =
manager.get_certificate(hostname.clone()).await.unwrap();
let cert_chain = CertificateDer::pem_slice_iter(
raw_certificate.cert_data.as_slice(),
)
.map(|cert| cert.unwrap())
.collect();
let key = PrivateKeyDer::from_pem_slice(
raw_certificate.key_data.as_slice(),
)
.unwrap();
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.unwrap();
let stream = y.into_stream(Arc::new(config)).await.unwrap();
if let Err(err) = http1::Builder::new()
.writev(false)
.serve_connection(TokioIo::new(stream), svc_clone)
.await
{
error!("Error while trying to serve connection: {err}")
}
}
Err(e) => error!("Error while initiating handshake: {e}"),
}
}
}; };
}); });
} }
} }
pub async fn new(service: S, a: (String, u16)) -> Result<Self, Box<dyn Error>> { pub async fn new(service: S, a: (String, u16), tls: TlsOption) -> Result<Self, Box<dyn Error>> {
Ok(Self { Ok(Self {
listener: TcpListener::bind(&a).await?, listener: TcpListener::bind(&a).await?,
service, service,
tls,
}) })
} }
} }
/* /*
*/ */

66
src/tls.rs Normal file
View file

@ -0,0 +1,66 @@
use std::{collections::HashMap, error::Error, sync::Arc};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
use tokio::{
fs::{self, File},
sync::Mutex,
};
use crate::db::BoxyDatabase;
pub enum TlsOption {
NoTls,
Tls(Mutex<TlsManager>),
}
pub struct TlsManager {
pub certs_path: String,
pub certificates: HashMap<String, RawCertificate>,
pub database: Arc<Mutex<&'static mut BoxyDatabase>>,
}
impl RawCertificate {
pub fn new(cert_data: Vec<u8>, key_data: Vec<u8>) -> Self {
Self {
cert_data,
key_data,
}
}
}
impl TlsManager {
pub async fn get_certificate(
&mut self,
hostname: &str,
) -> Result<&RawCertificate, Box<dyn Error>> {
if self.certificates.contains_key(hostname) {
return Ok(self.certificates.get(hostname).unwrap());
}
let path_to_pem =
safe_path::scoped_join(self.certs_path.clone(), format!("{hostname}.pem"))?;
let path_to_key =
safe_path::scoped_join(self.certs_path.clone(), format!("{hostname}.key"))?;
let cert_file = fs::read(path_to_pem).await.unwrap();
let key_file = fs::read(path_to_key).await.unwrap();
self.certificates.insert(
hostname.to_string(),
RawCertificate::new(cert_file, key_file),
);
// fucking borrow checker
Ok(self.certificates.get(hostname).unwrap())
}
}
impl TlsManager {
pub async fn new(path: String) -> Self {
Self {
certs_path: path,
certificates: HashMap::new(),
}
}
}