From 17c802bab8c2efec04d819583f8106f39aa5548a Mon Sep 17 00:00:00 2001 From: hex Date: Tue, 29 Jul 2025 20:38:43 +0200 Subject: [PATCH] feat: error handling --- src/db.rs | 2 +- src/main.rs | 6 +- src/server.rs | 21 +++++-- src/services/api.rs | 119 +++++++++++++++++++++++++++---------- src/services/controller.rs | 38 +++++++----- src/services/proxy.rs | 18 +++++- 6 files changed, 145 insertions(+), 59 deletions(-) diff --git a/src/db.rs b/src/db.rs index 0fdccba..734ab00 100644 --- a/src/db.rs +++ b/src/db.rs @@ -65,7 +65,7 @@ impl Endpoint { &mut self, db: &mut BoxyDatabase, hostname: String, - ) -> Result<(), Box> { + ) -> Result<(), tokio_postgres::Error> { let tx = db.client.transaction().await?; let endpoint_id: i32 = tx diff --git a/src/main.rs b/src/main.rs index 8fd2d2c..666d87a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -68,11 +68,7 @@ async fn main() -> Result<(), Box> { let database_shared = Arc::new(Mutex::new(Box::leak(database))); - let api_svc = ApiService { - database: database_shared.clone(), - config: config.clone(), - _address: None, - }; + let api_svc = ApiService::new(database_shared.clone(), config.clone()).await; let svc = ControllerService { database: database_shared, diff --git a/src/server.rs b/src/server.rs index 0eb7c63..52cad1d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,10 +2,7 @@ use std::{any::type_name_of_val, error::Error}; use http_body_util::{Either, Full}; use hyper::{ - Request, Response, - body::{Body, Bytes, Incoming}, - server::conn::http1, - service::{HttpService, Service}, + body::{Body, Bytes, Incoming}, server::conn::http1, service::{HttpService, Service}, Request, Response, StatusCode }; use hyper_util::rt::TokioIo; use log::{error, info}; @@ -28,6 +25,22 @@ pub trait TcpIntercept { fn stream(&mut self, stream: &TcpStream); } +pub async fn default_response() -> GeneralResponse { + Response::builder() + .status(404) + .body(GeneralBody::Right(Full::from(Bytes::from( + "That route doesn't exist.", + )))) + .unwrap() +} + +pub async fn custom_resp(e: StatusCode, m: &'static str) -> GeneralResponse { + Response::builder() + .status(e) + .body(GeneralBody::Right(Full::from(Bytes::from(m)))) + .unwrap() +} + impl Server where S: TcpIntercept, diff --git a/src/services/api.rs b/src/services/api.rs index cab3cac..a6ece5d 100644 --- a/src/services/api.rs +++ b/src/services/api.rs @@ -7,13 +7,13 @@ use hyper::{ body::{Bytes, Incoming}, service::Service, }; -use log::{debug, warn}; +use log::{debug, error, warn}; use tokio::{net::TcpStream, sync::Mutex}; use crate::{ config::{Client, Config}, db::{BoxyDatabase, Endpoint}, - server::{GeneralBody, GeneralResponse, TcpIntercept}, + server::{custom_resp, default_response, GeneralBody, GeneralResponse, TcpIntercept}, }; #[derive(Debug, Clone)] @@ -23,21 +23,6 @@ pub struct ApiService { pub _address: Option, } -async fn default_response() -> GeneralResponse { - Response::builder() - .status(404) - .body(GeneralBody::Right(Full::from(Bytes::from( - "That route doesn't exist.", - )))) - .unwrap() -} - -async fn custom_resp(e: StatusCode, m: &'static str) -> GeneralResponse { - Response::builder() - .status(e) - .body(GeneralBody::Right(Full::from(Bytes::from(m)))) - .unwrap() -} impl TcpIntercept for ApiService { fn stream(&mut self, stream: &TcpStream) { @@ -45,6 +30,16 @@ impl TcpIntercept for ApiService { } } +impl ApiService { + pub async fn new(database: Arc>, config: Config) -> Self { + Self { + database, + config, + _address: None, + } + } +} + impl Service> for ApiService { type Response = GeneralResponse; type Error = hyper::Error; @@ -53,7 +48,7 @@ impl Service> for ApiService { fn call(&self, req: Request) -> Self::Future { let database = self.database.clone(); let config = self.config.clone(); - let address = self._address.clone().unwrap(); + let address = self._address.unwrap(); Box::pin(async move { match *req.method() { @@ -61,19 +56,55 @@ impl Service> for ApiService { "/register" => { debug!("new api register request from {}", address); - let encoded_header = req - .headers() - .get(hyper::header::AUTHORIZATION) - .unwrap() + let encoded_header = + match req.headers().get(hyper::header::AUTHORIZATION) { + None => { + error!( + "Authorization header not given for request from {address}", + ); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid credentials.", + ) + .await); + } + Some(x) => x, + } .to_str() .unwrap(); debug!("authorization header: {}", encoded_header); - let auth_string = String::from_utf8( - BASE64_STANDARD.decode(&encoded_header[6..]).unwrap(), - ) - .unwrap(); + let auth_bytes = match BASE64_STANDARD.decode(&encoded_header[6..]) { + Ok(x) => x, + Err(e) => { + error!( + "Error while decoding authorization header from {address}: {e}", + ); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid base64 string given.", + ) + .await); + } + }; + + let auth_string = match String::from_utf8(auth_bytes) { + Ok(x) => x, + Err(e) => { + error!( + "Error while decoding authorization header from {address}: {e}", + ); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid UTF-8 in authentication string.", + ) + .await); + } + }; debug!("decoded auth string: {}", auth_string); @@ -90,7 +121,7 @@ impl Service> for ApiService { .await); } - let body = String::from_utf8( + let body = match String::from_utf8( req.collect() .await .unwrap() @@ -98,16 +129,40 @@ impl Service> for ApiService { .iter() .cloned() .collect::>(), - ) - .unwrap(); - let json = json::parse(body.as_str()).unwrap(); + ) { + Ok(x) => x, + Err(e) => { + error!( + "Error while inferring UTF-8 string from {address}'s request body: {e}", + ); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid UTF-8 in body.", + ) + .await); + } + }; + + let json = match json::parse(body.as_str()) { + Ok(x) => x, + Err(e) => { + error!("Error while parsing JSON body from {address}: {e}",); + + return Ok(custom_resp( + StatusCode::BAD_REQUEST, + "Invalid JSON in body.", + ) + .await); + } + }; debug!("body: {}", body); let mut endpoint = Endpoint::new( None, address, - json["port"].as_u16().unwrap(), + json["port"].as_u16().unwrap_or(8080), json["callback"].as_str().unwrap_or("/").to_string(), ) .await; @@ -120,7 +175,7 @@ impl Service> for ApiService { .await .unwrap(); - Ok(custom_resp(StatusCode::OK, "").await) + Ok(custom_resp(StatusCode::OK, "Success").await) } _ => Ok(default_response().await), }, diff --git a/src/services/controller.rs b/src/services/controller.rs index 7665ac7..2ae7ccc 100644 --- a/src/services/controller.rs +++ b/src/services/controller.rs @@ -1,15 +1,12 @@ use std::{pin::Pin, sync::Arc}; -use hyper::{ - Request, - body::Incoming, - service::Service, -}; +use hyper::{Request, StatusCode, body::Incoming, service::Service}; +use log::error; use tokio::sync::Mutex; use crate::{ db::{BoxyDatabase, Endpoint}, - server::{GeneralResponse, TcpIntercept}, + server::{GeneralResponse, TcpIntercept, custom_resp}, }; use super::proxy::ProxyService; @@ -31,19 +28,32 @@ impl Service> for ControllerService { fn call(&self, req: Request) -> Self::Future { let database = self.database.clone(); Box::pin(async move { - let hostname = req - .headers() - .get(hyper::header::HOST) - .unwrap() - .to_str() - .unwrap() - .to_string(); + let hostname = match req.headers().get(hyper::header::HOST) { + Some(x) => x, + None => { + error!("No host header given."); + + return Ok(custom_resp(StatusCode::BAD_REQUEST, "No host header given.").await); + } + } + .to_str() + .unwrap() + .to_string(); let endpoints = Endpoint::get_by_hostname(*database.lock().await, hostname.clone()) .await .unwrap(); - let endpoint = endpoints.first().unwrap(); + let endpoint = match endpoints.first() { + Some(x) => x, + None => { + error!("No endpoint found for request."); + + return Ok( + custom_resp(StatusCode::NOT_FOUND, "No endpoint found for host.").await, + ); + } + }; let proxy = ProxyService { address: format!("{}:{}", endpoint.address.clone(), endpoint.port), diff --git a/src/services/proxy.rs b/src/services/proxy.rs index e4641e1..47b0a5e 100644 --- a/src/services/proxy.rs +++ b/src/services/proxy.rs @@ -1,11 +1,11 @@ use std::pin::Pin; -use hyper::{Request, body::Incoming, service::Service}; +use hyper::{Request, StatusCode, body::Incoming, service::Service}; use hyper_util::rt::TokioIo; use log::error; use tokio::net::TcpStream; -use crate::server::{GeneralResponse, to_general_response}; +use crate::server::{GeneralResponse, custom_resp, to_general_response}; #[derive(Debug, Clone)] pub struct ProxyService { @@ -22,7 +22,19 @@ impl Service> for ProxyService { let address = self.address.clone(); let hostname = self.hostname.clone(); Box::pin(async move { - let stream = TcpStream::connect(address).await.unwrap(); + let stream = match TcpStream::connect(address).await { + Ok(x) => x, + Err(e) => { + error!("Could not open connection to endpoint: {e}"); + + return Ok(custom_resp( + StatusCode::BAD_GATEWAY, + "Unable to open connection to endpoint.", + ) + .await); + } + }; + let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::Builder::new()