diff --git a/src/db.rs b/src/db.rs index 3fa8d5d..1666502 100644 --- a/src/db.rs +++ b/src/db.rs @@ -5,6 +5,7 @@ use tokio_postgres::{Client, Row}; const ENDPOINT_TABLE: &str = "endpoints"; const HOSTS_RELATION_TABLE: &str = "hosts"; +const CERTIFICATES_TABLE: &str = "certs"; #[derive(Debug)] pub struct BoxyDatabase { @@ -18,6 +19,86 @@ pub struct Endpoint { pub callback: String, } +pub struct Certificate { + pub hostname: String, + pub cert_data: Vec, + pub key_data: Vec, +} + +impl Certificate { + pub async fn new(hostname: String, cert_data: Vec, key_data: Vec) -> Self { + Self { + hostname, + cert_data, + key_data, + } + } + pub async fn get_by_hostname( + db: &BoxyDatabase, + hostname: String, + ) -> Result> { + let row = db + .client + .query_one( + format!( + "SELECT * FROM {HOSTS_RELATION_TABLE} + WHERE hostname = $1" + ) + .as_str(), + &[&hostname], + ) + .await?; + + Ok(row.into()) + } + pub async fn get_all(db: &BoxyDatabase) -> Result, Box> { + let mut result: Vec = Vec::new(); + + let rows = db + .client + .query(format!("SELECT * FROM {CERTIFICATES_TABLE}").as_str(), &[]) + .await?; + + for row in rows { + result.push(row.into()); + } + + Ok(result) + } +} + +impl Certificate { + pub async fn delete(self, db: &mut BoxyDatabase) -> Result<(), tokio_postgres::Error> { + let tx = db.client.transaction().await?; + + tx.execute( + format!( + "DELETE FROM {CERTIFICATES_TABLE} + WHERE hostname = $1" + ) + .as_str(), + &[&self.hostname], + ) + .await?; + + tx.commit().await?; + + warn!("Removed certificate for host {}", self.hostname); + + Ok(()) + } +} + +impl From for Certificate { + fn from(value: Row) -> Self { + Self { + hostname: value.get("hostname"), + cert_data: value.get("certificate"), + key_data: value.get("key"), + } + } +} + impl Endpoint { pub async fn new(id: Option, address: IpAddr, port: u16, callback: String) -> Self { Self { @@ -119,7 +200,7 @@ impl Endpoint { let id = self.id.unwrap() as i32; tx.execute( - format!("DELETE FROM {ENDPOINT_TABLE} where id = $1").as_str(), + format!("DELETE FROM {ENDPOINT_TABLE} WHERE id = $1").as_str(), &[&id], ) .await?; @@ -205,6 +286,21 @@ impl BoxyDatabase { ) .await?; + c.execute( + format!( + "CREATE TABLE IF NOT EXISTS {CERTIFICATES_TABLE} + ( + hostname text PRIMARY KEY, + certificate bytea, + key bytea + ) + " + ) + .as_str(), + &[], + ) + .await?; + Ok(BoxyDatabase { client: c }) } } diff --git a/src/main.rs b/src/main.rs index abb24be..9967716 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use log::{debug, error, info}; use matchers::api::ApiMatcher; use server::Server; use services::{controller::ControllerService, matcher::Matcher}; -use tls::TlsOption; +use tls::{TlsManager, TlsOption}; use tokio::{ sync::Mutex, time::{self}, @@ -72,6 +72,7 @@ async fn main() -> Result<(), Box> { }); info!("Connected to database."); + let database = Box::new(BoxyDatabase::new(client).await.unwrap()); @@ -87,7 +88,7 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - let proxy_server = Server::new(svc, (config.proxy.listen, config.proxy.port), TlsOption::NoTls) + let proxy_server = Server::new(svc, (config.proxy.listen, config.proxy.port), TlsOption::Tls(TlsManager.clone)) .await .unwrap(); diff --git a/src/matchers/api.rs b/src/matchers/api.rs index 10ad87c..fe24cab 100644 --- a/src/matchers/api.rs +++ b/src/matchers/api.rs @@ -11,7 +11,7 @@ use tokio::{net::TcpStream, sync::Mutex}; use crate::{ config::{Client, Config}, db::BoxyDatabase, - routes::api::{AddHost, RegisterEndpoint, RemoveHost}, + routes::api::{AddHostToEndpoint, RegisterEndpoint, RemoveHost}, server::{GeneralResponse, custom_resp}, services::matcher::Matcher, }; @@ -119,7 +119,7 @@ impl Matcher for ApiMatcher { fn retrieve(&self) -> Vec + Sync + Send>> { vec![ Arc::new(RegisterEndpoint {}), - Arc::new(AddHost {}), + Arc::new(AddHostToEndpoint {}), Arc::new(RemoveHost {}), ] } diff --git a/src/routes/api.rs b/src/routes/api.rs index 8461fb4..bba2248 100644 --- a/src/routes/api.rs +++ b/src/routes/api.rs @@ -10,10 +10,14 @@ use crate::{ services::matcher::Route, }; -pub struct AddHost {} +pub struct LinkHost {} +pub struct UnlinkHost {} +pub struct GetHostStatus {} +pub struct RegisterEndpoint {} +pub struct DeregisterEndpoint {} #[async_trait] -impl Route for AddHost { +impl Route for LinkHost { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { req.uri().path().starts_with("/endpoint/") && req.method() == Method::POST } @@ -63,8 +67,6 @@ impl Route for AddHost { } } -pub struct RegisterEndpoint {} - #[async_trait] impl Route for RegisterEndpoint { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { @@ -113,10 +115,9 @@ impl Route for RegisterEndpoint { } } -pub struct RemoveHost {} #[async_trait] -impl Route for RemoveHost { +impl Route for DeregisterEndpoint { fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { req.uri().path().starts_with("/endpoint/") && req.method() == Method::DELETE } diff --git a/src/server.rs b/src/server.rs index e797ba0..db99e76 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,7 +11,12 @@ use hyper::{ use hyper_util::rt::TokioIo; use json::JsonValue; use log::{error, info}; -use rustls::server::Acceptor; +use rustls::{ + ConfigBuilder, ServerConfig, + pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject}, + server::Acceptor, + sign::SingleCertAndKey, +}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::{LazyConfigAcceptor, StartHandshake}; @@ -73,7 +78,7 @@ where S::ResBody: Send, >::Future: Send, { - pub async fn handle(&self) { + pub async fn handle(&'static self) { info!( "Server started at http://{} for service: {}", self.listener.local_addr().unwrap(), @@ -84,11 +89,10 @@ where let (tcp_stream, _) = self.listener.accept().await.unwrap(); let mut svc_clone = self.service.clone(); - let tls = self.tls.clone(); tokio::task::spawn(async move { svc_clone.stream(&tcp_stream); - match tls { + match &self.tls { TlsOption::NoTls => { if let Err(err) = http1::Builder::new() .writev(false) @@ -103,14 +107,31 @@ where match acceptor.await { Ok(y) => { + let mut manager = x.lock().await; let hello = y.client_hello(); - let hostname = hello.server_name().clone().unwrap(); - let config = Arc::new(x.matcher(hostname).unwrap()); - let stream = y - .into_stream(config) - .await + let hostname = hello.server_name().unwrap(); + + let raw_certificate = + manager.get_certificate(hostname.clone()).await.unwrap(); + + let cert_chain = CertificateDer::pem_slice_iter( + raw_certificate.cert_data.as_slice(), + ) + .map(|cert| cert.unwrap()) + .collect(); + + let key = PrivateKeyDer::from_pem_slice( + raw_certificate.key_data.as_slice(), + ) + .unwrap(); + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key) .unwrap(); + let stream = y.into_stream(Arc::new(config)).await.unwrap(); + if let Err(err) = http1::Builder::new() .writev(false) .serve_connection(TokioIo::new(stream), svc_clone) @@ -119,10 +140,7 @@ where error!("Error while trying to serve connection: {err}") } } - Err(e) => { - error!("Error while initiating handshake: {e}"); - return; - } + Err(e) => error!("Error while initiating handshake: {e}"), } } }; diff --git a/src/tls.rs b/src/tls.rs index e2b4f9c..e45f7b1 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,48 +1,66 @@ -use std::{ - error::{self, Error}, - fs::File, - io, - path::{self, Path}, - sync::Arc, +use std::{collections::HashMap, error::Error, sync::Arc}; + +use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject}; +use tokio::{ + fs::{self, File}, + sync::Mutex, }; -use rustls::{ - ServerConfig, crypto, - pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject}, - sign::CertifiedKey, -}; +use crate::db::BoxyDatabase; -#[derive(Clone)] pub enum TlsOption { NoTls, - Tls(FileTls), + Tls(Mutex), } -#[derive(Clone)] -pub struct FileTls { + +pub struct TlsManager { pub certs_path: String, + pub certificates: HashMap, + pub database: Arc>, } -impl FileTls { - pub fn matcher(&self, hostname: &str) -> Result> { +impl RawCertificate { + pub fn new(cert_data: Vec, key_data: Vec) -> Self { + Self { + cert_data, + key_data, + } + } +} + +impl TlsManager { + pub async fn get_certificate( + &mut self, + hostname: &str, + ) -> Result<&RawCertificate, Box> { + if self.certificates.contains_key(hostname) { + return Ok(self.certificates.get(hostname).unwrap()); + } + let path_to_pem = safe_path::scoped_join(self.certs_path.clone(), format!("{hostname}.pem"))?; let path_to_key = safe_path::scoped_join(self.certs_path.clone(), format!("{hostname}.key"))?; - let certfile = File::open(path_to_pem)?; - let mut cert_reader = io::BufReader::new(certfile); - let certs = rustls_pemfile::certs(&mut cert_reader) - .map(|x| x.unwrap()) - .collect(); + let cert_file = fs::read(path_to_pem).await.unwrap(); + let key_file = fs::read(path_to_key).await.unwrap(); - let keyfile = File::open(path_to_key)?; - let mut key_reader = io::BufReader::new(keyfile); - let key = rustls_pemfile::private_key(&mut key_reader).map(|key| key.unwrap())?; + self.certificates.insert( + hostname.to_string(), + RawCertificate::new(cert_file, key_file), + ); - Ok(ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(|e| e)?) + // fucking borrow checker + Ok(self.certificates.get(hostname).unwrap()) + } +} + +impl TlsManager { + pub async fn new(path: String) -> Self { + Self { + certs_path: path, + certificates: HashMap::new(), + } } }