Skip to content

Commit

Permalink
Merge pull request #9 from s-simoncelli/add-macros
Browse files Browse the repository at this point in the history
Add macros
  • Loading branch information
s-simoncelli committed Aug 3, 2024
2 parents 3d22e53 + c467770 commit f800df7
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 94 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["optirustic", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"]
members = ["optirustic", "optirustic-macros", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"]
default-members = ["optirustic", "hv-fonseca-et-al-2006-sys", "hv-wfg-sys"]

# Run test with optimisation to speed up tests solving optimisation problems.
Expand Down
12 changes: 12 additions & 0 deletions optirustic-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "optirustic-macros"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = { version = "*", features = ["full"] }
quote = "*"
proc-macro2 = "*"
285 changes: 285 additions & 0 deletions optirustic-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
use syn::{parse_macro_input, DeriveInput, ItemFn};

/// An attribute macro to repeat a test `n` times until the test passes. The test passes if it does
/// not panic at least once, it fails if it panics `n` times.
#[proc_macro_attribute]
pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let fn_name = &input_fn.sig.ident;
let tries = attrs
.to_string()
.parse::<u8>()
.expect("Attr must be an int");

let expanded = quote! {
#[test]
fn #fn_name() {
#input_fn
for i in 1..=#tries {
let result = std::panic::catch_unwind(|| { #fn_name() });

if result.is_ok() {
return;
}

if i == #tries {
std::panic::resume_unwind(result.unwrap_err());
}
};
}
};
expanded.into()
}

/// Register new fields on a struct that contains algorithm options. This macro adds:
/// - the Serialize, Deserialize, Clone traits to the structure to make it serialisable and
/// de-serialisable.
/// - add the following fields: stopping_condition ([`StoppingConditionType`]), parallel (`bool`)
/// and export_history (`Option<ExportHistory>`).
#[proc_macro_attribute]
pub fn as_algorithm_args(_attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as DeriveInput);
match &mut ast.data {
syn::Data::Struct(ref mut struct_data) => {
if let syn::Fields::Named(fields) = &mut struct_data.fields {
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The condition to use when to terminate the algorithm.
pub stopping_condition: StoppingConditionType
})
.expect("Cannot add `stopping_condition` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// Whether the objective and constraint evaluation in [`Problem::evaluator`] should run
/// using threads. If the evaluation function takes a long time to run and return the updated
/// values, it is advisable to set this to `true`. This defaults to `true`.
pub parallel: Option<bool>
})
.expect("Cannot add `parallel` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The options to configure the individual's history export. When provided, the algorithm will
/// save objectives, constraints and solutions to a file each time the generation increases by
/// a given step. This is useful to track convergence and inspect an algorithm evolution.
pub export_history: Option<ExportHistory>
})
.expect("Cannot add `export_history` field"),
);
}

let expand = quote! {
use crate::algorithms::{StoppingConditionType, ExportHistory};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Clone)]
#ast
};
expand.into()
}
_ => unimplemented!("`as_algorithm_args` can only be used on structs"),
}
}

/// This macro adds the following private fields to the struct defining an algorithm:
/// `problem`, `number_of_individuals`, `population`, `generation`,`stopping_condition`,
/// `start_time`, `export_history` and `parallel`.
///
/// It also implements the `Display` trait.
///
#[proc_macro_attribute]
pub fn as_algorithm(_attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;

match &mut ast.data {
syn::Data::Struct(ref mut struct_data) => {
if let syn::Fields::Named(fields) = &mut struct_data.fields {
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The problem being solved.
problem: Arc<Problem>
})
.expect("Cannot add `problem` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The number of individuals to use in the population.
number_of_individuals: usize
})
.expect("Cannot add `number_of_individuals` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The population with the solutions.
population: Population
})
.expect("Cannot add `population` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The evolution step.
generation: usize
})
.expect("Cannot add `generation` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The stopping condition.
stopping_condition: StoppingConditionType
})
.expect("Cannot add `stopping_condition` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The time when the algorithm started.
start_time: Instant
})
.expect("Cannot add `start_time` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// The configuration struct to export the algorithm history.
export_history: Option<ExportHistory>
})
.expect("Cannot add `export_history` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
/// Whether the evaluation should run using threads
parallel: bool
})
.expect("Cannot add `parallel` field"),
);
}

