Skip to content

Commit a970f62

Browse files
Merge pull request #141 from EMI-Group/dev3-lzy
Fix bugs of MOEA/D
2 parents 9527dca + 948a490 commit a970f62

File tree

5 files changed

+102
-112
lines changed

5 files changed

+102
-112
lines changed

src/evox/algorithms/mo/eagmoead.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
non_dominated_sort,
2424
crowding_distance,
2525
)
26-
from evox.operators.sampling import LatinHypercubeSampling
27-
from evox.utils import pairwise_euclidean_dist
26+
from evox.operators.sampling import LatinHypercubeSampling, UniformSampling
27+
from evox.utils import pairwise_euclidean_dist, AggregationFunction
2828

2929

3030
@partial(jax.jit, static_argnums=[1])
@@ -76,6 +76,7 @@ def __init__(
7676
if self.crossover is None:
7777
self.crossover = crossover.SimulatedBinary(type=2)
7878
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs)
79+
self.aggregate_func = AggregationFunction("weighted_sum")
7980

8081
def setup(self, key):
8182
key, subkey1, subkey2 = jax.random.split(key, 3)
@@ -160,10 +161,10 @@ def tell(self, state, fitness):
160161

161162
def body_fun(i, vals):
162163
population, pop_obj = vals
163-
g_old = jnp.sum(
164-
pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1
164+
g_old = self.aggregate_func(
165+
pop_obj[B[offspring_loc[i], :]], w[B[offspring_loc[i], :]]
165166
)
166-
g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i])
167+
g_new = self.aggregate_func(offspring_obj[i], w[B[offspring_loc[i], :]])
167168
idx = B[offspring_loc[i]]
168169
g_new = g_new[:, jnp.newaxis]
169170
g_old = g_old[:, jnp.newaxis]

src/evox/algorithms/mo/moead.py

Lines changed: 43 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
#
44
# Title: MOEA/D: A Multiobjective Evolutionary Algorithm Based on Decomposition
55
# Link: https://ieeexplore.ieee.org/document/4358754
6-
#
7-
# 2. This code has been inspired by PlatEMO.
8-
# More information about PlatEMO can be found at the following URL:
9-
# GitHub Link: https://github.com/BIMK/PlatEMO
106
# --------------------------------------------------------------------------------------
117

128
import math
@@ -17,8 +13,7 @@
1713
from evox import Algorithm, State, jit_class
1814
from evox.operators import crossover, mutation
1915
from evox.operators.sampling import UniformSampling
20-
from evox.utils import pairwise_euclidean_dist
21-
16+
from evox.utils import pairwise_euclidean_dist, AggregationFunction
2217

