-
Notifications
You must be signed in to change notification settings - Fork 100
fix de.py selection bug and add ode.py #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sses7757
merged 13 commits into
EMI-Group:evoxtorch-main
from
starquakee:evoxtorch-dev-fcc
Jan 11, 2025
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
4837273
fix de.py selection bug and add ode.py
starquakee e85bc60
black de.py and ode.py
starquakee d3232ab
docstring is added for de.py and ode.py, as well as TODO which can be…
starquakee 916ea98
merge upstream/evoxtorch-main to resolve conflicts and fix docstring …
starquakee 211355d
ruff format de.py and ode.py
starquakee 4284c69
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
starquakee 3af18d9
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
starquakee 4d84996
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
starquakee 2ded7b5
add ode in __init__.py
starquakee 33ef291
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
starquakee ffba734
add ode in __init__.py
starquakee 9c05b12
ruff format de.py and ode.py
starquakee eb710a8
add ode in init file
starquakee File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__all__ = ["DE"] | ||
__all__ = ["DE", "ODE"] | ||
|
||
|
||
from .de import DE | ||
from .ode import ODE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from typing import Literal | ||
|
||
import torch | ||
|
||
from ...core import Algorithm, Mutable, Parameter, jit_class | ||
from ...utils import clamp | ||
|
||
|
||
@jit_class | ||
class ODE(Algorithm): | ||
""" | ||
Opposition-based Differential Evolution (ODE) algorithm for optimization. | ||
|
||
## Class Methods | ||
|
||
* `__init__`: Initializes the ODE algorithm with the given parameters, including population size, bounds, mutation strategy, and other hyperparameters. | ||
* `init_step`: Performs the initial evaluation of the population's fitness and proceeds to the first optimization step. | ||
* `step`: Executes a single optimization step of the ODE algorithm, involving mutation, crossover, selection, and opposition-based mechanisms. | ||
|
||
Note that the `evaluate` method is not defined in this class. It is expected to be provided by the `Problem` class or another external component. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pop_size: int, | ||
lb: torch.Tensor, | ||
ub: torch.Tensor, | ||
base_vector: Literal["best", "rand"] = "rand", | ||
num_difference_vectors: int = 1, | ||
differential_weight: float | torch.Tensor = 0.5, | ||
cross_probability: float = 0.9, | ||
mean: torch.Tensor | None = None, | ||
stdev: torch.Tensor | None = None, | ||
device: torch.device | None = None, | ||
): | ||
""" | ||
Initialize the Opposition-based Differential Evolution (ODE) algorithm with the given parameters. | ||
|
||
:param pop_size: The size of the population. | ||
:param lb: The lower bounds of the particle positions. Must be a 1D tensor. | ||
:param ub: The upper bounds of the particle positions. Must be a 1D tensor. | ||
:param base_vector: The base vector type used in mutation. Either "best" or "rand". Defaults to "rand". | ||
:param num_difference_vectors: The number of difference vectors used in mutation. Must be at least 1 and less than half of the population size. Defaults to 1. | ||
:param differential_weight: The differential weight(s) (F) applied to difference vectors. Can be a float or a tensor of shape [num_difference_vectors]. Defaults to 0.5. | ||
:param cross_probability: The crossover probability (CR). Must be in (0, 1]. Defaults to 0.9. | ||
:param mean: The mean for initializing the population with a normal distribution. Must be provided with `stdev` if used. Defaults to None. | ||
:param stdev: The standard deviation for initializing the population with a normal distribution. Must be provided with `mean` if used. Defaults to None. | ||
:param device: The device to use for tensor computations. Defaults to None. | ||
""" | ||
super().__init__() | ||
device = torch.get_default_device() if device is None else device | ||
|
||
# Validate input parameters | ||
assert pop_size >= 4 | ||
assert 0 < cross_probability <= 1 | ||
assert 1 <= num_difference_vectors < pop_size // 2 | ||
assert base_vector in ["rand", "best"] | ||
assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1 and lb.dtype == ub.dtype | ||
|
||
# Initialize parameters | ||
self.pop_size = pop_size | ||
self.dim = lb.shape[0] | ||
self.best_vector = base_vector == "best" | ||
self.num_difference_vectors = num_difference_vectors | ||
|
||
# Validate and set differential weight | ||
if num_difference_vectors == 1: | ||
assert isinstance(differential_weight, float) | ||
else: | ||
assert isinstance(differential_weight, torch.Tensor) and differential_weight.shape == torch.Size( | ||
[num_difference_vectors] | ||
) | ||
self.differential_weight = Parameter(differential_weight, device=device) | ||
self.cross_probability = Parameter(cross_probability, device=device) | ||
|
||
# Move bounds to the specified device and add batch dimension | ||
lb = lb[None, :].to(device=device) | ||
ub = ub[None, :].to(device=device) | ||
self.lb = lb | ||
self.ub = ub | ||
|
||
# Initialize population | ||
if mean is not None and stdev is not None: | ||
# Initialize population using a normal distribution | ||
population = mean + stdev * torch.randn(self.pop_size, self.dim, device=device) | ||
population = clamp(population, lb=self.lb, ub=self.ub) | ||
else: | ||
# Initialize population uniformly within bounds | ||
population = torch.rand(self.pop_size, self.dim, device=device) | ||
population = population * (self.ub - self.lb) + self.lb | ||
|
||
# Mutable attributes to store population and fitness | ||
self.population = Mutable(population) | ||
self.fitness = Mutable(torch.empty(self.pop_size, device=device).fill_(float("inf"))) | ||
|
||
def init_step(self): | ||
""" | ||
Perform the initial evaluation of the population's fitness and proceed to the first optimization step. | ||
|
||
This method evaluates the fitness of the initial population and then calls the `step` method to perform the first optimization iteration. | ||
""" | ||
self.fitness = self.evaluate(self.population) | ||
self.step() | ||
|
||
def step(self): | ||
""" | ||
Execute a single optimization step of the ODE algorithm. | ||
|
||
This involves the following sub-steps: | ||
1. Mutation: Generate mutant vectors based on the specified base vector strategy (`best` or `rand`) and the number of difference vectors. | ||
2. Crossover: Perform crossover between the current population and the mutant vectors based on the crossover probability. | ||
3. Selection: Evaluate the fitness of the new population and select the better individuals between the current and new populations. | ||
4. Opposition-Based Mechanism: Generate opposition-based population, evaluate their fitness, and perform selection to potentially replace current individuals with their opposites if they are better. | ||
|
||
The method ensures that all new population vectors are clamped within the specified bounds. | ||
""" | ||
device = self.population.device | ||
num_vec = self.num_difference_vectors * 2 + (0 if self.best_vector else 1) | ||
random_choices = [] | ||
|
||
# Mutation: Generate random permutations for selecting vectors | ||
# TODO: Currently allows replacement for different vectors, which is not equivalent to the original implementation | ||
# TODO: Consider changing to an implementation based on reservoir sampling (e.g., https://github.com/LeviViana/torch_sampling) in the future | ||
for _ in range(num_vec): | ||
random_choices.append(torch.randperm(self.pop_size, device=device)) | ||
|
||
# Determine the base vector | ||
if self.best_vector: | ||
# Use the best individual as the base vector | ||
best_index = torch.argmin(self.fitness) | ||
base_vector = self.population[best_index][None, :] | ||
start_index = 0 | ||
else: | ||
# Use randomly selected individuals as base vectors | ||
base_vector = self.population[random_choices[0]] | ||
start_index = 1 | ||
|
||
# Generate difference vectors by subtracting randomly chosen population vectors | ||
difference_vector = torch.stack( | ||
[ | ||
self.population[random_choices[i]] - self.population[random_choices[i + 1]] | ||
for i in range(start_index, num_vec - 1, 2) | ||
] | ||
).sum(dim=0) | ||
|
||
# Create mutant vectors by adding weighted difference vectors to the base vector | ||
new_population = base_vector + self.differential_weight * difference_vector | ||
|
||
# Crossover: Determine which dimensions to crossover based on the crossover probability | ||
cross_prob = torch.rand(self.pop_size, self.dim, device=device) | ||
random_dim = torch.randint(0, self.dim, (self.pop_size, 1), device=device) | ||
mask = cross_prob < self.cross_probability | ||
mask = mask.scatter(dim=1, index=random_dim, value=1) | ||
new_population = torch.where(mask, new_population, self.population) | ||
|
||
# Ensure new population is within bounds | ||
new_population = clamp(new_population, self.lb, self.ub) | ||
|
||
# Selection: Evaluate fitness of the new population and select the better individuals | ||
new_fitness = self.evaluate(new_population) | ||
compare = new_fitness < self.fitness | ||
self.population = torch.where(compare[:, None], new_population, self.population) | ||
self.fitness = torch.where(compare, new_fitness, self.fitness) | ||
|
||
# Opposition-Based Population: Generate opposite solutions | ||
opposition_population = self.lb + self.ub - self.population | ||
|
||
# Opposition-Based Selection: Evaluate fitness of the opposition population | ||
opposition_fitness = self.evaluate(opposition_population) | ||
compare_opposition = opposition_fitness < self.fitness | ||
|
||
# Replace individuals with their opposites if the opposites are better | ||
updated_population = torch.where(compare_opposition[:, None], opposition_population, self.population) | ||
updated_fitness = torch.where(compare_opposition, opposition_fitness, self.fitness) | ||
|
||
# Update population and fitness with opposition-based selections | ||
self.population = updated_population | ||
self.fitness = updated_fitness |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.