diff --git a/optirustic-macros/src/lib.rs b/optirustic-macros/src/lib.rs index 3628693..669fd26 100644 --- a/optirustic-macros/src/lib.rs +++ b/optirustic-macros/src/lib.rs @@ -1,10 +1,10 @@ use proc_macro::TokenStream; - use quote::quote; -use syn::{parse_macro_input, ItemFn}; +use syn::parse::Parser; +use syn::{parse_macro_input, DeriveInput, ItemFn}; /// An attribute macro to repeat a test `n` times until the test passes. The test passes if it does -/// not panic once, it fails if it panics `n` times. +/// not panic at least once, it fails if it panics `n` times. #[proc_macro_attribute] pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream { let input_fn = parse_macro_input!(item as ItemFn); @@ -33,3 +33,253 @@ pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream { }; expanded.into() } + +/// Register new fields on a struct that contains algorithm options. This macro adds: +/// - the Serialize, Deserialize, Clone traits to the structure to make it serialisable and +/// de-serialisable. +/// - add the following fields: stopping_condition ([`StoppingConditionType`]), parallel (`bool`) +/// and export_history (`Option`). +#[proc_macro_attribute] +pub fn as_algorithm_args(_attrs: TokenStream, input: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(input as DeriveInput); + match &mut ast.data { + syn::Data::Struct(ref mut struct_data) => { + if let syn::Fields::Named(fields) = &mut struct_data.fields { + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The condition to use when to terminate the algorithm. + pub stopping_condition: StoppingConditionType + }) + .expect("Cannot add `stopping_condition` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// Whether the objective and constraint evaluation in [`Problem::evaluator`] should run + /// using threads. If the evaluation function takes a long time to run and return the updated + /// values, it is advisable to set this to `true`. This defaults to `true`. + pub parallel: Option + }) + .expect("Cannot add `parallel` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The options to configure the individual's history export. When provided, the algorithm will + /// save objectives, constraints and solutions to a file each time the generation increases by + /// a given step. This is useful to track convergence and inspect an algorithm evolution. + pub export_history: Option + }) + .expect("Cannot add `export_history` field"), + ); + } + + let expand = quote! { + use crate::algorithms::{StoppingConditionType, ExportHistory}; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Clone)] + #ast + }; + expand.into() + } + _ => unimplemented!("`as_algorithm_args` can only be used on structs"), + } +} + +/// This macro adds the following private fields to the struct defining an algorithm: +/// `problem`, `number_of_individuals`, `population`, `generation`,`stopping_condition`, +/// `start_time`, `export_history` and `parallel`. +/// +/// It also implements the `Display` trait. +/// +#[proc_macro_attribute] +pub fn as_algorithm(_attrs: TokenStream, input: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + + match &mut ast.data { + syn::Data::Struct(ref mut struct_data) => { + if let syn::Fields::Named(fields) = &mut struct_data.fields { + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The problem being solved. + problem: Arc + }) + .expect("Cannot add `problem` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The number of individuals to use in the population. + number_of_individuals: usize + }) + .expect("Cannot add `number_of_individuals` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The population with the solutions. + population: Population + }) + .expect("Cannot add `population` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The evolution step. + generation: usize + }) + .expect("Cannot add `generation` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The stopping condition. + stopping_condition: StoppingConditionType + }) + .expect("Cannot add `stopping_condition` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The time when the algorithm started. + start_time: Instant + }) + .expect("Cannot add `start_time` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// The configuration struct to export the algorithm history. + export_history: Option + }) + .expect("Cannot add `export_history` field"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote! { + /// Whether the evaluation should run using threads + parallel: bool + }) + .expect("Cannot add `parallel` field"), + ); + } + + let expand = quote! { + use std::time::Instant; + use std::sync::Arc; + use crate::core::{Problem, Population}; + + #ast + + impl Display for #name { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(self.name().as_str()) + } + } + }; + expand.into() + } + _ => unimplemented!("`as_algorithm` can only be used on structs"), + } +} + +/// This macro adds common items when the `Algorithm` trait is implemented for a new algorithm +/// struct. This adds the following items: `Algorithm::name()`, `Algorithm::stopping_condition()` +/// `Algorithm::start_time()`, `Algorithm::problem()`, `Algorithm::population()`, +/// `Algorithm::generation()` and `Algorithm::export_history()`. +/// +#[proc_macro_attribute] +pub fn impl_algorithm_trait_items(attrs: TokenStream, input: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(input as syn::ItemImpl); + let name = if let syn::Type::Path(tp) = &*ast.self_ty { + tp.path.clone() + } else { + unimplemented!("Token not supported") + }; + let arg_type = syn::punctuated::Punctuated::::parse_terminated + .parse(attrs) + .expect("Cannot parse argument type"); + + let mut new_items = vec![ + syn::parse::( + quote!( + fn stopping_condition(&self) -> &StoppingConditionType { + &self.stopping_condition + } + ) + .into(), + ) + .expect("Failed to parse `name` item"), + syn::parse::( + quote!( + fn name(&self) -> String { + stringify!(#name).to_string() + } + ) + .into(), + ) + .expect("Failed to parse `name` item"), + syn::parse::( + quote!( + fn start_time(&self) -> &Instant { + &self.start_time + } + ) + .into(), + ) + .expect("Failed to parse `start_time` item"), + syn::parse::( + quote!( + fn problem(&self) -> Arc { + self.problem.clone() + } + ) + .into(), + ) + .expect("Failed to parse `problem` item"), + syn::parse::( + quote!( + fn population(&self) -> &Population { + &self.population + } + ) + .into(), + ) + .expect("Failed to parse `population` item"), + syn::parse::( + quote!( + fn export_history(&self) -> Option<&ExportHistory> { + self.export_history.as_ref() + } + ) + .into(), + ) + .expect("Failed to parse `export_history` item"), + syn::parse::( + quote!( + fn generation(&self) -> usize { + self.generation + } + ) + .into(), + ) + .expect("Failed to parse `export_history` item"), + syn::parse::( + quote!( + fn algorithm_options(&self) -> &#arg_type { + &self.args + } + ) + .into(), + ) + .expect("Failed to parse `algorithm_options` item"), + ]; + + ast.items.append(&mut new_items); + let expand = quote! { #ast }; + expand.into() +} diff --git a/optirustic/src/algorithms/algorithm.rs b/optirustic/src/algorithms/algorithm.rs index bc24b30..621074d 100644 --- a/optirustic/src/algorithms/algorithm.rs +++ b/optirustic/src/algorithms/algorithm.rs @@ -1,13 +1,13 @@ -use std::{fmt, fs}; use std::fmt::{Debug, Display, Formatter}; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; +use std::{fmt, fs}; use log::{debug, info}; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use crate::algorithms::{StoppingCondition, StoppingConditionType}; use crate::core::{Individual, IndividualExport, OError, Population, Problem, ProblemExport}; @@ -88,7 +88,7 @@ impl ExportHistory { /// * `generation_step`: export the algorithm data each time the generation counter in a genetic // algorithm increases by the provided step. /// * `destination`: serialise the algorithm history and export the results to a JSON file in - /// the given folder. + /// the given folder. /// /// returns: `Result` pub fn new(generation_step: usize, destination: &PathBuf) -> Result { @@ -409,7 +409,7 @@ pub trait Algorithm: Display { /// * `problem`: The problem. /// * `name`: The algorithm name. /// * `expected_individuals`: The number of individuals to expect in the file. If this does not - /// match the population size, being used in the algorithm, an error is thrown. + /// match the population size, being used in the algorithm, an error is thrown. /// * `file`: The path to the JSON file exported from this library. /// /// returns: `Result` diff --git a/optirustic/src/algorithms/nsga2.rs b/optirustic/src/algorithms/nsga2.rs index 64fae99..ab98a1f 100644 --- a/optirustic/src/algorithms/nsga2.rs +++ b/optirustic/src/algorithms/nsga2.rs @@ -1,26 +1,21 @@ use std::fmt::{Display, Formatter}; use std::ops::Rem; use std::path::PathBuf; -use std::sync::Arc; -use std::time::Instant; -use log::{debug, info}; -use rand::RngCore; -use serde::{Deserialize, Serialize}; - -use crate::algorithms::{Algorithm, ExportHistory, StoppingConditionType}; +use crate::algorithms::Algorithm; use crate::core::utils::{argsort, get_rng, vector_max, vector_min, Sort}; -use crate::core::{ - Individual, Individuals, IndividualsMut, OError, Population, Problem, VariableValue, -}; +use crate::core::{Individual, Individuals, IndividualsMut, OError, VariableValue}; use crate::operators::{ Crossover, CrowdedComparison, Mutation, PolynomialMutation, PolynomialMutationArgs, Selector, SimulatedBinaryCrossover, SimulatedBinaryCrossoverArgs, TournamentSelector, }; use crate::utils::fast_non_dominated_sort; +use log::{debug, info}; +use optirustic_macros::{as_algorithm, as_algorithm_args, impl_algorithm_trait_items}; +use rand::RngCore; /// Input arguments for the NSGA2 algorithm. -#[derive(Serialize, Deserialize, Clone)] +#[as_algorithm_args] pub struct NSGA2Arg { /// The number of individuals to use in the population. This must be a multiple of `2`. pub number_of_individuals: usize, @@ -34,16 +29,6 @@ pub struct NSGA2Arg { /// divided by the number of real variables in the problem (i.e., each variable will have the /// same probability of being mutated). pub mutation_operator_options: Option, - /// The condition to use when to terminate the algorithm. - pub stopping_condition: StoppingConditionType, - /// Whether the objective and constraint evaluation in [`Problem::evaluator`] should run - /// using threads. If the evaluation function takes a long time to run and return the updated - /// values, it is advisable to set this to `true`. This defaults to `true`. - pub parallel: Option, - /// The options to configure the individual's history export. When provided, the algorithm will - /// save objectives, constraints and solutions to a file each time the generation increases by - /// a given step. This is useful to track convergence and inspect an algorithm evolution. - pub export_history: Option, /// Instead of initialising the population with random variables, see the initial population /// with the variable values from a JSON files exported with this tool. This option lets you /// restart the evolution from a previous generation; you can use any history file (exported @@ -75,13 +60,8 @@ pub struct NSGA2Arg { /// ```rust #[doc = include_str!("../../examples/nsga2_zdt1.rs")] /// ``` +#[as_algorithm] pub struct NSGA2 { - /// The number of individuals to use in the population. - number_of_individuals: usize, - /// The population with the solutions. - population: Population, - /// The problem being solved. - problem: Arc, /// The operator to use to select the individuals for reproduction. selector_operator: TournamentSelector, /// The operator to use to generate a new children by recombining the variables of parent @@ -90,28 +70,12 @@ pub struct NSGA2 { crossover_operator: SimulatedBinaryCrossover, /// The operator to use to mutate the variables of an individual. mutation_operator: PolynomialMutation, - /// The evolution step. - generation: usize, - /// The stopping condition. - stopping_condition: StoppingConditionType, - /// The time when the algorithm started. - start_time: Instant, - /// The configuration struct to export the algorithm history. - export_history: Option, - /// Whether the evaluation should run using threads - parallel: bool, /// The seed to use. rng: Box, /// The algorithm options args: NSGA2Arg, } -impl Display for NSGA2 { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(self.name().as_str()) - } -} - impl NSGA2 { /// Initialise the NSGA2 algorithm. /// @@ -286,6 +250,7 @@ impl NSGA2 { } /// Implementation of Section IIIC of the paper. +#[impl_algorithm_trait_items(NSGA2Arg)] impl Algorithm for NSGA2 { /// This assesses the initial random population and sets the individual's ranks and crowding /// distance needed in [`self.evolve`]. @@ -408,38 +373,6 @@ impl Algorithm for NSGA2 { self.generation += 1; Ok(()) } - - fn generation(&self) -> usize { - self.generation - } - - fn name(&self) -> String { - "NSGA2".to_string() - } - - fn start_time(&self) -> &Instant { - &self.start_time - } - - fn stopping_condition(&self) -> &StoppingConditionType { - &self.stopping_condition - } - - fn population(&self) -> &Population { - &self.population - } - - fn problem(&self) -> Arc { - self.problem.clone() - } - - fn export_history(&self) -> Option<&ExportHistory> { - self.export_history.as_ref() - } - - fn algorithm_options(&self) -> &NSGA2Arg { - &self.args - } } #[cfg(test)]