Skip to content

Commit

Permalink
fix: error in pop monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 23, 2024
1 parent 2092e7b commit b35dce8
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/evox/monitors/pop_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import jax
import jax.numpy as jnp
from jax.experimental import io_callback
from jax.sharding import SingleDeviceSharding

from evox import Monitor
from evox.vis_tools import plot
Expand Down Expand Up @@ -37,15 +39,11 @@ def __init__(
self,
population_name="population",
fitness_name="fitness",
to_host=False,
fitness_only=False,
):
super().__init__()
self.population_name = population_name
self.fitness_name = fitness_name
self.to_host = to_host
if to_host:
self.host = jax.devices("cpu")[0]
self.population_history = []
self.fitness_history = []
self.fitness_only = fitness_only
Expand All @@ -54,17 +52,26 @@ def hooks(self):
return ["post_step"]

def post_step(self, state):
monitor_device = SingleDeviceSharding(jax.devices()[0])
if not self.fitness_only:
population = getattr(
state.get_child_state("algorithm"), self.population_name
)
if self.to_host:
population = jax.device_put(population, self.host)
self.population_history.append(population)
else:
population = None

fitness = getattr(state.get_child_state("algorithm"), self.fitness_name)
if self.to_host:
fitness = jax.device_put(fitness, self.host)
io_callback(
self._record,
None,
population,
fitness,
sharding=monitor_device,
)

def _record(self, population, fitness):
if population is not None:
self.population_history.append(population)
self.fitness_history.append(fitness)

def plot(self, problem_pf=None, **kwargs):
Expand All @@ -86,6 +93,12 @@ def plot(self, problem_pf=None, **kwargs):
else:
warnings.warn("Not supported yet.")

def get_latest_fitness(self):
return self.fitness_history[-1]

def get_latest_population(self):
return self.population_history[-1]

def get_population_history(self):
return self.population_history

Expand Down

0 comments on commit b35dce8

Please sign in to comment.