Skip to content

Commit

Permalink
test: add unit test for popmonitor
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 23, 2024
1 parent b35dce8 commit 24d9d2e
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/test_monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import pytest
from evox import workflows, algorithms, problems
from evox.monitors import StdSOMonitor, StdMOMonitor, EvalMonitor
from evox.monitors import StdSOMonitor, StdMOMonitor, EvalMonitor, PopMonitor


@pytest.mark.parametrize("topk", [1, 2, 4])
Expand Down Expand Up @@ -84,3 +84,17 @@ def test_eval_monitor_with_so(full_fit_history, full_sol_history, topk):
assert (monitor.get_topk_fitness() == fitness2[-topk:][::-1]).all()
assert (monitor.get_best_solution() == pop2[-1]).all()
assert (monitor.get_topk_solutions() == pop2[-topk:][::-1]).all()


@pytest.mark.parametrize("fitness_only", [True, False])
def test_pop_monitor(fitness_only):
monitor = PopMonitor(fitness_only=fitness_only)
algorithm = algorithms.CSO(lb=jnp.zeros((5,)), ub=jnp.ones((5,)), pop_size=4)
problem = problems.numerical.Sphere()
workflow = workflows.StdWorkflow(algorithm, problem, monitors=[monitor])
key = jax.random.PRNGKey(0)
state = workflow.init(key)
state = workflow.step(state)
assert (monitor.get_latest_fitness() == state.get_child_state("algorithm").fitness).all()
if not fitness_only:
assert (monitor.get_latest_population() == state.get_child_state("algorithm").population).all()

0 comments on commit 24d9d2e

Please sign in to comment.