diff --git a/tests/test_monitors.py b/tests/test_monitors.py index 3840ed122..5ffeafb39 100644 --- a/tests/test_monitors.py +++ b/tests/test_monitors.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import pytest from evox import workflows, algorithms, problems -from evox.monitors import StdSOMonitor, StdMOMonitor, EvalMonitor +from evox.monitors import StdSOMonitor, StdMOMonitor, EvalMonitor, PopMonitor @pytest.mark.parametrize("topk", [1, 2, 4]) @@ -84,3 +84,17 @@ def test_eval_monitor_with_so(full_fit_history, full_sol_history, topk): assert (monitor.get_topk_fitness() == fitness2[-topk:][::-1]).all() assert (monitor.get_best_solution() == pop2[-1]).all() assert (monitor.get_topk_solutions() == pop2[-topk:][::-1]).all() + + +@pytest.mark.parametrize("fitness_only", [True, False]) +def test_pop_monitor(fitness_only): + monitor = PopMonitor(fitness_only=fitness_only) + algorithm = algorithms.CSO(lb=jnp.zeros((5,)), ub=jnp.ones((5,)), pop_size=4) + problem = problems.numerical.Sphere() + workflow = workflows.StdWorkflow(algorithm, problem, monitors=[monitor]) + key = jax.random.PRNGKey(0) + state = workflow.init(key) + state = workflow.step(state) + assert (monitor.get_latest_fitness() == state.get_child_state("algorithm").fitness).all() + if not fitness_only: + assert (monitor.get_latest_population() == state.get_child_state("algorithm").population).all()