Skip to content

Commit

Permalink
dev: Implicitly convert dataclass to State
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Aug 6, 2024
1 parent b0a5599 commit 32fc3b5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
14 changes: 6 additions & 8 deletions src/evox/algorithms/so/pso_variants/cso.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,12 @@ def setup(self, key):
velocity = jnp.zeros((self.pop_size, self.dim))
fitness = jnp.full((self.pop_size,), jnp.inf)

return State(
CSOState(
population=population,
fitness=fitness,
velocity=velocity,
students=jnp.empty((self.pop_size // 2,), dtype=jnp.int32),
key=state_key,
)
return CSOState(
population=population,
fitness=fitness,
velocity=velocity,
students=jnp.empty((self.pop_size // 2,), dtype=jnp.int32),
key=state_key,
)

def init_ask(self, state):
Expand Down
15 changes: 9 additions & 6 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,15 @@ def _recursive_init(
if no_state:
return None, node_id
else:
return (
self.setup(key)
._set_state_id_mut(self._node_id)
._set_child_states_mut(child_states),
node_id,
)
self_state = self.setup(key)
if dataclasses.is_dataclass(self_state):
# if the setup method return a dataclass, convert it to State first
self_state = State.from_dataclass(self_state)

self_state._set_state_id_mut(self._node_id)._set_child_states_mut(
child_states
),
return self_state, node_id

def init(self, key: jax.Array = None, no_state: bool = False) -> State:
"""Initialize this module and all submodules
Expand Down
18 changes: 18 additions & 0 deletions src/evox/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ def __init__(self, _dataclass=None, /, **kwargs) -> None:
self.__dict__["_child_states"] = State.EMPTY
self.__dict__["_state_id"] = None

@classmethod
def from_dataclass(cls, dataclass) -> Self:
"""Construct a ``State`` from dataclass instance
Example::
>>> from evox import State
>>> from dataclasses import dataclass
>>> @dataclass
>>> class Param:
... x: int
... y: int
...
>>> param = Param(x=1, y=2)
>>> State.from_dataclass(param)
State(Param(x=1, y=2), {})
"""
return cls(_dataclass=dataclass)

def _set_state_dict_mut(self, state_dict: dict) -> Self:
"""Force set child state and return self
Expand Down

0 comments on commit 32fc3b5

Please sign in to comment.