2318
@jit_class
2419
class MOEAD(Algorithm):
@@ -33,7 +28,7 @@ def __init__(
3328
ub,
3429
n_objs,
3530
pop_size,
36-
type=1,
31+
func_name='pbi',
3732
mutation_op=None,
3833
crossover_op=None,
3934
):
@@ -42,8 +37,8 @@ def __init__(
4237
self.n_objs = n_objs
4338
self.dim = lb.shape[0]
4439
self.pop_size = pop_size
45-
self.type = type
46-
self.T = 0
40+
self.func_name = func_name
41+
self.n_neighbor = 0
4742

4843
self.mutation = mutation_op
4944
self.crossover = crossover_op
@@ -53,54 +48,57 @@ def __init__(
5348
if self.crossover is None:
5449
self.crossover = crossover.SimulatedBinary(type=2)
5550
self.sample = UniformSampling(self.pop_size, self.n_objs)
51+
self.aggregate_func = AggregationFunction(self.func_name)
5652

5753
def setup(self, key):
5854
key, subkey1, subkey2 = jax.random.split(key, 3)
5955
w, _ = self.sample(subkey2)
6056
self.pop_size = w.shape[0]
61-
self.T = int(math.ceil(self.pop_size / 10))
57+
self.n_neighbor = int(math.ceil(self.pop_size / 10))
6258

6359
population = (
6460
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
6561
* (self.ub - self.lb)
6662
+ self.lb
6763
)
6864

69-
B = pairwise_euclidean_dist(w, w)
70-
B = jnp.argsort(B, axis=1)
71-
B = B[:, : self.T]
65+
neighbors = pairwise_euclidean_dist(w, w)
66+
neighbors = jnp.argsort(neighbors, axis=1)
67+
neighbors = neighbors[:, : self.n_neighbor]
68+
7269
return State(
7370
population=population,
7471
fitness=jnp.zeros((self.pop_size, self.n_objs)),
7572
next_generation=population,
7673
weight_vector=w,
77-
B=B,
78-
Z=jnp.zeros(shape=self.n_objs),
79-
parent=jnp.zeros((self.pop_size, self.T)).astype(int),
74+
neighbors=neighbors,
75+
z=jnp.zeros(shape=self.n_objs),
8076
key=key,
8177
)
8278

8379
def init_ask(self, state):
8480
return state.population, state
8581

8682
def init_tell(self, state, fitness):
87-
Z = jnp.min(fitness, axis=0)
88-
state = state.update(fitness=fitness, Z=Z)
83+
z = jnp.min(fitness, axis=0)
84+
state = state.update(fitness=fitness, z=z)
8985
return state
9086

9187
def ask(self, state):
9288
key, subkey, sel_key, mut_key = jax.random.split(state.key, 4)
9389
parent = jax.random.permutation(
94-
subkey, state.B, axis=1, independent=True
90+
subkey, state.neighbors, axis=1, independent=True
9591
).astype(int)
92+
9693
population = state.population
9794
selected_p = jnp.r_[population[parent[:, 0]], population[parent[:, 1]]]
9895

9996
crossovered = self.crossover(sel_key, selected_p)
10097
next_generation = self.mutation(mut_key, crossovered)
98+
next_generation = jnp.clip(next_generation, self.lb, self.ub)
10199

102100
return next_generation, state.update(
103-
next_generation=next_generation, parent=parent, key=key
101+
next_generation=next_generation, key=key
104102
)
105103

106104
def tell(self, state, fitness):
@@ -109,84 +107,28 @@ def tell(self, state, fitness):
109107
offspring = state.next_generation
110108
obj = fitness
111109
w = state.weight_vector
112-
Z = state.Z
113-
parent = state.parent
114-
115-
out_vals = (population, pop_obj, Z)
116-
117-
def out_body(i, out_vals):
118-
population, pop_obj, Z = out_vals
119-
ind_p = parent[i]
120-
ind_obj = obj[i]
121-
Z = jnp.minimum(Z, obj[i])
122-
123-
if self.type == 1:
124-
# PBI approach
125-
norm_w = jnp.linalg.norm(w[ind_p], axis=1)
126-
norm_p = jnp.linalg.norm(
127-
pop_obj[ind_p] - jnp.tile(Z, (self.T, 1)), axis=1
128-
)
129-
norm_o = jnp.linalg.norm(ind_obj - Z)
130-
cos_p = (
131-
jnp.sum(
132-
(pop_obj[ind_p] - jnp.tile(Z, (self.T, 1))) * w[ind_p], axis=1
133-
)
134-
/ norm_w
135-
/ norm_p
136-
)
137-
cos_o = (
138-
jnp.sum(jnp.tile(ind_obj - Z, (self.T, 1)) * w[ind_p], axis=1)
139-
/ norm_w
140-
/ norm_o
141-
)
142-
g_old = norm_p * cos_p + 5 * norm_p * jnp.sqrt(1 - cos_p**2)
143-
g_new = norm_o * cos_o + 5 * norm_o * jnp.sqrt(1 - cos_o**2)
144-
if self.type == 2:
145-
# Tchebycheff approach
146-
g_old = jnp.max(
147-
jnp.abs(pop_obj[ind_p] - jnp.tile(Z, (self.T, 1))) * w[ind_p],
148-
axis=1,
149-
)
150-
g_new = jnp.max(
151-
jnp.tile(jnp.abs(ind_obj - Z), (self.T, 1)) * w[ind_p], axis=1
152-
)
153-
if self.type == 3:
154-
# Tchebycheff approach with normalization
155-
z_max = jnp.max(pop_obj, axis=0)
156-
g_old = jnp.max(
157-
jnp.abs(pop_obj[ind_p] - jnp.tile(Z, (self.T, 1)))
158-
/ jnp.tile(z_max - Z, (self.T, 1))
159-
* w[ind_p],
160-
axis=1,
161-
)
162-
g_new = jnp.max(
163-
jnp.tile(jnp.abs(ind_obj - Z), (self.T, 1))
164-
/ jnp.tile(z_max - Z, (self.T, 1))
165-
* w[ind_p],
166-
axis=1,
167-
)
168-
if self.type == 4:
169-
# Modified Tchebycheff approach
170-
g_old = jnp.max(
171-
jnp.abs(pop_obj[ind_p] - jnp.tile(Z, (self.T, 1))) / w[ind_p],
172-
axis=1,
173-
)
174-
g_new = jnp.max(
175-
jnp.tile(jnp.abs(ind_obj - Z), (self.T, 1)) / w[ind_p], axis=1
176-
)
177-
178-
g_new = g_new[:, jnp.newaxis]
179-
g_old = g_old[:, jnp.newaxis]
180-
population = population.at[ind_p].set(
181-
jnp.where(g_old >= g_new, offspring[ind_p], population[ind_p])
182-
)
183-
pop_obj = pop_obj.at[ind_p].set(
184-
jnp.where(g_old >= g_new, obj[ind_p], pop_obj[ind_p])
185-
)
186-
187-
return (population, pop_obj, Z)
188-
189-
population, pop_obj, Z = jax.lax.fori_loop(0, self.pop_size, out_body, out_vals)
190-
191-
state = state.update(population=population, fitness=pop_obj, Z=Z)
192-
return state
110+
111+
z = jnp.minimum(state.z, jnp.min(obj, axis=0))
112+
z_max = jnp.max(pop_obj, axis=0)
113+
neighbors = state.neighbors
114+
115+
def scan_body(carry, x):
116+
population, pop_obj = carry
117+
off_pop, off_obj, indices = x
118+
119+
f_old = self.aggregate_func(pop_obj[indices], w[indices], z, z_max)
120+
f_new = self.aggregate_func(off_obj[jnp.newaxis, :], w[indices], z, z_max)
121+
122+
update_condition = (f_old > f_new)[:, jnp.newaxis]
123+
updated_population = population.at[indices].set(
124+
jnp.where(update_condition, jnp.tile(off_pop, (jnp.shape(indices)[0], 1)), population[indices]))
125+
updated_pop_obj = pop_obj.at[indices].set(
126+
jnp.where(update_condition, jnp.tile(off_obj, (jnp.shape(indices)[0], 1)), pop_obj[indices]))
127+
128+
return (updated_population, updated_pop_obj), None
129+
130+
(population, pop_obj), _ = jax.lax.scan(scan_body, (population, pop_obj), (offspring, obj, neighbors))
131+
132+
133+
state = state.update(population=population, fitness=pop_obj, z=z)
134+
return state

src/evox/algorithms/mo/moeaddra.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from evox import Algorithm, State, jit_class
1818
from evox.operators import crossover, mutation, selection
1919
from evox.operators.sampling import LatinHypercubeSampling
20-
from evox.utils import pairwise_euclidean_dist
20+
from evox.utils import pairwise_euclidean_dist, AggregationFunction
2121

2222

2323
@jit_class
@@ -55,6 +55,7 @@ def __init__(
5555
if self.crossover is None:
5656
self.crossover = crossover.DifferentialEvolve()
5757
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs)
58+
self.aggergate_func = AggregationFunction("tchebycheff")
5859

5960
def setup(self, key):
6061
key, subkey1, subkey2 = jax.random.split(key, 3)
@@ -155,10 +156,8 @@ def out_body(i, out_vals):
155156
ind_obj = off_obj[i]
156157
Z = jnp.minimum(Z, ind_obj)
157158

158-
g_old = jnp.max(
159-
jnp.abs(pop_obj[p] - jnp.tile(Z, (len(p), 1))) * w[p], axis=1
160-
)
161-
g_new = jnp.max(jnp.abs(jnp.tile(ind_obj - Z, (len(p), 1))) * w[p], axis=1)
159+
g_old = self.aggergate_func(pop_obj[p], w[p], Z)
160+
g_new = self.aggergate_func(ind_obj[jnp.newaxis, :], w[p], Z)
162161

163162
g_new = g_new[:, jnp.newaxis]
164163
g_old = g_old[:, jnp.newaxis]

src/evox/utils/common.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,52 @@ def frames2gif(frames, save_path, duration=0.1):
266266
writer.append_data(formatted_image)
267267

268268
print("Gif saved to: ", save_path)
269+
270+
271+
@jit_class
272+
class AggregationFunction:
273+
"""
274+
Aggregation function: PBI approach, Tchebycheff approach, Tchebycheff approach with normalization,
275+
modified Tchebycheff approach, weighted sum approach.
276+
277+
Args:
278+
function_name (str): name of the aggregation function.
279+
"""
280+
281+
def __init__(self, function_name):
282+
if function_name == "pbi":
283+
self.function = self.pbi
284+
elif function_name == "tchebycheff":
285+
self.function = self.tchebycheff
286+
elif function_name == "tchebycheff_norm":
287+
self.function = self.tchebycheff_norm
288+
elif function_name == "modified_tchebycheff":
289+
self.function = self.modified_tchebycheff
290+
elif function_name == "weighted_sum":
291+
self.function = self.weighted_sum
292+
else:
293+
raise ValueError("Unsupported function")
294+
295+
def pbi(self, f, w, z, *args):
296+
norm_w = jnp.linalg.norm(w, axis=1)
297+
f = f - z
298+
d1 = jnp.sum(f * w, axis=1) / norm_w
299+
d2 = jnp.linalg.norm(
300+
f - (d1[:, jnp.newaxis] * w / norm_w[:, jnp.newaxis]), axis=1
301+
)
302+
return d1 + 5 * d2
303+
304+
def tchebycheff(self, f, w, z, *args):
305+
return jnp.max(jnp.abs(f - z) * w, axis=1)
306+
307+
def tchebycheff_norm(self, f, w, z, z_max, *args):
308+
return jnp.max(jnp.abs(f - z) / (z_max - z) * w, axis=1)
309+
310+
def modified_tchebycheff(self, f, w, z, *args):
311+
return jnp.max(jnp.abs(f - z) / w, axis=1)
312+
313+
def weighted_sum(self, f, w, *args):
314+
return jnp.sum(f * w, axis=1)
315+
316+
def __call__(self, *args, **kwargs):
317+
return self.function(*args, **kwargs)

tests/test_multi_objective_algorithms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def test_moead():
4646
ub=jnp.full(shape=(N,), fill_value=UB),
4747
n_objs=M,
4848
pop_size=POP_SIZE,
49-
type=1,
5049
)
5150
run_moea(algorithm)
5251

0 commit comments

Comments
 (0)