Skip to content

Commit

Permalink
Merge pull request #212 from starquakee/evoxtorch-dev-fcc
Browse files Browse the repository at this point in the history
Brax now supports Single-layer vmap (Hpo Problem)
  • Loading branch information
BillHuang2001 authored Feb 11, 2025
2 parents acb6783 + d5cf36b commit 261e007
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/evox/operators/crossover/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def DE_differential_sum(

select_len = num_diff_vectors.unsqueeze(1) * 2 + 1
rand_indices = torch.randint(0, pop_size, (pop_size, diff_padding_num), device=device)
rand_indices = torch.where(rand_indices == index.unsqueeze(1), torch.tensor(pop_size - 1, device=device), rand_indices)
rand_indices = torch.where(rand_indices == index.unsqueeze(1), pop_size - 1, rand_indices)

pop_permute = population[rand_indices]
mask = torch.arange(diff_padding_num, device=device).unsqueeze(0) < select_len
pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, torch.zeros_like(pop_permute))
pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, 0)

diff_vectors = pop_permute_padding[:, 1:]
difference_sum = diff_vectors[:, 0::2].sum(dim=1) - diff_vectors[:, 1::2].sum(dim=1)
Expand Down
32 changes: 23 additions & 9 deletions src/evox/problems/neuroevolution/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brax import envs
from brax.io import html, image

from ...core import Problem, jit_class
from ...core import Problem, _vmap_fix, jit_class, vmap_impl
from .utils import get_vmap_model_state_forward


Expand Down Expand Up @@ -70,7 +70,8 @@ def __init__(
The initial key is obtained from `torch.random.get_rng_state()`.
## Warning
This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`), i.e., the workflow containing this problem CANNOT be vmapped.
This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`) out-of-box, i.e., the workflow containing this problem CANNOT be vmapped.
*However*, by setting `pop_size` to the multiplication of inner population size and outer population size, you can still use this problem in a HPO workflow.
## Examples
>>> from evox import problems
Expand Down Expand Up @@ -137,13 +138,7 @@ def __init__(
self.rotate_key = rotate_key
self.reduce_fn = reduce_fn

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the final rewards of a population (batch) of model parameters.
:param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.
:return: A tensor of shape (batch_size,) containing the reward of each sample in the population.
"""
def _normal_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
# Unpack parameters and buffers
state_params = {self._param_to_state_key_map[key]: value for key, value in pop_params.items()}
model_state = dict(self._vmap_model_buffers)
Expand All @@ -157,6 +152,25 @@ def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
# Return
return rewards

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the final rewards of a population (batch) of model parameters.
:param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.
:return: A tensor of shape (batch_size,) containing the reward of each sample in the population.
"""
return self._normal_evaluate(pop_params)

@vmap_impl(evaluate)
def _vmap_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
_, vmap_dim, vmap_size = _vmap_fix.unwrap_batch_tensor(list(pop_params.values())[0])
assert vmap_dim == (0,)
vmap_size = vmap_size[0]
pop_params = {k: _vmap_fix.unwrap_batch_tensor(v)[0].view(vmap_size * v.size(0), *v.size()[1:]) for k, v in pop_params.items()}
flat_rewards = self._normal_evaluate(pop_params)
rewards = flat_rewards.view(vmap_size, flat_rewards.size(0) // vmap_size, *flat_rewards.size()[1:])
return _vmap_fix.wrap_batch_tensor(rewards, vmap_dim)

def _model_forward(
self, model_state: Dict[str, torch.Tensor], obs: torch.Tensor, record_trajectory: bool = False
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
Expand Down

0 comments on commit 261e007

Please sign in to comment.