From f305cb5a85f7bad21010144e7a30a3f1a3e23cdf Mon Sep 17 00:00:00 2001 From: hex Date: Sat, 2 Aug 2025 20:39:13 +0200 Subject: [PATCH 1/2] feat: health checker --- src/health.rs | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/health.rs diff --git a/src/health.rs b/src/health.rs new file mode 100644 index 0000000..83cfbc3 --- /dev/null +++ b/src/health.rs @@ -0,0 +1,60 @@ +use http_body_util::Empty; +use hyper::{Request, body::Bytes}; +use hyper_util::rt::TokioIo; +use log::error; +use tokio::net::TcpStream; + +use crate::db::{BoxyDatabase, Endpoint}; + +pub async fn check(db: &mut BoxyDatabase) { + let endpoints = Endpoint::get_all(db).await.unwrap(); + + for endpoint in endpoints { + let address = format!("{}:{}", endpoint.address, endpoint.port); + + let url = format!("http://{}{}", address, endpoint.callback) + .parse::() + .unwrap(); + + let stream = match TcpStream::connect(address).await { + Ok(x) => x, + Err(e) => { + error!("Could not reach endpoint {}: {e}", endpoint.id.unwrap()); + + endpoint.delete(db).await.unwrap(); + + continue; + } + }; + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder() + .uri(url) + .body(Empty::::new()) + .unwrap(); + + let res = match sender.send_request(req).await { + Ok(x) => x, + Err(e) => { + error!("Could not reach endpoint {}: {e}", endpoint.id.unwrap()); + + endpoint.delete(db).await.unwrap(); + + continue; + } + }; + + if !res.status().is_success() { + endpoint.delete(db).await.unwrap(); + } + } +} From 638b0376d827ce3a1c7bc423dba915c6124dc628 Mon Sep 17 00:00:00 2001 From: hex Date: Sat, 2 Aug 2025 20:40:44 +0200 Subject: [PATCH 2/2] feat: new matcher system instead of manually 'match'ing in servcie --- Cargo.lock | 2 + Cargo.toml | 2 + examples/example-server/src/main.rs | 21 +++- src/db.rs | 123 ++++++++++++++---- src/main.rs | 31 ++++- src/matchers.rs | 1 + src/matchers/api.rs | 180 +++++++++++++++++++++++++++ src/routes.rs | 1 + src/routes/api.rs | 157 +++++++++++++++++++++++ src/server.rs | 20 ++- src/services.rs | 2 +- src/services/api.rs | 186 ---------------------------- src/services/controller.rs | 14 ++- src/services/matcher.rs | 129 +++++++++++++++++++ src/services/proxy.rs | 2 +- 15 files changed, 646 insertions(+), 225 deletions(-) create mode 100644 src/matchers.rs create mode 100644 src/matchers/api.rs create mode 100644 src/routes.rs create mode 100644 src/routes/api.rs delete mode 100644 src/services/api.rs create mode 100644 src/services/matcher.rs diff --git a/Cargo.lock b/Cargo.lock index cff25dd..d81bd87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -129,9 +129,11 @@ version = "0.1.0" dependencies = [ "ansi_colours", "anyhow", + "async-trait", "base64", "bcrypt", "colour", + "http", "http-body-util", "hyper", "hyper-util", diff --git a/Cargo.toml b/Cargo.toml index e6bf1d9..229a6f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,6 @@ string-builder = "0.2.0" json = "0.12.4" ansi_colours = "1.2.3" colour = "2.1.0" +async-trait = "0.1.88" +http = "1.3.1" diff --git a/examples/example-server/src/main.rs b/examples/example-server/src/main.rs index b19cced..52638b7 100644 --- a/examples/example-server/src/main.rs +++ b/examples/example-server/src/main.rs @@ -24,7 +24,7 @@ fn rocket() -> _ { // We define the port of the server running locally and the hostname we want to route to it. let body = json!({ "port": 8000, - "hostname": "localhost:8005", + "hosts": ["localhost:8005"], }); // Send it to Boxy's API @@ -35,7 +35,24 @@ fn rocket() -> _ { .send() .unwrap(); - println!("{}", res.text().unwrap()); + let id = res.text().unwrap(); + + println!("{}", id); + + let body2 = json!({}); + + // Send it to Boxy's API + let res2 = client + .delete(format!( + "http://{}:{}/endpoint/{}", + BOXY_ADDRESS, BOXY_PORT, id + )) + .basic_auth(CLIENT_NAME, Some(CLIENT_SECRET)) + .json(&body2) + .send() + .unwrap(); + + println!("{}", res2.text().unwrap()); build().mount("/", routes![index]) } diff --git a/src/db.rs b/src/db.rs index 734ab00..e0eb17d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,9 +1,8 @@ -use std::{ - error::Error, - net::IpAddr, -}; +use std::{error::Error, net::IpAddr}; -use tokio_postgres::Client; +use json::JsonValue; +use log::warn; +use tokio_postgres::{Client, Row, Statement}; const ENDPOINT_TABLE: &str = "endpoints"; const HOSTS_RELATION_TABLE: &str = "hosts"; @@ -14,14 +13,14 @@ pub struct BoxyDatabase { } pub struct Endpoint { - pub id: Option, + pub id: Option, pub address: IpAddr, pub port: u16, pub callback: String, } impl Endpoint { - pub async fn new(id: Option, address: IpAddr, port: u16, callback: String) -> Self { + pub async fn new(id: Option, address: IpAddr, port: u16, callback: String) -> Self { Self { id, address, @@ -48,23 +47,97 @@ impl Endpoint { .await?; for row in rows { - result.push(Self { - id: row.get("id"), - address: row.get("address"), - port: row.get::<&str, i32>("port") as u16, - callback: row.get("callback"), - }); + result.push(row.into()); } Ok(result) } + pub async fn get_all(db: &BoxyDatabase) -> Result, Box> { + let mut result: Vec = Vec::new(); + + let rows = db + .client + .query(format!("SELECT * FROM {ENDPOINT_TABLE}").as_str(), &[]) + .await?; + + for row in rows { + result.push(row.into()); + } + + Ok(result) + } + pub async fn get_by_id(db: &BoxyDatabase, id: u32) -> Result, Box> { + let endpoint = db + .client + .query_one( + format!("SELECT * FROM {ENDPOINT_TABLE} WHERE id = $1").as_str(), + &[&(id as i32)], + ) + .await?; + + Ok(Some(endpoint.into())) + } +} + +impl From for Endpoint { + fn from(value: Row) -> Self { + Self { + id: Some(value.get::<&str, i32>("id") as u32), + address: value.get("address"), + port: value.get::<&str, i32>("port") as u16, + callback: value.get("callback"), + } + } } impl Endpoint { + pub async fn host( + &self, + db: &mut BoxyDatabase, + hostnames: Vec, + ) -> Result<(), tokio_postgres::Error> { + let tx = db.client.transaction().await?; + + let statement = tx + .prepare( + format!("INSERT INTO {HOSTS_RELATION_TABLE} (endpoint_id,hostname) VALUES ($1,$2)") + .as_str(), + ) + .await?; + + for host in hostnames { + tx.execute(&statement, &[&(self.id.unwrap() as i32), &host]) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + pub async fn delete(self, db: &mut BoxyDatabase) -> Result<(), tokio_postgres::Error> { + let tx = db.client.transaction().await?; + + let id = self.id.unwrap() as i32; + + tx.execute( + format!("DELETE FROM {ENDPOINT_TABLE} where id = $1").as_str(), + &[&id], + ) + .await?; + + tx.commit().await?; + + warn!( + "Removed endpoint with ID: {}, address: {}:{}", + id, self.address, self.port + ); + + Ok(()) + } pub async fn register( &mut self, db: &mut BoxyDatabase, - hostname: String, + hostnames: Vec, ) -> Result<(), tokio_postgres::Error> { let tx = db.client.transaction().await?; @@ -81,16 +154,20 @@ impl Endpoint { .await? .get("id"); - tx.execute( - format!("INSERT INTO {HOSTS_RELATION_TABLE} (endpoint_id,hostname) VALUES ($1,$2)") - .as_str(), - &[&endpoint_id, &hostname], - ) - .await?; + let statement = tx + .prepare( + format!("INSERT INTO {HOSTS_RELATION_TABLE} (endpoint_id,hostname) VALUES ($1,$2)") + .as_str(), + ) + .await?; + + for host in hostnames { + tx.execute(&statement, &[&endpoint_id, &host]).await?; + } tx.commit().await?; - self.id = Some(endpoint_id); + self.id = Some(endpoint_id as u32); Ok(()) } @@ -118,8 +195,8 @@ impl BoxyDatabase { format!( "CREATE TABLE IF NOT EXISTS {HOSTS_RELATION_TABLE} ( - id SERIAL PRIMARY KEY, - endpoint_id int REFERENCES {HOSTS_RELATION_TABLE}(id), + id serial PRIMARY KEY, + endpoint_id int REFERENCES {ENDPOINT_TABLE}(id) ON DELETE CASCADE, hostname text ) " diff --git a/src/main.rs b/src/main.rs index 666d87a..671a19f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,27 @@ mod config; mod db; +mod health; +mod matchers; +mod routes; mod server; mod services; -use std::{env, process::exit, sync::Arc}; +use std::{env, process::exit, sync::Arc, time::Duration}; use bcrypt::DEFAULT_COST; use config::Config; use db::BoxyDatabase; use log::{debug, error, info}; +use matchers::api::ApiMatcher; use server::Server; -use services::{api::ApiService, controller::ControllerService}; -use tokio::sync::Mutex; +use services::{ + controller::ControllerService, + matcher::{Matcher, MatcherService}, +}; +use tokio::{ + sync::Mutex, + time::{self, interval}, +}; use tokio_postgres::NoTls; const VERSION: &str = "v0.1a"; @@ -68,13 +78,13 @@ async fn main() -> Result<(), Box> { let database_shared = Arc::new(Mutex::new(Box::leak(database))); - let api_svc = ApiService::new(database_shared.clone(), config.clone()).await; + let api_matcher = ApiMatcher::new(database_shared.clone(), config.clone()).await; let svc = ControllerService { - database: database_shared, + database: database_shared.clone(), }; - let api_server = Server::new(api_svc, (config.api.listen, config.api.port)) + let api_server = Server::new(api_matcher.service(), (config.api.listen, config.api.port)) .await .unwrap(); @@ -87,6 +97,15 @@ async fn main() -> Result<(), Box> { api_server.handle().await; }); + tokio::task::spawn(async move { + let mut interval = time::interval(Duration::from_secs(30)); + + loop { + health::check(*database_shared.lock().await).await; + interval.tick().await; + } + }); + // We don't put this on a separate thread because we'd be wasting the main thread. info!("Starting proxy server..."); proxy_server.handle().await; diff --git a/src/matchers.rs b/src/matchers.rs new file mode 100644 index 0000000..e5fdf85 --- /dev/null +++ b/src/matchers.rs @@ -0,0 +1 @@ +pub mod api; diff --git a/src/matchers/api.rs b/src/matchers/api.rs new file mode 100644 index 0000000..c8c108a --- /dev/null +++ b/src/matchers/api.rs @@ -0,0 +1,180 @@ +use std::{net::IpAddr, pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use base64::{Engine, prelude::BASE64_STANDARD}; +use http::request::Parts; +use http_body_util::BodyExt; +use hyper::{ + Method, Request, StatusCode, + body::{self, Incoming}, + service::Service, +}; +use json::JsonValue; +use log::{debug, error, warn}; +use tokio::{net::TcpStream, sync::Mutex}; + +use crate::{ + config::{Client, Config}, + db::{BoxyDatabase, Endpoint}, + routes::api::{AddHost, RegisterEndpoint, RemoveHost}, + server::{GeneralResponse, TcpIntercept, custom_resp, default_response, json_to_vec}, + services::matcher::{Matcher, MatcherService}, +}; + +#[derive(Debug, Clone)] +pub struct ApiMatcher { + pub database: Arc>, + pub config: Config, + pub _address: Option, + pub body: Option, +} + +impl ApiMatcher { + pub async fn new(database: Arc>, config: Config) -> Self { + Self { + database, + config, + _address: None, + body: None, + } + } +} + +#[async_trait] +impl Matcher for ApiMatcher { + async fn unimatch( + &mut self, + req: &Request, + ) -> (bool, Option>) { + let address = self._address.unwrap(); + + let encoded_header = match req.headers().get(hyper::header::AUTHORIZATION) { + None => { + error!("Authorization header not given for request from {address}",); + + return ( + false, + Some(Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid credentials.".to_string(), + ) + .await)), + ); + } + Some(x) => x, + } + .to_str() + .unwrap(); + + debug!("authorization header: {}", encoded_header); + + let auth_bytes = match BASE64_STANDARD.decode(&encoded_header[6..]) { + Ok(x) => x, + Err(e) => { + error!("Error while decoding authorization header from {address}: {e}",); + + return ( + false, + Some(Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid base64 string given.".to_string(), + ) + .await)), + ); + } + }; + + let auth_string = match String::from_utf8(auth_bytes) { + Ok(x) => x, + Err(e) => { + error!("Error while decoding authorization header from {address}: {e}",); + + return ( + false, + Some(Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid UTF-8 in body.".to_string(), + ) + .await)), + ); + } + }; + + debug!("decoded auth string: {}", auth_string); + + if !Client::verify(auth_string.clone(), self.config.clone()).await { + warn!( + "Authentication for string {} from {} failed.", + auth_string, address + ); + + return ( + false, + Some(Ok(custom_resp( + StatusCode::UNAUTHORIZED, + "Invalid credentials.".to_string(), + ) + .await)), + ); + } + + return (true, None); + } + + fn retrieve(&self) -> Vec + Sync + Send>> { + vec![ + Arc::new(RegisterEndpoint {}), + Arc::new(AddHost {}), + Arc::new(RemoveHost {}), + ] + } + + fn stream(&mut self, stream: &TcpStream) { + self._address = Some(stream.peer_addr().unwrap().ip()); + } + + async fn body(&mut self, body: Incoming) -> Option> { + let address = self._address.unwrap(); + + let body_string = match String::from_utf8( + body.collect() + .await + .unwrap() + .to_bytes() + .iter() + .cloned() + .collect::>() + .clone(), + ) { + Ok(x) => x, + Err(e) => { + error!("Error while inferring UTF-8 string from {address}'s request body: {e}",); + + return Some(Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid UTF-8 in body.".to_string(), + ) + .await)); + } + }; + + debug!("body: {}", body_string); + + let json = match json::parse(body_string.as_str()) { + Ok(x) => x, + Err(e) => { + error!("Error while parsing JSON body from {address}: {e}",); + + return Some(Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid JSON in body.".to_string(), + ) + .await)); + } + }; + + self.body = Some(json); + + None + } +} diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..e5fdf85 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1 @@ +pub mod api; diff --git a/src/routes/api.rs b/src/routes/api.rs new file mode 100644 index 0000000..d137093 --- /dev/null +++ b/src/routes/api.rs @@ -0,0 +1,157 @@ +use async_trait::async_trait; +use base64::{Engine, prelude::BASE64_STANDARD}; +use http::request::Parts; +use http_body_util::BodyExt; +use hyper::{Method, StatusCode}; +use log::{debug, error, warn}; + +use crate::{ + config::Client, + db::Endpoint, + matchers::api::ApiMatcher, + server::{custom_resp, json_to_vec}, + services::matcher::Route, +}; + +pub struct AddHost {} + +#[async_trait] +impl Route for AddHost { + fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { + req.uri().path().starts_with("/endpoint/") && req.method() == Method::POST + } + + async fn call( + &self, + m: &ApiMatcher, + parts: Parts, + ) -> Result { + let database = m.database.clone(); + let address = m._address.unwrap(); + let body = m.body.clone().unwrap(); + + let endpoint_id: u32 = parts.uri.path().replace("/endpoint/", "").parse().unwrap(); + + if !body["hosts"].is_array() { + error!("Hosts parameter is not an array.",); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Hosts parameter is not an array.".to_string(), + ) + .await); + } + + let endpoint = match Endpoint::get_by_id(*database.lock().await, endpoint_id) + .await + .unwrap() + { + Some(x) => x, + None => { + error!("No endpoint found by id {endpoint_id} from {address}",); + + return Ok(custom_resp( + StatusCode::NOT_FOUND, + "No endpoint by that ID.".to_string(), + ) + .await); + } + }; + + let hosts = json_to_vec(body["hosts"].clone()).await.unwrap(); + + endpoint.host(*database.lock().await, hosts).await.unwrap(); + + Ok(custom_resp(StatusCode::OK, "Success".to_string()).await) + } +} + +pub struct RegisterEndpoint {} + +#[async_trait] +impl Route for RegisterEndpoint { + fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { + req.uri().path() == "/register" && req.method() == Method::POST + } + + async fn call( + &self, + m: &ApiMatcher, + _: Parts, + ) -> Result { + let address = m._address.unwrap(); + let database = m.database.clone(); + let body = m.body.clone().unwrap(); + + let mut endpoint = Endpoint::new( + None, + address, + body["port"].as_u16().unwrap_or(8080), + body["callback"].as_str().unwrap_or("/").to_string(), + ) + .await; + + if !body["hosts"].is_array() { + error!("Hosts parameter is not an array.",); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Hosts parameter is not an array.".to_string(), + ) + .await); + }; + + let hosts = json_to_vec(body["hosts"].clone()).await.unwrap(); + + endpoint + .register(*database.lock().await, hosts) + .await + .unwrap(); + + let endpoint_id = endpoint.id.unwrap().to_string(); + + let response = custom_resp(StatusCode::OK, endpoint_id).await; + + Ok(response) + } +} + +pub struct RemoveHost {} + +#[async_trait] +impl Route for RemoveHost { + fn matcher(&self, _: &ApiMatcher, req: &hyper::Request) -> bool { + req.uri().path().starts_with("/endpoint/") && req.method() == Method::DELETE + } + + async fn call( + &self, + m: &ApiMatcher, + parts: Parts, + ) -> Result { + let database = m.database.clone(); + let address = m._address.unwrap(); + + let endpoint_id: u32 = parts.uri.path().replace("/endpoint/", "").parse().unwrap(); + + let endpoint = match Endpoint::get_by_id(*database.lock().await, endpoint_id) + .await + .unwrap() + { + Some(x) => x, + None => { + error!("No endpoint found by id {endpoint_id} from {address}",); + + return Ok(custom_resp( + StatusCode::NOT_FOUND, + "No endpoint by that ID.".to_string(), + ) + .await); + } + }; + + endpoint.delete(*database.lock().await).await.unwrap(); + + Ok(custom_resp(StatusCode::OK, "Success".to_string()).await) + } +} diff --git a/src/server.rs b/src/server.rs index 52cad1d..662bdec 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,9 +2,13 @@ use std::{any::type_name_of_val, error::Error}; use http_body_util::{Either, Full}; use hyper::{ - body::{Body, Bytes, Incoming}, server::conn::http1, service::{HttpService, Service}, Request, Response, StatusCode + Request, Response, StatusCode, + body::{Body, Bytes, Incoming}, + server::conn::http1, + service::{HttpService, Service}, }; use hyper_util::rt::TokioIo; +use json::JsonValue; use log::{error, info}; use tokio::net::{TcpListener, TcpStream}; @@ -34,13 +38,25 @@ pub async fn default_response() -> GeneralResponse { .unwrap() } -pub async fn custom_resp(e: StatusCode, m: &'static str) -> GeneralResponse { +pub async fn custom_resp(e: StatusCode, m: String) -> GeneralResponse { Response::builder() .status(e) .body(GeneralBody::Right(Full::from(Bytes::from(m)))) .unwrap() } +pub async fn json_to_vec(v: JsonValue) -> Option> { + if let JsonValue::Array(arr) = v { + Some( + arr.into_iter() + .map(|val| val.as_str().unwrap().to_string()) + .collect(), + ) + } else { + None + } +} + impl Server where S: TcpIntercept, diff --git a/src/services.rs b/src/services.rs index 168dbf8..381d0dc 100644 --- a/src/services.rs +++ b/src/services.rs @@ -1,3 +1,3 @@ -pub mod api; pub mod controller; pub mod proxy; +pub mod matcher; diff --git a/src/services/api.rs b/src/services/api.rs deleted file mode 100644 index ee86312..0000000 --- a/src/services/api.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::{net::IpAddr, pin::Pin, sync::Arc}; - -use base64::{Engine, prelude::BASE64_STANDARD}; -use http_body_util::BodyExt; -use hyper::{ - Method, Request, StatusCode, - body::Incoming, - service::Service, -}; -use log::{debug, error, warn}; -use tokio::{net::TcpStream, sync::Mutex}; - -use crate::{ - config::{Client, Config}, - db::{BoxyDatabase, Endpoint}, - server::{custom_resp, default_response, GeneralResponse, TcpIntercept}, -}; - -#[derive(Debug, Clone)] -pub struct ApiService { - pub database: Arc>, - pub config: Config, - pub _address: Option, -} - - -impl TcpIntercept for ApiService { - fn stream(&mut self, stream: &TcpStream) { - self._address = Some(stream.peer_addr().unwrap().ip()); - } -} - -impl ApiService { - pub async fn new(database: Arc>, config: Config) -> Self { - Self { - database, - config, - _address: None, - } - } -} - -impl Service> for ApiService { - type Response = GeneralResponse; - type Error = hyper::Error; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - let database = self.database.clone(); - let config = self.config.clone(); - let address = self._address.unwrap(); - - Box::pin(async move { - match *req.method() { - Method::POST => match req.uri().path() { - "/register" => { - debug!("new api register request from {}", address); - - let encoded_header = - match req.headers().get(hyper::header::AUTHORIZATION) { - None => { - error!( - "Authorization header not given for request from {address}", - ); - - return Ok(custom_resp( - StatusCode::BAD_REQUEST, - "Invalid credentials.", - ) - .await); - } - Some(x) => x, - } - .to_str() - .unwrap(); - - debug!("authorization header: {}", encoded_header); - - let auth_bytes = match BASE64_STANDARD.decode(&encoded_header[6..]) { - Ok(x) => x, - Err(e) => { - error!( - "Error while decoding authorization header from {address}: {e}", - ); - - return Ok(custom_resp( - StatusCode::BAD_REQUEST, - "Invalid base64 string given.", - ) - .await); - } - }; - - let auth_string = match String::from_utf8(auth_bytes) { - Ok(x) => x, - Err(e) => { - error!( - "Error while decoding authorization header from {address}: {e}", - ); - - return Ok(custom_resp( - StatusCode::BAD_REQUEST, - "Invalid UTF-8 in authentication string.", - ) - .await); - } - }; - - debug!("decoded auth string: {}", auth_string); - - if !Client::verify(auth_string.clone(), config).await { - warn!( - "Authentication for string {} from {} failed.", - auth_string, address - ); - - return Ok(custom_resp( - StatusCode::UNAUTHORIZED, - "Invalid credentials.", - ) - .await); - } - - let body = match String::from_utf8( - req.collect() - .await - .unwrap() - .to_bytes() - .iter() - .cloned() - .collect::>(), - ) { - Ok(x) => x, - Err(e) => { - error!( - "Error while inferring UTF-8 string from {address}'s request body: {e}", - ); - - return Ok(custom_resp( - StatusCode::BAD_REQUEST, - "Invalid UTF-8 in body.", - ) - .await); - } - }; - - let json = match json::parse(body.as_str()) { - Ok(x) => x, - Err(e) => { - error!("Error while parsing JSON body from {address}: {e}",); - - return Ok(custom_resp( - StatusCode::BAD_REQUEST, - "Invalid JSON in body.", - ) - .await); - } - }; - - debug!("body: {}", body); - - let mut endpoint = Endpoint::new( - None, - address, - json["port"].as_u16().unwrap_or(8080), - json["callback"].as_str().unwrap_or("/").to_string(), - ) - .await; - - endpoint - .register( - *database.lock().await, - json["hostname"].as_str().unwrap().to_string(), - ) - .await - .unwrap(); - - Ok(custom_resp(StatusCode::OK, "Success").await) - } - _ => Ok(default_response().await), - }, - _ => Ok(default_response().await), - } - }) - } -} diff --git a/src/services/controller.rs b/src/services/controller.rs index 2ae7ccc..628f432 100644 --- a/src/services/controller.rs +++ b/src/services/controller.rs @@ -33,7 +33,11 @@ impl Service> for ControllerService { None => { error!("No host header given."); - return Ok(custom_resp(StatusCode::BAD_REQUEST, "No host header given.").await); + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "No host header given.".to_string(), + ) + .await); } } .to_str() @@ -49,9 +53,11 @@ impl Service> for ControllerService { None => { error!("No endpoint found for request."); - return Ok( - custom_resp(StatusCode::NOT_FOUND, "No endpoint found for host.").await, - ); + return Ok(custom_resp( + StatusCode::NOT_FOUND, + "No endpoint found for host.".to_string(), + ) + .await); } }; diff --git a/src/services/matcher.rs b/src/services/matcher.rs new file mode 100644 index 0000000..fbcb074 --- /dev/null +++ b/src/services/matcher.rs @@ -0,0 +1,129 @@ +use std::{net::IpAddr, pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use base64::{Engine, prelude::BASE64_STANDARD}; +use http::request::Parts; +use http_body_util::BodyExt; +use hyper::{ + Method, Request, StatusCode, + body::{self, Incoming}, + service::Service, +}; +use log::{debug, error, warn}; +use tokio::{net::TcpStream, sync::Mutex}; + +use crate::{ + config::{Client, Config}, + db::{BoxyDatabase, Endpoint}, + server::{GeneralResponse, TcpIntercept, custom_resp, default_response, json_to_vec}, +}; + +// The routes itself +#[async_trait] +pub trait Route +where + T: Matcher, +{ + fn matcher(&self, m: &T, req: &Request) -> bool; + async fn call(&self, m: &T, parts: Parts) -> Result; +} + +// Matcher, essentially just a router that contains routes and some other features +#[async_trait] +pub trait Matcher: Clone + Send + Sync + 'static { + // Essentially a kind of "middleware", a universal matcher. If it doesn't match, it won't + // route. + async fn unimatch( + &mut self, + req: &Request, + ) -> (bool, Option>); + + // Return list of routes associated with self matcher + fn retrieve(&self) -> Vec + Sync + Send>>; + + // Wrap self into matcher service + fn service(self) -> MatcherService { + MatcherService::new(self) + } + + // Do something with TCP stream + fn stream(&mut self, stream: &TcpStream) {} + + // Body parser - made universal for api server cause lazy + async fn body(&mut self, body: Incoming) -> Option> { + None + } +} + +// Wrapper service, wraps matcher into a service +#[derive(Clone)] +pub struct MatcherService +where + T: Matcher, +{ + inner: T, +} + +impl MatcherService +where + T: Matcher, +{ + pub fn new(inner: T) -> Self { + Self { inner } + } +} + +impl Service> for MatcherService +where + T: Matcher, +{ + type Response = GeneralResponse; + type Error = hyper::Error; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + let mut matcher = self.inner.clone(); + + Box::pin(async move { + let unimatched = matcher.unimatch(&req).await; + + if !unimatched.0 { + match unimatched.1 { + Some(x) => { + return x; + } + None => { + return Ok(custom_resp( + StatusCode::NOT_FOUND, + "Could not match route".to_string(), + ) + .await); + } + } + } + + for r in matcher.retrieve() { + if r.matcher(&matcher, &req) { + let (parts, body) = req.into_parts(); + + if let Some(resp) = matcher.body(body).await { + return resp; + } + + return r.call(&matcher, parts).await; + } + } + + Ok(default_response().await) + }) + } +} + +impl TcpIntercept for MatcherService +where + T: Matcher, +{ + fn stream(&mut self, stream: &TcpStream) { + self.inner.stream(stream); + } +} diff --git a/src/services/proxy.rs b/src/services/proxy.rs index 47b0a5e..0c59f3f 100644 --- a/src/services/proxy.rs +++ b/src/services/proxy.rs @@ -29,7 +29,7 @@ impl Service> for ProxyService { return Ok(custom_resp( StatusCode::BAD_GATEWAY, - "Unable to open connection to endpoint.", + "Unable to open connection to endpoint.".to_string(), ) .await); }