3
3
#
4
4
# Title: MOEA/D: A Multiobjective Evolutionary Algorithm Based on Decomposition
5
5
# Link: https://ieeexplore.ieee.org/document/4358754
6
- #
7
- # 2. This code has been inspired by PlatEMO.
8
- # More information about PlatEMO can be found at the following URL:
9
- # GitHub Link: https://github.com/BIMK/PlatEMO
10
6
# --------------------------------------------------------------------------------------
11
7
12
8
import math
17
13
from evox import Algorithm , State , jit_class
18
14
from evox .operators import crossover , mutation
19
15
from evox .operators .sampling import UniformSampling
20
- from evox .utils import pairwise_euclidean_dist
21
-
16
+ from evox .utils import pairwise_euclidean_dist , AggregationFunction
22
17
23
18
@jit_class
24
19
class MOEAD (Algorithm ):
@@ -33,7 +28,7 @@ def __init__(
33
28
ub ,
34
29
n_objs ,
35
30
pop_size ,
36
- type = 1 ,
31
+ func_name = 'pbi' ,
37
32
mutation_op = None ,
38
33
crossover_op = None ,
39
34
):
@@ -42,8 +37,8 @@ def __init__(
42
37
self .n_objs = n_objs
43
38
self .dim = lb .shape [0 ]
44
39
self .pop_size = pop_size
45
- self .type = type
46
- self .T = 0
40
+ self .func_name = func_name
41
+ self .n_neighbor = 0
47
42
48
43
self .mutation = mutation_op
49
44
self .crossover = crossover_op
@@ -53,54 +48,57 @@ def __init__(
53
48
if self .crossover is None :
54
49
self .crossover = crossover .SimulatedBinary (type = 2 )
55
50
self .sample = UniformSampling (self .pop_size , self .n_objs )
51
+ self .aggregate_func = AggregationFunction (self .func_name )
56
52
57
53
def setup (self , key ):
58
54
key , subkey1 , subkey2 = jax .random .split (key , 3 )
59
55
w , _ = self .sample (subkey2 )
60
56
self .pop_size = w .shape [0 ]
61
- self .T = int (math .ceil (self .pop_size / 10 ))
57
+ self .n_neighbor = int (math .ceil (self .pop_size / 10 ))
62
58
63
59
population = (
64
60
jax .random .uniform (subkey1 , shape = (self .pop_size , self .dim ))
65
61
* (self .ub - self .lb )
66
62
+ self .lb
67
63
)
68
64
69
- B = pairwise_euclidean_dist (w , w )
70
- B = jnp .argsort (B , axis = 1 )
71
- B = B [:, : self .T ]
65
+ neighbors = pairwise_euclidean_dist (w , w )
66
+ neighbors = jnp .argsort (neighbors , axis = 1 )
67
+ neighbors = neighbors [:, : self .n_neighbor ]
68
+
72
69
return State (
73
70
population = population ,
74
71
fitness = jnp .zeros ((self .pop_size , self .n_objs )),
75
72
next_generation = population ,
76
73
weight_vector = w ,
77
- B = B ,
78
- Z = jnp .zeros (shape = self .n_objs ),
79
- parent = jnp .zeros ((self .pop_size , self .T )).astype (int ),
74
+ neighbors = neighbors ,
75
+ z = jnp .zeros (shape = self .n_objs ),
80
76
key = key ,
81
77
)
82
78
83
79
def init_ask (self , state ):
84
80
return state .population , state
85
81
86
82
def init_tell (self , state , fitness ):
87
- Z = jnp .min (fitness , axis = 0 )
88
- state = state .update (fitness = fitness , Z = Z )
83
+ z = jnp .min (fitness , axis = 0 )
84
+ state = state .update (fitness = fitness , z = z )
89
85
return state
90
86
91
87
def ask (self , state ):
92
88
key , subkey , sel_key , mut_key = jax .random .split (state .key , 4 )
93
89
parent = jax .random .permutation (
94
- subkey , state .B , axis = 1 , independent = True
90
+ subkey , state .neighbors , axis = 1 , independent = True
95
91
).astype (int )
92
+
96
93
population = state .population
97
94
selected_p = jnp .r_ [population [parent [:, 0 ]], population [parent [:, 1 ]]]
98
95
99
96
crossovered = self .crossover (sel_key , selected_p )
100
97
next_generation = self .mutation (mut_key , crossovered )
98
+ next_generation = jnp .clip (next_generation , self .lb , self .ub )
101
99
102
100
return next_generation , state .update (
103
- next_generation = next_generation , parent = parent , key = key
101
+ next_generation = next_generation , key = key
104
102
)
105
103
106
104
def tell (self , state , fitness ):
@@ -109,84 +107,28 @@ def tell(self, state, fitness):
109
107
offspring = state .next_generation
110
108
obj = fitness
111
109
w = state .weight_vector
112
- Z = state .Z
113
- parent = state .parent
114
-
115
- out_vals = (population , pop_obj , Z )
116
-
117
- def out_body (i , out_vals ):
118
- population , pop_obj , Z = out_vals
119
- ind_p = parent [i ]
120
- ind_obj = obj [i ]
121
- Z = jnp .minimum (Z , obj [i ])
122
-
123
- if self .type == 1 :
124
- # PBI approach
125
- norm_w = jnp .linalg .norm (w [ind_p ], axis = 1 )
126
- norm_p = jnp .linalg .norm (
127
- pop_obj [ind_p ] - jnp .tile (Z , (self .T , 1 )), axis = 1
128
- )
129
- norm_o = jnp .linalg .norm (ind_obj - Z )
130
- cos_p = (
131
- jnp .sum (
132
- (pop_obj [ind_p ] - jnp .tile (Z , (self .T , 1 ))) * w [ind_p ], axis = 1
133
- )
134
- / norm_w
135
- / norm_p
136
- )
137
- cos_o = (
138
- jnp .sum (jnp .tile (ind_obj - Z , (self .T , 1 )) * w [ind_p ], axis = 1 )
139
- / norm_w
140
- / norm_o
141
- )
142
- g_old = norm_p * cos_p + 5 * norm_p * jnp .sqrt (1 - cos_p ** 2 )
143
- g_new = norm_o * cos_o + 5 * norm_o * jnp .sqrt (1 - cos_o ** 2 )
144
- if self .type == 2 :
145
- # Tchebycheff approach
146
- g_old = jnp .max (
147
- jnp .abs (pop_obj [ind_p ] - jnp .tile (Z , (self .T , 1 ))) * w [ind_p ],
148
- axis = 1 ,
149
- )
150
- g_new = jnp .max (
151
- jnp .tile (jnp .abs (ind_obj - Z ), (self .T , 1 )) * w [ind_p ], axis = 1
152
- )
153
- if self .type == 3 :
154
- # Tchebycheff approach with normalization
155
- z_max = jnp .max (pop_obj , axis = 0 )
156
- g_old = jnp .max (
157
- jnp .abs (pop_obj [ind_p ] - jnp .tile (Z , (self .T , 1 )))
158
- / jnp .tile (z_max - Z , (self .T , 1 ))
159
- * w [ind_p ],
160
- axis = 1 ,
161
- )
162
- g_new = jnp .max (
163
- jnp .tile (jnp .abs (ind_obj - Z ), (self .T , 1 ))
164
- / jnp .tile (z_max - Z , (self .T , 1 ))
165
- * w [ind_p ],
166
- axis = 1 ,
167
- )
168
- if self .type == 4 :
169
- # Modified Tchebycheff approach
170
- g_old = jnp .max (
171
- jnp .abs (pop_obj [ind_p ] - jnp .tile (Z , (self .T , 1 ))) / w [ind_p ],
172
- axis = 1 ,
173
- )
174
- g_new = jnp .max (
175
- jnp .tile (jnp .abs (ind_obj - Z ), (self .T , 1 )) / w [ind_p ], axis = 1
176
- )
177
-
178
- g_new = g_new [:, jnp .newaxis ]
179
- g_old = g_old [:, jnp .newaxis ]
180
- population = population .at [ind_p ].set (
181
- jnp .where (g_old >= g_new , offspring [ind_p ], population [ind_p ])
182
- )
183
- pop_obj = pop_obj .at [ind_p ].set (
184
- jnp .where (g_old >= g_new , obj [ind_p ], pop_obj [ind_p ])
185
- )
186
-
187
- return (population , pop_obj , Z )
188
-
189
- population , pop_obj , Z = jax .lax .fori_loop (0 , self .pop_size , out_body , out_vals )
190
-
191
- state = state .update (population = population , fitness = pop_obj , Z = Z )
192
- return state
110
+
111
+ z = jnp .minimum (state .z , jnp .min (obj , axis = 0 ))
112
+ z_max = jnp .max (pop_obj , axis = 0 )
113
+ neighbors = state .neighbors
114
+
115
+ def scan_body (carry , x ):
116
+ population , pop_obj = carry
117
+ off_pop , off_obj , indices = x
118
+
119
+ f_old = self .aggregate_func (pop_obj [indices ], w [indices ], z , z_max )
120
+ f_new = self .aggregate_func (off_obj [jnp .newaxis , :], w [indices ], z , z_max )
121
+
122
+ update_condition = (f_old > f_new )[:, jnp .newaxis ]
123
+ updated_population = population .at [indices ].set (
124
+ jnp .where (update_condition , jnp .tile (off_pop , (jnp .shape (indices )[0 ], 1 )), population [indices ]))
125
+ updated_pop_obj = pop_obj .at [indices ].set (
126
+ jnp .where (update_condition , jnp .tile (off_obj , (jnp .shape (indices )[0 ], 1 )), pop_obj [indices ]))
127
+
128
+ return (updated_population , updated_pop_obj ), None
129
+
130
+ (population , pop_obj ), _ = jax .lax .scan (scan_body , (population , pop_obj ), (offspring , obj , neighbors ))
131
+
132
+
133
+ state = state .update (population = population , fitness = pop_obj , z = z )
134
+ return state
0 commit comments