Skip to content

Commit

Permalink
dev: when using multi devices, place sharding constraint on the fitne…
Browse files Browse the repository at this point in the history
…ss value
  • Loading branch information
BillHuang2001 committed Nov 25, 2024
1 parent 1d435b3 commit df6d561
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable, Sequence
import dataclasses
from typing import NamedTuple, Optional, Union, Any
from typing import NamedTuple, Optional, Union
import warnings

import jax
Expand Down Expand Up @@ -39,7 +39,6 @@ class StdWorkflowState:

class MultiDeviceConfig(NamedTuple):
devices: list[jax.Device]
sharding: Any
axis_name: str


Expand Down Expand Up @@ -160,7 +159,10 @@ def _step(self, state):
# when using multi devices
# force the candidates to be sharded along the first axis
cands = jax.lax.with_sharding_constraint(
cands, self.multi_device_config.sharding
cands,
ShardingType.SHARED_FIRST_DIM.get_sharding(
self.multi_device_config.devices
),
)

state = self._post_ask_hook(state, cands)
Expand All @@ -171,6 +173,17 @@ def _step(self, state):

state = self._pre_eval_hook(state, transformed_cands)
fitness, state = self._evaluate(state, transformed_cands)

if self.multi_device_config:
# when using multi devices
# force the fitness to be replicated
fitness = jax.lax.with_sharding_constraint(
fitness,
ShardingType.REPLICATED.get_sharding(
self.multi_device_config.devices
),
)

state = self._post_eval_hook(state, fitness)

transformed_fitness = fitness
Expand Down Expand Up @@ -357,7 +370,6 @@ def enable_multi_devices(
"""
self.multi_device_config = MultiDeviceConfig(
devices=devices,
sharding=ShardingType.SHARED_FIRST_DIM.get_sharding(devices),
axis_name=POP_AXIS_NAME,
)
if not devices:
Expand Down

0 comments on commit df6d561

Please sign in to comment.