diff --git a/src/evox/workflows/distributed.py b/src/evox/workflows/distributed.py index 05a4bcb5..dc63b013 100644 --- a/src/evox/workflows/distributed.py +++ b/src/evox/workflows/distributed.py @@ -6,8 +6,16 @@ import jax.numpy as jnp import ray -from evox import Algorithm, Problem, State, Workflow, use_state -from evox.utils import algorithm_has_init_ask, parse_opt_direction +from evox import ( + Algorithm, + Problem, + State, + Workflow, + use_state, + has_init_ask, + has_init_tell, +) +from evox.utils import parse_opt_direction class WorkerWorkflow(Workflow): @@ -36,7 +44,7 @@ def __init__( self.fit_transforms = fit_transforms def setup(self, key): - return State(generation=0) + return State(generation=0, first_step=True) def _get_slice(self, pop_size): slice_per_worker = pop_size // self.num_workers @@ -45,20 +53,23 @@ def _get_slice(self, pop_size): end = start + slice_per_worker + (self.worker_index < remainder) return start, end + def _ask(self, state): + if has_init_ask(self.algorithm) and state.first_step: + ask = self.algorithm.init_ask + else: + ask = self.algorithm.ask + + # candidate: individuals that need to be evaluated (may differ from population) + # Note: num_cands can be different from init_ask() and ask() + cands, state = use_state(ask)(state) + + return cands, state + def step1(self, state: State): if "pre_ask" in self.non_empty_hooks: ray.get(self.monitor_actor.push.remote("pre_ask", state)) - if state.generation == 0: - is_init = algorithm_has_init_ask(self.algorithm, state) - else: - is_init = False - - if is_init: - cand_sol, state = use_state(self.algorithm.init_ask)(state) - else: - cand_sol, state = use_state(self.algorithm.ask)(state) - + cand_sol, state = self._ask(state) if "post_ask" in self.non_empty_hooks: ray.get(self.monitor_actor.push.remote("post_ask", None, cand_sol)) @@ -82,12 +93,17 @@ def step1(self, state: State): return partial_fitness, state - def step2(self, state: State, fitness: List[jax.Array]): - if state.generation == 0: - is_init = algorithm_has_init_ask(self.algorithm, state) + def _tell(self, state, transformed_fitness): + if has_init_tell(self.algorithm) and state.first_step: + tell = self.algorithm.init_tell else: - is_init = False + tell = self.algorithm.tell + state = use_state(tell)(state, transformed_fitness) + + return state + + def step2(self, state: State, fitness: List[jax.Array]): fitness = jnp.concatenate(fitness, axis=0) fitness = fitness * self.opt_direction @@ -112,15 +128,19 @@ def step2(self, state: State, fitness: List[jax.Array]): ) ) - if is_init: - state = use_state(self.algorithm.init_tell)(state, fitness) - else: - state = use_state(self.algorithm.tell)(state, fitness) + state = self._tell(state, fitness) if "post_tell" in self.non_empty_hooks: ray.get(self.monitor_actor.push.remote("post_tell", state)) - - return state.update(generation=state.generation + 1) + + + if has_init_ask(self.algorithm) and state.first_step: + # this ensures that _step() will be re-jitted + state = state.replace(generation=state.generation + 1, first_step=False) + else: + state = state.replace(generation=state.generation + 1) + + return state def valid(self, state: State, metric: str): new_state = use_state(self.problem.valid)(state, metric=metric)