From bf2036a32ebc8bf757941ecf9c2d4a33d3b3fff5 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Thu, 27 Feb 2025 08:23:33 +0100 Subject: [PATCH] Try to handle errors a bit more gracefully In many cases where iocaine was using `unwrap()`, handle it gracefully, and return a `Result` instead. Signed-off-by: Gergely Nagy --- src/app.rs | 48 ++++++++++++++++---- src/assembled_statistical_sequences.rs | 63 +++++++++++++++++++------- src/tenx_programmer.rs | 21 +++++---- tests/test_app.rs | 5 +- 4 files changed, 99 insertions(+), 38 deletions(-) diff --git a/src/app.rs b/src/app.rs index 27c479c..8e40db9 100644 --- a/src/app.rs +++ b/src/app.rs @@ -6,8 +6,9 @@ use anyhow::Result; use axum::{ extract::{Path, State}, + http::StatusCode, middleware, - response::Html, + response::{Html, IntoResponse, Response}, routing::get, Router, }; @@ -110,20 +111,49 @@ async fn poison( headers: axum::http::HeaderMap, State(iocaine): State, path: Option>, -) -> Html { +) -> std::result::Result, AppError> { let default_host = axum::http::HeaderValue::from_static(""); - let host = headers - .get("host") - .unwrap_or(&default_host) - .to_str() - .unwrap(); + let host = headers.get("host").unwrap_or(&default_host).to_str()?; let path = path.unwrap_or(Path("/".to_string())); - let garbage = AssembledStatisticalSequences::generate(&iocaine, host, &path); + let garbage = AssembledStatisticalSequences::generate(&iocaine, host, &path)?; if iocaine.config.metrics.enable { metrics::counter!("iocaine_garbage_served").increment(garbage.len() as u64); } - Html(garbage) + Ok(Html(garbage)) +} + +pub struct AppError(anyhow::Error); + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + tracing::error!("Internal server error: {}", self.0); + (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() + } +} + +impl From for AppError { + fn from(e: axum::http::header::ToStrError) -> Self { + Self(e.into()) + } +} + +impl From for AppError { + fn from(e: anyhow::Error) -> Self { + Self(e) + } +} + +impl From for AppError { + fn from(e: std::io::Error) -> Self { + Self(e.into()) + } +} + +impl From for AppError { + fn from(e: metrics_exporter_prometheus::BuildError) -> Self { + Self(e.into()) + } } diff --git a/src/assembled_statistical_sequences.rs b/src/assembled_statistical_sequences.rs index 158a6c4..ea147ef 100644 --- a/src/assembled_statistical_sequences.rs +++ b/src/assembled_statistical_sequences.rs @@ -3,7 +3,8 @@ // // SPDX-License-Identifier: MIT -use handlebars::Handlebars; +use anyhow::Result; +use handlebars::{Handlebars, RenderErrorReason}; use rand::{seq::IndexedRandom, Rng}; use rust_embed::Embed; use serde::Serialize; @@ -37,7 +38,7 @@ struct Page<'a> { pub struct AssembledStatisticalSequences; impl AssembledStatisticalSequences { - pub fn generate(iocaine: &Iocaine, host: &str, path: &str) -> String { + pub fn generate(iocaine: &Iocaine, host: &str, path: &str) -> Result { let initial_seed = &iocaine.config.generator.initial_seed; let static_seed = format!("{}/{}#{}", host, path, initial_seed); let markov_gen = |h: &handlebars::Helper, @@ -46,9 +47,24 @@ impl AssembledStatisticalSequences { _: &mut handlebars::RenderContext, out: &mut dyn handlebars::Output| -> handlebars::HelperResult { - let group = h.param(0).unwrap().value().as_str().unwrap(); - let index = h.param(1).unwrap().value().as_i64().unwrap(); - let words = h.param(2).unwrap().value().as_i64().unwrap(); + let group = h + .param(0) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("group", 0))? + .value() + .as_str() + .ok_or(RenderErrorReason::InvalidParamType("string"))?; + let index = h + .param(1) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("index", 1))? + .value() + .as_i64() + .ok_or(RenderErrorReason::InvalidParamType("i64"))?; + let words = h + .param(2) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("words", 2))? + .value() + .as_i64() + .ok_or(RenderErrorReason::InvalidParamType("i64"))?; let rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index)); let chain = iocaine.chain.generate(rng).take(words as usize); @@ -62,9 +78,24 @@ impl AssembledStatisticalSequences { _: &mut handlebars::RenderContext, out: &mut dyn handlebars::Output| -> handlebars::HelperResult { - let group = h.param(0).unwrap().value().as_str().unwrap(); - let index = h.param(1).unwrap().value().as_i64().unwrap(); - let count = h.param(2).unwrap().value().as_i64().unwrap(); + let group = h + .param(0) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("group", 0))? + .value() + .as_str() + .ok_or(RenderErrorReason::InvalidParamType("string"))?; + let index = h + .param(1) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("index", 1))? + .value() + .as_i64() + .ok_or(RenderErrorReason::InvalidParamType("i64"))?; + let count = h + .param(2) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("count", 2))? + .value() + .as_i64() + .ok_or(RenderErrorReason::InvalidParamType("i64"))?; let mut rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index)); let words = (1..=count) @@ -127,12 +158,9 @@ impl AssembledStatisticalSequences { let mut handlebars = Handlebars::new(); if let Some(dir) = &iocaine.config.templates.directory { handlebars - .register_templates_directory(dir, handlebars::DirectorySourceOptions::default()) - .unwrap(); + .register_templates_directory(dir, handlebars::DirectorySourceOptions::default())?; } else { - handlebars - .register_embed_templates_with_extension::(".hbs") - .unwrap(); + handlebars.register_embed_templates_with_extension::(".hbs")?; } handlebars.register_helper("markov-gen", Box::new(markov_gen)); handlebars.register_helper("href-gen", Box::new(href_gen)); @@ -145,10 +173,11 @@ impl AssembledStatisticalSequences { }; let host_template = &format!("hosts/{}", host); - if handlebars.has_template(host_template) { - handlebars.render(host_template, &data).unwrap() + let rendered = if handlebars.has_template(host_template) { + handlebars.render(host_template, &data)? } else { - handlebars.render("main", &data).unwrap() - } + handlebars.render("main", &data)? + }; + Ok(rendered) } } diff --git a/src/tenx_programmer.rs b/src/tenx_programmer.rs index 4447043..1692b7e 100644 --- a/src/tenx_programmer.rs +++ b/src/tenx_programmer.rs @@ -4,9 +4,10 @@ // SPDX-License-Identifier: MIT use axum::{ + body::Body, extract::{Request, State}, middleware::Next, - response::IntoResponse, + response::Response, routing::get, Router, }; @@ -14,7 +15,7 @@ use metrics_exporter_prometheus::PrometheusBuilder; use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - app::{shutdown_signal, StatefulIocaine}, + app::{shutdown_signal, AppError, StatefulIocaine}, config::MetricsLabel, }; @@ -22,7 +23,7 @@ pub async fn track_metrics( State(iocaine): State, req: Request, next: Next, -) -> impl IntoResponse { +) -> Result, AppError> { let headers = req.headers().clone(); let response = next.run(req).await; let cfg = &iocaine.config.metrics; @@ -30,7 +31,7 @@ pub async fn track_metrics( let mut labels = Vec::new(); if cfg.labels.contains(&MetricsLabel::Host) { if let Some(host) = headers.get("host") { - let host = host.to_str().unwrap().to_string(); + let host = host.to_str()?.to_string(); labels.push(("host", host)); } } @@ -39,7 +40,7 @@ pub async fn track_metrics( || cfg.labels.contains(&MetricsLabel::UserAgentGroup) { if let Some(ua) = headers.get("user-agent") { - let user_agent = ua.to_str().unwrap().to_string(); + let user_agent = ua.to_str()?.to_string(); if cfg.labels.contains(&MetricsLabel::UserAgent) { labels.push(("user-agent", user_agent.clone())); @@ -59,12 +60,12 @@ pub async fn track_metrics( metrics::counter!("iocaine_requests_total", &labels).increment(1); - response + Ok(response) } -pub async fn start_metrics_server(metrics_bind: String) -> std::result::Result<(), std::io::Error> { +pub async fn start_metrics_server(metrics_bind: String) -> std::result::Result<(), AppError> { let metrics_listener = tokio::net::TcpListener::bind(metrics_bind).await?; - let recorder_handle = PrometheusBuilder::new().install_recorder().unwrap(); + let recorder_handle = PrometheusBuilder::new().install_recorder()?; let app = Router::new().route("/metrics", get(|| async move { recorder_handle.render() })); let ts = SystemTime::now() @@ -73,7 +74,7 @@ pub async fn start_metrics_server(metrics_bind: String) -> std::result::Result<( let labels = [("service", "iocaine".to_string())]; metrics::gauge!("process_start_time_seconds", &labels).set(ts); - axum::serve(metrics_listener, app) + Ok(axum::serve(metrics_listener, app) .with_graceful_shutdown(shutdown_signal()) - .await + .await?) } diff --git a/tests/test_app.rs b/tests/test_app.rs index 763bb57..0a463b5 100644 --- a/tests/test_app.rs +++ b/tests/test_app.rs @@ -34,7 +34,7 @@ fn generate(templates: &str, host: &str, url: &str) -> String { }; let iocaine = Iocaine::new(config).unwrap(); - AssembledStatisticalSequences::generate(&iocaine, host, url) + AssembledStatisticalSequences::generate(&iocaine, host, url).unwrap() } #[test] @@ -95,7 +95,8 @@ fn test_templates_builtin() { &iocaine, "test.example.com", "/builtin-templates/", - ); + ) + .unwrap(); assert!(result.contains("/builtin-templates/")); assert!(result.contains("a href=\"../\""));