From 4a28c65fc1d46d5357755bc44a1e1b58bdd922bd Mon Sep 17 00:00:00 2001 From: Stefano Simonelli <16114781+s-simoncelli@users.noreply.github.com> Date: Fri, 2 Aug 2024 18:00:47 +0100 Subject: [PATCH] Added test_with_retries macro to repeat GA tests. Some tests may fail due to the randomness in the solutions --- Cargo.lock | 10 +++++++++ Cargo.toml | 2 +- optirustic-macros/Cargo.toml | 12 ++++++++++ optirustic-macros/src/lib.rs | 35 ++++++++++++++++++++++++++++++ optirustic/Cargo.toml | 1 + optirustic/src/algorithms/nsga2.rs | 29 ++++++++++++++----------- 6 files changed, 75 insertions(+), 14 deletions(-) create mode 100644 optirustic-macros/Cargo.toml create mode 100644 optirustic-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 4268b82..f854fe0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -724,6 +724,7 @@ dependencies = [ "hv-fonseca-et-al-2006-sys", "hv-wfg-sys", "log", + "optirustic-macros", "ordered-float", "plotters", "rand", @@ -734,6 +735,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "optirustic-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ordered-float" version = "4.2.1" diff --git a/Cargo.toml b/Cargo.toml index 4fbc4bf..59607b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["optirustic", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"] +members = ["optirustic", "optirustic-macros", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"] default-members = ["optirustic", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"] # Run test with optimisation to speed up tests solving optimisation problems. diff --git a/optirustic-macros/Cargo.toml b/optirustic-macros/Cargo.toml new file mode 100644 index 0000000..bb624cf --- /dev/null +++ b/optirustic-macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "optirustic-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "*", features = ["full"] } +quote = "*" +proc-macro2 = "*" \ No newline at end of file diff --git a/optirustic-macros/src/lib.rs b/optirustic-macros/src/lib.rs new file mode 100644 index 0000000..3628693 --- /dev/null +++ b/optirustic-macros/src/lib.rs @@ -0,0 +1,35 @@ +use proc_macro::TokenStream; + +use quote::quote; +use syn::{parse_macro_input, ItemFn}; + +/// An attribute macro to repeat a test `n` times until the test passes. The test passes if it does +/// not panic once, it fails if it panics `n` times. +#[proc_macro_attribute] +pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream { + let input_fn = parse_macro_input!(item as ItemFn); + let fn_name = &input_fn.sig.ident; + let tries = attrs + .to_string() + .parse::() + .expect("Attr must be an int"); + + let expanded = quote! { + #[test] + fn #fn_name() { + #input_fn + for i in 1..=#tries { + let result = std::panic::catch_unwind(|| { #fn_name() }); + + if result.is_ok() { + return; + } + + if i == #tries { + std::panic::resume_unwind(result.unwrap_err()); + } + }; + } + }; + expanded.into() +} diff --git a/optirustic/Cargo.toml b/optirustic/Cargo.toml index c507519..4ce1f64 100644 --- a/optirustic/Cargo.toml +++ b/optirustic/Cargo.toml @@ -14,6 +14,7 @@ rayon = "1.10.0" env_logger = "0.11.3" chrono = { version = "0.4.38", features = ["serde"] } ordered-float = "4.2.0" +optirustic-macros = { path = "../optirustic-macros" } hv-fonseca-et-al-2006-sys = { path = "../hv-fonseca-et-al-2006-sys" } hv-wfg-sys = { path = "../hv-wfg-sys" } plotters = { version = "0.3.6", optional = true } diff --git a/optirustic/src/algorithms/nsga2.rs b/optirustic/src/algorithms/nsga2.rs index a47a5b8..64fae99 100644 --- a/optirustic/src/algorithms/nsga2.rs +++ b/optirustic/src/algorithms/nsga2.rs @@ -9,10 +9,10 @@ use rand::RngCore; use serde::{Deserialize, Serialize}; use crate::algorithms::{Algorithm, ExportHistory, StoppingConditionType}; +use crate::core::utils::{argsort, get_rng, vector_max, vector_min, Sort}; use crate::core::{ Individual, Individuals, IndividualsMut, OError, Population, Problem, VariableValue, }; -use crate::core::utils::{argsort, get_rng, Sort, vector_max, vector_min}; use crate::operators::{ Crossover, CrowdedComparison, Mutation, PolynomialMutation, PolynomialMutationArgs, Selector, SimulatedBinaryCrossover, SimulatedBinaryCrossoverArgs, TournamentSelector, @@ -47,7 +47,7 @@ pub struct NSGA2Arg { /// Instead of initialising the population with random variables, see the initial population /// with the variable values from a JSON files exported with this tool. This option lets you /// restart the evolution from a previous generation; you can use any history file (exported - /// when the field `export_history`) or the file exported when the stopping condition was reached. + /// when the field `export_history`) or the file exported when the stopping condition was reached. pub resume_from_file: Option, /// The seed used in the random number generator (RNG). You can specify a seed in case you want /// to try to reproduce results. NSGA2 is a stochastic algorithm that relies on a RNG at @@ -441,13 +441,14 @@ impl Algorithm for NSGA2 { &self.args } } + #[cfg(test)] mod test_sorting { use float_cmp::assert_approx_eq; use crate::algorithms::NSGA2; - use crate::core::{Individuals, ObjectiveDirection, VariableValue}; use crate::core::utils::individuals_from_obj_values_dummy; + use crate::core::{Individuals, ObjectiveDirection, VariableValue}; #[test] /// Test the crowding distance algorithm (not enough points). @@ -669,7 +670,9 @@ mod test_sorting { } #[cfg(test)] mod test_problems { - use crate::algorithms::{Algorithm, MaxGeneration, NSGA2, NSGA2Arg, StoppingConditionType}; + use optirustic_macros::test_with_retries; + + use crate::algorithms::{Algorithm, MaxGeneration, NSGA2Arg, StoppingConditionType, NSGA2}; use crate::core::builtin_problems::{ SCHProblem, ZTD1Problem, ZTD2Problem, ZTD3Problem, ZTD4Problem, }; @@ -677,7 +680,8 @@ mod test_problems { const BOUND_TOL: f64 = 1.0 / 1000.0; const LOOSE_BOUND_TOL: f64 = 0.1; - #[test] + + #[test_with_retries(3)] /// Test problem 1 from Deb et al. (2002). Optional solution x in [0; 2] fn test_sch_problem() { let problem = SCHProblem::create().unwrap(); @@ -703,7 +707,7 @@ mod test_problems { } } - #[test] + #[test_with_retries(3)] /// Test the ZTD1 problem from Deb et al. (2002) with 30 variables. Solution x1 in [0; 1] and /// x2 to x30 = 0. The exact solutions are tested using a strict and loose bounds. fn test_ztd1_problem() { @@ -750,7 +754,7 @@ mod test_problems { } } - #[test] + #[test_with_retries(3)] /// Test the ZTD2 problem from Deb et al. (2002) with 30 variables. Solution x1 in [0; 1] and /// x2 to x30 = 0. The exact solutions are tested using a strict and loose bounds. fn test_ztd2_problem() { @@ -802,7 +806,7 @@ mod test_problems { } } - #[test] + #[test_with_retries(3)] /// Test the ZTD3 problem from Deb et al. (2002) with 30 variables. Solution x1 in [0; 1] and /// x2 to x30 = 0. The exact solutions are tested using a strict and loose bounds. fn test_ztd3_problem() { @@ -854,15 +858,13 @@ mod test_problems { } } - #[test] + #[test_with_retries(3)] /// Test the ZTD4 problem from Deb et al. (2002) with 30 variables. Solution x1 in [0; 1] and /// x2 to x10 = 0. The exact solutions are tested using a strict and loose bounds. fn test_ztd4_problem() { let number_of_individuals: usize = 10; - let problem = ZTD4Problem::create(number_of_individuals).unwrap(); let args = NSGA2Arg { number_of_individuals, - // this may take longer to converge stopping_condition: StoppingConditionType::MaxGeneration(MaxGeneration(3000)), crossover_operator_options: None, mutation_operator_options: None, @@ -871,7 +873,8 @@ mod test_problems { resume_from_file: None, seed: Some(1), }; - let mut algo = NSGA2::new(problem, args).unwrap(); + let problem = ZTD4Problem::create(number_of_individuals).unwrap(); + let mut algo = NSGA2::new(problem, args.clone()).unwrap(); algo.run().unwrap(); let results = algo.get_results(); @@ -908,7 +911,7 @@ mod test_problems { } } - #[test] + #[test_with_retries(3)] /// Test the ZTD6 problem from Deb et al. (2002) with 30 variables. Solution x1 in [0; 1] and /// x2 to x10 = 0. The exact solutions are tested using a strict and loose bounds. fn test_ztd6_problem() {