Skip to content

Commit

Permalink
add callback tests
Browse files Browse the repository at this point in the history
Signed-off-by: Grossberger Lukas (CR/AIR2.2) <Lukas.Grossberger@de.bosch.com>
  • Loading branch information
LGro committed Feb 8, 2024
1 parent 1fcaef7 commit a4590cf
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions tests/optimization_loops/file_based_distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from functools import partial
from threading import Thread

import parameterspace as ps

from blackboxopt import Evaluation, EvaluationSpecification
from blackboxopt.optimization_loops import testing
from blackboxopt.optimization_loops.file_based_distributed import (
evaluate_specifications,
run_optimization_loop,
Expand All @@ -16,12 +17,12 @@


def test_successful_loop(tmpdir):
space = ps.ParameterSpace()
space.add(ps.ContinuousParameter("p", (-10, 10)))
opt = SpaceFilling(space, objectives=[Objective("loss", greater_is_better=False)])
opt = SpaceFilling(
testing.SPACE, objectives=[Objective("loss", greater_is_better=False)]
)

def eval_func(spec: EvaluationSpecification) -> Evaluation:
return spec.create_evaluation({"loss": spec.configuration["p"] ** 2})
def eval_func(eval_spec: EvaluationSpecification) -> Evaluation:
return eval_spec.create_evaluation({"loss": eval_spec.configuration["p1"] ** 2})

max_evaluations = 3

Expand All @@ -48,11 +49,11 @@ def eval_func(spec: EvaluationSpecification) -> Evaluation:


def test_failed_evaluations(tmpdir):
space = ps.ParameterSpace()
space.add(ps.ContinuousParameter("p", (-10, 10)))
opt = SpaceFilling(space, objectives=[Objective("loss", greater_is_better=False)])
opt = SpaceFilling(
testing.SPACE, objectives=[Objective("loss", greater_is_better=False)]
)

def eval_func(spec: EvaluationSpecification) -> Evaluation:
def eval_func(eval_spec: EvaluationSpecification) -> Evaluation:
raise ValueError("This is a test error to make the evaluation fail.")

max_evaluations = 3
Expand All @@ -78,3 +79,57 @@ def eval_func(spec: EvaluationSpecification) -> Evaluation:
assert evaluations[1].objectives[opt.objectives[0].name] is None
assert evaluations[2].objectives[opt.objectives[0].name] is None
thread.join()


def test_callbacks(tmpdir):
from_callback = defaultdict(list)

def callback(e: Evaluation, callback_name: str):
from_callback[callback_name].append(e)

def eval_func(eval_spec: EvaluationSpecification) -> Evaluation:
return eval_spec.create_evaluation({"loss": eval_spec.configuration["p1"] ** 2})

max_evaluations = 3
opt = SpaceFilling(
testing.SPACE, objectives=[Objective("loss", greater_is_better=False)]
)
thread = Thread(
target=evaluate_specifications,
kwargs=dict(
target_directory=tmpdir,
evaluation_function=eval_func,
objectives=opt.objectives,
max_evaluations=max_evaluations,
pre_evaluation_callback=partial(callback, callback_name="evaluate_pre"),
post_evaluation_callback=partial(callback, callback_name="evaluate_post"),
),
)
thread.start()

evaluations = run_optimization_loop(
optimizer=opt,
target_directory=tmpdir,
max_evaluations=max_evaluations,
pre_evaluation_callback=partial(callback, callback_name="run_loop_pre"),
post_evaluation_callback=partial(callback, callback_name="run_loop_post"),
)

# NOTE: These are set comparisons instead of list comparisons because the order
# of the evaluations is not guaranteed.
assert len(evaluations) == len(from_callback["evaluate_post"])
assert set([e.to_json() for e in evaluations]) == set(
[e.to_json() for e in from_callback["evaluate_post"]]
)
assert len(evaluations) == len(from_callback["run_loop_post"])
assert set([e.to_json() for e in evaluations]) == set(
[e.to_json() for e in from_callback["run_loop_post"]]
)
assert len(evaluations) == len(from_callback["evaluate_pre"])
assert set([e.get_specification().to_json() for e in evaluations]) == set(
[es.to_json() for es in from_callback["evaluate_pre"]]
)
assert len(evaluations) == len(from_callback["run_loop_pre"])
assert set([e.get_specification().to_json() for e in evaluations]) == set(
[es.to_json() for es in from_callback["run_loop_pre"]]
)

0 comments on commit a4590cf

Please sign in to comment.