diff --git a/src/evox/operators/crossover/differential_evolution.py b/src/evox/operators/crossover/differential_evolution.py index 144454561..7da78094b 100644 --- a/src/evox/operators/crossover/differential_evolution.py +++ b/src/evox/operators/crossover/differential_evolution.py @@ -25,11 +25,11 @@ def DE_differential_sum( select_len = num_diff_vectors.unsqueeze(1) * 2 + 1 rand_indices = torch.randint(0, pop_size, (pop_size, diff_padding_num), device=device) - rand_indices = torch.where(rand_indices == index.unsqueeze(1), torch.tensor(pop_size - 1, device=device), rand_indices) + rand_indices = torch.where(rand_indices == index.unsqueeze(1), pop_size - 1, rand_indices) pop_permute = population[rand_indices] mask = torch.arange(diff_padding_num, device=device).unsqueeze(0) < select_len - pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, torch.zeros_like(pop_permute)) + pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, 0) diff_vectors = pop_permute_padding[:, 1:] difference_sum = diff_vectors[:, 0::2].sum(dim=1) - diff_vectors[:, 1::2].sum(dim=1) diff --git a/src/evox/problems/neuroevolution/brax.py b/src/evox/problems/neuroevolution/brax.py index b6658669a..adfbfe75c 100644 --- a/src/evox/problems/neuroevolution/brax.py +++ b/src/evox/problems/neuroevolution/brax.py @@ -11,7 +11,7 @@ from brax import envs from brax.io import html, image -from ...core import Problem, jit_class +from ...core import Problem, _vmap_fix, jit_class, vmap_impl from .utils import get_vmap_model_state_forward @@ -70,7 +70,8 @@ def __init__( The initial key is obtained from `torch.random.get_rng_state()`. ## Warning - This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`), i.e., the workflow containing this problem CANNOT be vmapped. + This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`) out-of-box, i.e., the workflow containing this problem CANNOT be vmapped. + *However*, by setting `pop_size` to the multiplication of inner population size and outer population size, you can still use this problem in a HPO workflow. ## Examples >>> from evox import problems @@ -137,13 +138,7 @@ def __init__( self.rotate_key = rotate_key self.reduce_fn = reduce_fn - def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: - """Evaluate the final rewards of a population (batch) of model parameters. - - :param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models. - - :return: A tensor of shape (batch_size,) containing the reward of each sample in the population. - """ + def _normal_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: # Unpack parameters and buffers state_params = {self._param_to_state_key_map[key]: value for key, value in pop_params.items()} model_state = dict(self._vmap_model_buffers) @@ -157,6 +152,25 @@ def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: # Return return rewards + def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: + """Evaluate the final rewards of a population (batch) of model parameters. + + :param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models. + + :return: A tensor of shape (batch_size,) containing the reward of each sample in the population. + """ + return self._normal_evaluate(pop_params) + + @vmap_impl(evaluate) + def _vmap_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: + _, vmap_dim, vmap_size = _vmap_fix.unwrap_batch_tensor(list(pop_params.values())[0]) + assert vmap_dim == (0,) + vmap_size = vmap_size[0] + pop_params = {k: _vmap_fix.unwrap_batch_tensor(v)[0].view(vmap_size * v.size(0), *v.size()[1:]) for k, v in pop_params.items()} + flat_rewards = self._normal_evaluate(pop_params) + rewards = flat_rewards.view(vmap_size, flat_rewards.size(0) // vmap_size, *flat_rewards.size()[1:]) + return _vmap_fix.wrap_batch_tensor(rewards, vmap_dim) + def _model_forward( self, model_state: Dict[str, torch.Tensor], obs: torch.Tensor, record_trajectory: bool = False ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: