Skip to content

Commit 8285b29

Browse files
Merge pull request #145 from EMI-Group/state_io
State based IO
2 parents c993546 + 7ed05a3 commit 8285b29

24 files changed

+771
-825
lines changed

src/evox/algorithms/so/pso_variants/pso.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,30 @@
55
# Link: https://ieeexplore.ieee.org/document/494215
66
# --------------------------------------------------------------------------------------
77

8-
from functools import partial
8+
from typing import Optional
99

1010
import jax
1111
import jax.numpy as jnp
12-
import copy
1312

13+
from evox import Algorithm, State, dataclass, pytree_field
1414
from evox.utils import *
15-
from evox import Algorithm, State, jit_class
1615

1716

18-
@jit_class
17+
@dataclass
1918
class PSO(Algorithm):
20-
def __init__(
21-
self,
22-
lb,
23-
ub,
24-
pop_size,
25-
inertia_weight=0.6,
26-
cognitive_coefficient=2.5,
27-
social_coefficient=0.8,
28-
mean=None,
29-
stdev=None,
30-
):
31-
self.dim = lb.shape[0]
32-
self.lb = lb
33-
self.ub = ub
34-
self.pop_size = pop_size
35-
self.w = inertia_weight
36-
self.phi_p = cognitive_coefficient
37-
self.phi_g = social_coefficient
38-
self.mean = mean
39-
self.stdev = stdev
19+
dim: jax.Array = pytree_field(static=True, init=False)
20+
lb: jax.Array
21+
ub: jax.Array
22+
pop_size: jax.Array = pytree_field(static=True)
23+
w: jax.Array = pytree_field(default=0.6)
24+
phi_p: jax.Array = pytree_field(default=2.5)
25+
phi_g: jax.Array = pytree_field(default=0.8)
26+
mean: Optional[jax.Array] = pytree_field(default=None)
27+
stdev: Optional[jax.Array] = pytree_field(default=None)
28+
bound_method: str = pytree_field(static=True, default="clip")
29+
30+
def __post_init__(self):
31+
self.set_frozen_attr("dim", self.lb.shape[0])
4032

4133
def setup(self, key):
4234
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
@@ -95,7 +87,23 @@ def tell(self, state, fitness):
9587
+ self.phi_g * rg * (global_best_location - state.population)
9688
)
9789
population = state.population + velocity
98-
population = jnp.clip(population, self.lb, self.ub)
90+
91+
if self.bound_method == "clip":
92+
population = jnp.clip(population, self.lb, self.ub)
93+
elif self.bound_method == "reflect":
94+
lower_bound_violation = population < self.lb
95+
upper_bound_violation = population > self.ub
96+
97+
population = jnp.where(
98+
lower_bound_violation, 2 * self.lb - population, population
99+
)
100+
population = jnp.where(
101+
upper_bound_violation, 2 * self.ub - population, population
102+
)
103+
104+
velocity = jnp.where(
105+
lower_bound_violation | upper_bound_violation, -velocity, velocity
106+
)
99107

100108
return state.replace(
101109
population=population,

src/evox/core/module.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def wrapper(self, state: State, *args, **kwargs):
7575
new_state,
7676
)
7777

78-
state = state.replace_by_path(path, new_state)
78+
state = state.replace_by_path(
79+
path, new_state.clear_callbacks()
80+
).prepend_closure(new_state)
7981

8082
if aux is None:
8183
return state
@@ -148,6 +150,10 @@ class Stateful:
148150
149151
The ``init`` method will automatically call the ``setup`` of the current module
150152
and recursively call ``setup`` methods of all submodules.
153+
154+
Currently, there are two special metadata that can be used to control the behavior of the module initialization:
155+
- ``stack``: If set to True, the module will be initialized multiple times, and the states will be stacked together.
156+
- ``nested``: If set to True, the a list of modules, that is [module1, module2, ...], will be iterated and initialized.
151157
"""
152158

153159
def __init__(self):
@@ -174,10 +180,16 @@ def setup(self, key: jax.Array) -> State:
174180
return State()
175181

176182
def _recursive_init(
177-
self, key: jax.Array, node_id: int, module_name: str, no_state: bool
183+
self,
184+
key: jax.Array,
185+
node_id: int,
186+
module_name: str,
187+
no_state: bool,
188+
re_init: bool,
178189
) -> Tuple[State, int]:
179-
object.__setattr__(self, "_node_id", node_id)
180-
object.__setattr__(self, "_module_name", module_name)
190+
if not re_init:
191+
object.__setattr__(self, "_node_id", node_id)
192+
object.__setattr__(self, "_module_name", module_name)
181193

182194
if not no_state:
183195
child_states = {}
@@ -197,6 +209,15 @@ def _recursive_init(
197209

198210
if isinstance(attr, Stateful):
199211
submodules.append(SubmoduleInfo(field.name, attr, field.metadata))
212+
213+
# handle "nested" field
214+
if field.metadata.get("nested", False):
215+
for idx, nested_module in enumerate(attr):
216+
submodules.append(
217+
SubmoduleInfo(
218+
field.name + str(idx), nested_module, field.metadata
219+
)
220+
)
200221
else:
201222
for attr_name in vars(self):
202223
attr = getattr(self, attr_name)
@@ -211,24 +232,27 @@ def _recursive_init(
211232
else:
212233
key, subkey = jax.random.split(key)
213234

214-
# handle "StackAnnotation"
235+
# handle "Stack"
215236
# attr should be a list, or tuple of modules
216237
if metadata.get("stack", False):
217238
num_copies = len(attr)
218239
subkeys = jax.random.split(subkey, num_copies)
219240
current_node_id = node_id
220-
_, node_id = attr._recursive_init(None, node_id + 1, attr_name, True)
241+
_, node_id = attr._recursive_init(
242+
None, node_id + 1, attr_name, True, re_init
243+
)
221244
submodule_state, _node_id = jax.vmap(
222245
partial(
223246
Stateful._recursive_init,
224247
node_id=current_node_id + 1,
225248
module_name=attr_name,
226249
no_state=no_state,
250+
re_init=re_init,
227251
)
228252
)(attr, subkeys)
229253
else:
230254
submodule_state, node_id = attr._recursive_init(
231-
subkey, node_id + 1, attr_name, no_state
255+
subkey, node_id + 1, attr_name, no_state, re_init
232256
)
233257

234258
if not no_state:
@@ -246,10 +270,12 @@ def _recursive_init(
246270

247271
self_state._set_state_id_mut(self._node_id)._set_child_states_mut(
248272
child_states
249-
),
273+
)
250274
return self_state, node_id
251275

252-
def init(self, key: jax.Array = None, no_state: bool = False) -> State:
276+
def init(
277+
self, key: jax.Array = None, no_state: bool = False, re_init: bool = False
278+
) -> State:
253279
"""Initialize this module and all submodules
254280
255281
This method should not be overwritten.
@@ -264,9 +290,33 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
264290
State
265291
The state of this module and all submodules combined.
266292
"""
267-
state, _node_id = self._recursive_init(key, 0, None, no_state)
293+
state, _node_id = self._recursive_init(key, 0, None, no_state, re_init)
268294
return state
269295

296+
def parallel_init(
297+
self, key: jax.Array, num_copies: int, no_state: bool = False
298+
) -> Tuple[State, int]:
299+
"""Initialize multiple copies of this module in parallel
300+
301+
This method should not be overwritten.
302+
303+
Parameters
304+
----------
305+
key
306+
A PRNGKey.
307+
num_copies
308+
The number of copies to be initialized
309+
no_state
310+
Whether to skip the state initialization
311+
312+
Returns
313+
-------
314+
Tuple[State, int]
315+
The state of this module and all submodules combined, and the last node_id
316+
"""
317+
subkeys = jax.random.split(key, num_copies)
318+
return jax.vmap(self.init, in_axes=(0, None))(subkeys, no_state)
319+
270320
@classmethod
271321
def stack(cls, stateful_objs, axis=0):
272322
for obj in stateful_objs:

src/evox/core/monitor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
class Monitor:
1+
from .module import *
2+
3+
4+
class Monitor(Stateful):
25
"""Monitor base class.
36
Monitors are used to monitor the evolutionary process.
47
They contains a set of callbacks,

src/evox/core/pytree_dataclass.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from jax.tree_util import register_pytree_node
1+
import copy
22
import dataclasses
33
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints
44

5-
from typing_extensions import (
6-
dataclass_transform, # pytype: disable=not-supported-yet
7-
)
5+
from jax.tree_util import register_pytree_node
6+
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
87

98
from .distributed import ShardingType
109

@@ -19,10 +18,26 @@ def pytree_field(
1918
return dataclasses.field(**kwargs)
2019

2120

21+
def _dataclass_set_frozen_attr(self, key, value):
22+
object.__setattr__(self, key, value)
23+
24+
25+
def _dataclass_replace(self, **kwargs):
26+
"""Add a replace method to dataclasses.
27+
It's different from dataclasses.replace in that it doesn't call the __init__,
28+
instead it copies the object and sets the new values.
29+
"""
30+
new_obj = copy.copy(self)
31+
for key, value in kwargs.items():
32+
object.__setattr__(new_obj, key, value)
33+
return new_obj
34+
35+
2236
def dataclass(cls, *args, **kwargs):
2337
"""
2438
A dataclass decorator that also registers the dataclass as a pytree node.
2539
"""
40+
kwargs = {"unsafe_hash": False, "eq": False, **kwargs}
2641
cls = dataclasses.dataclass(cls, *args, **kwargs)
2742

2843
field_info = []
@@ -78,7 +93,8 @@ def unflatten(aux_data, children):
7893
register_pytree_node(cls, flatten, unflatten)
7994

8095
# Add a method to set frozen attributes after init
81-
cls.set_frozen_attr = lambda self, key, value: object.__setattr__(self, key, value)
96+
cls.set_frozen_attr = _dataclass_set_frozen_attr
97+
cls.replace = _dataclass_replace
8298
return cls
8399

84100

0 commit comments

Comments
 (0)