Skip to content

Commit

Permalink
dev: add evoxvision support to monitors
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Nov 22, 2024
1 parent 6d9b586 commit fa45e82
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/evox/monitors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .eval_monitor import EvalMonitor
from .pop_monitor import PopMonitor
from .evoxvis_monitor import EvoXVisMonitor
from .evoxvision_adapter import EvoXVisionAdapter
47 changes: 39 additions & 8 deletions src/evox/monitors/eval_monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
import warnings
import time

Expand All @@ -8,6 +8,7 @@
from evox import Monitor, dataclass, pytree_field
from evox.vis_tools import plot
from evox.operators import non_dominated_sort
from .evoxvision_adapter import EvoXVisionAdapter, new_exv_metadata


@dataclass
Expand Down Expand Up @@ -56,6 +57,9 @@ class EvalMonitor(Monitor):
timestamp_history: list = pytree_field(
static=True, init=False, default_factory=list
)
evoxvision_adapter: Optional[EvoXVisionAdapter] = pytree_field(
default=None, static=True
)

def hooks(self):
return ["post_ask", "post_eval"]
Expand Down Expand Up @@ -110,20 +114,47 @@ def post_eval(self, state, _workflow_state, fitness):
if self.full_fit_history or self.full_sol_history:
return state.register_callback(
self._record_history,
state.latest_solution if self.full_sol_history else None,
fitness if self.full_fit_history else None,
state.latest_solution,
fitness,
)
else:
return state

def _record_history(self, solution, fitness):
# since history is a list, which doesn't have a static shape
# we need to use register_callback to record the history
if self.full_sol_history:
self.solution_history.append(solution)
if self.full_fit_history:
self.fitness_history.append(fitness)
self.timestamp_history.append(time.time())
if self.evoxvision_adapter:
if not self.evoxvision_adapter.header_written:
# wait for the first two iterations
self.solution_history.append(jax.device_get(solution))
self.fitness_history.append(jax.device_get(fitness))
if len(self.solution_history) >= 2:
metadata = new_exv_metadata(
self.solution_history[0],
self.solution_history[1],
self.fitness_history[0],
self.fitness_history[1],
)
self.evoxvision_adapter.set_metadata(metadata)
self.evoxvision_adapter.write_header()
self.evoxvision_adapter.write(
self.solution_history[0].tobytes(),
self.fitness_history[0].tobytes(),
)
self.evoxvision_adapter.write(
self.solution_history[1].tobytes(),
self.fitness_history[1].tobytes(),
)
self.solution_history = []
self.fitness_history = []
else:
self.evoxvision_adapter.write(solution.tobytes(), fitness.tobytes())
else:
if self.full_sol_history:
self.solution_history.append(jax.device_get(solution))
if self.full_fit_history:
self.fitness_history.append(jax.device_get(fitness))
self.timestamp_history.append(time.time())

def get_latest_fitness(self, state) -> Tuple[jax.Array, EvalMonitorState]:
"""Get the fitness values from the latest iteration."""
Expand Down
45 changes: 37 additions & 8 deletions src/evox/monitors/pop_monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
import warnings
import time

Expand All @@ -7,6 +7,7 @@

from evox import Monitor, pytree_field, dataclass
from evox.vis_tools import plot
from .evoxvision_adapter import EvoXVisionAdapter, new_exv_metadata


@dataclass
Expand Down Expand Up @@ -55,6 +56,9 @@ class PopMonitor(Monitor):
timestamp_history: list = pytree_field(
static=True, init=False, default_factory=list
)
evoxvision_adapter: Optional[EvoXVisionAdapter] = pytree_field(
default=None, static=True
)

def hooks(self):
return ["post_step"]
Expand Down Expand Up @@ -90,13 +94,38 @@ def post_step(self, state, workflow_state):
return state.register_callback(self._record_history, population, fitness)

def _record_history(self, population, fitness):
# since history is a list, which doesn't have a static shape
# we need to use register_callback to record the history
if self.full_pop_history:
self.population_history.append(population)
if self.full_fit_history:
self.fitness_history.append(fitness)
self.timestamp_history.append(time.time())
if self.evoxvision_adapter:
if not self.evoxvision_adapter.header_written:
# wait for the first two iterations
self.population_history.append(jax.device_get(population))
self.fitness_history.append(jax.device_get(fitness))
if len(self.population_history) >= 2:
metadata = new_exv_metadata(
self.population_history[0],
self.population_history[1],
self.fitness_history[0],
self.fitness_history[1],
)
self.evoxvision_adapter.set_metadata(metadata)
self.evoxvision_adapter.write_header()
self.evoxvision_adapter.write(
self.population_history[0].tobytes(),
self.fitness_history[0].tobytes(),
)
self.evoxvision_adapter.write(
self.population_history[1].tobytes(),
self.fitness_history[1].tobytes(),
)
self.population_history = []
self.fitness_history = []
else:
self.evoxvision_adapter.write(population.tobytes(), fitness.tobytes())
else:
if self.full_pop_history:
self.population_history.append(population)
if self.full_fit_history:
self.fitness_history.append(fitness)
self.timestamp_history.append(time.time())

def plot(self, state=None, problem_pf=None, **kwargs):
if not self.fitness_history:
Expand Down

0 comments on commit fa45e82

Please sign in to comment.