let expand = quote! {
use std::time::Instant;
use std::sync::Arc;
use crate::core::{Problem, Population};

#ast

impl Display for #name {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name().as_str())
}
}
};
expand.into()
}
_ => unimplemented!("`as_algorithm` can only be used on structs"),
}
}

/// This macro adds common items when the `Algorithm` trait is implemented for a new algorithm
/// struct. This adds the following items: `Algorithm::name()`, `Algorithm::stopping_condition()`
/// `Algorithm::start_time()`, `Algorithm::problem()`, `Algorithm::population()`,
/// `Algorithm::generation()` and `Algorithm::export_history()`.
///
#[proc_macro_attribute]
pub fn impl_algorithm_trait_items(attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as syn::ItemImpl);
let name = if let syn::Type::Path(tp) = &*ast.self_ty {
tp.path.clone()
} else {
unimplemented!("Token not supported")
};
let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
.parse(attrs)
.expect("Cannot parse argument type");

let mut new_items = vec![
syn::parse::<syn::ImplItem>(
quote!(
fn stopping_condition(&self) -> &StoppingConditionType {
&self.stopping_condition
}
)
.into(),
)
.expect("Failed to parse `name` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn name(&self) -> String {
stringify!(#name).to_string()
}
)
.into(),
)
.expect("Failed to parse `name` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn start_time(&self) -> &Instant {
&self.start_time
}
)
.into(),
)
.expect("Failed to parse `start_time` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn problem(&self) -> Arc<Problem> {
self.problem.clone()
}
)
.into(),
)
.expect("Failed to parse `problem` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn population(&self) -> &Population {
&self.population
}
)
.into(),
)
.expect("Failed to parse `population` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn export_history(&self) -> Option<&ExportHistory> {
self.export_history.as_ref()
}
)
.into(),
)
.expect("Failed to parse `export_history` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn generation(&self) -> usize {
self.generation
}
)
.into(),
)
.expect("Failed to parse `export_history` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn algorithm_options(&self) -> &#arg_type {
&self.args
}
)
.into(),
)
.expect("Failed to parse `algorithm_options` item"),
];

ast.items.append(&mut new_items);
let expand = quote! { #ast };
expand.into()
}
1 change: 1 addition & 0 deletions optirustic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ rayon = "1.10.0"
env_logger = "0.11.3"
chrono = { version = "0.4.38", features = ["serde"] }
ordered-float = "4.2.0"
optirustic-macros = { path = "../optirustic-macros" }
hv-fonseca-et-al-2006-sys = { path = "../hv-fonseca-et-al-2006-sys" }
hv-wfg-sys = { path = "../hv-wfg-sys" }
plotters = { version = "0.3.6", optional = true }
Expand Down
8 changes: 4 additions & 4 deletions optirustic/src/algorithms/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{fmt, fs};
use std::fmt::{Debug, Display, Formatter};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use std::{fmt, fs};

use log::{debug, info};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use crate::algorithms::{StoppingCondition, StoppingConditionType};
use crate::core::{Individual, IndividualExport, OError, Population, Problem, ProblemExport};
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ExportHistory {
/// * `generation_step`: export the algorithm data each time the generation counter in a genetic
// algorithm increases by the provided step.
/// * `destination`: serialise the algorithm history and export the results to a JSON file in
/// the given folder.
/// the given folder.
///
/// returns: `Result<ExportHistory, OError>`
pub fn new(generation_step: usize, destination: &PathBuf) -> Result<Self, OError> {
Expand Down Expand Up @@ -409,7 +409,7 @@ pub trait Algorithm<AlgorithmOptions: Serialize + DeserializeOwned>: Display {
/// * `problem`: The problem.
/// * `name`: The algorithm name.
/// * `expected_individuals`: The number of individuals to expect in the file. If this does not
/// match the population size, being used in the algorithm, an error is thrown.
/// match the population size, being used in the algorithm, an error is thrown.
/// * `file`: The path to the JSON file exported from this library.
///
/// returns: `Result<Population, OError>`
Expand Down
Loading

0 comments on commit f800df7

Please sign in to comment.