@@ -75,7 +75,9 @@ def wrapper(self, state: State, *args, **kwargs):
75
75
new_state ,
76
76
)
77
77
78
- state = state .replace_by_path (path , new_state )
78
+ state = state .replace_by_path (
79
+ path , new_state .clear_callbacks ()
80
+ ).prepend_closure (new_state )
79
81
80
82
if aux is None :
81
83
return state
@@ -148,6 +150,10 @@ class Stateful:
148
150
149
151
The ``init`` method will automatically call the ``setup`` of the current module
150
152
and recursively call ``setup`` methods of all submodules.
153
+
154
+ Currently, there are two special metadata that can be used to control the behavior of the module initialization:
155
+ - ``stack``: If set to True, the module will be initialized multiple times, and the states will be stacked together.
156
+ - ``nested``: If set to True, the a list of modules, that is [module1, module2, ...], will be iterated and initialized.
151
157
"""
152
158
153
159
def __init__ (self ):
@@ -174,10 +180,16 @@ def setup(self, key: jax.Array) -> State:
174
180
return State ()
175
181
176
182
def _recursive_init (
177
- self , key : jax .Array , node_id : int , module_name : str , no_state : bool
183
+ self ,
184
+ key : jax .Array ,
185
+ node_id : int ,
186
+ module_name : str ,
187
+ no_state : bool ,
188
+ re_init : bool ,
178
189
) -> Tuple [State , int ]:
179
- object .__setattr__ (self , "_node_id" , node_id )
180
- object .__setattr__ (self , "_module_name" , module_name )
190
+ if not re_init :
191
+ object .__setattr__ (self , "_node_id" , node_id )
192
+ object .__setattr__ (self , "_module_name" , module_name )
181
193
182
194
if not no_state :
183
195
child_states = {}
@@ -197,6 +209,15 @@ def _recursive_init(
197
209
198
210
if isinstance (attr , Stateful ):
199
211
submodules .append (SubmoduleInfo (field .name , attr , field .metadata ))
212
+
213
+ # handle "nested" field
214
+ if field .metadata .get ("nested" , False ):
215
+ for idx , nested_module in enumerate (attr ):
216
+ submodules .append (
217
+ SubmoduleInfo (
218
+ field .name + str (idx ), nested_module , field .metadata
219
+ )
220
+ )
200
221
else :
201
222
for attr_name in vars (self ):
202
223
attr = getattr (self , attr_name )
@@ -211,24 +232,27 @@ def _recursive_init(
211
232
else :
212
233
key , subkey = jax .random .split (key )
213
234
214
- # handle "StackAnnotation "
235
+ # handle "Stack "
215
236
# attr should be a list, or tuple of modules
216
237
if metadata .get ("stack" , False ):
217
238
num_copies = len (attr )
218
239
subkeys = jax .random .split (subkey , num_copies )
219
240
current_node_id = node_id
220
- _ , node_id = attr ._recursive_init (None , node_id + 1 , attr_name , True )
241
+ _ , node_id = attr ._recursive_init (
242
+ None , node_id + 1 , attr_name , True , re_init
243
+ )
221
244
submodule_state , _node_id = jax .vmap (
222
245
partial (
223
246
Stateful ._recursive_init ,
224
247
node_id = current_node_id + 1 ,
225
248
module_name = attr_name ,
226
249
no_state = no_state ,
250
+ re_init = re_init ,
227
251
)
228
252
)(attr , subkeys )
229
253
else :
230
254
submodule_state , node_id = attr ._recursive_init (
231
- subkey , node_id + 1 , attr_name , no_state
255
+ subkey , node_id + 1 , attr_name , no_state , re_init
232
256
)
233
257
234
258
if not no_state :
@@ -246,10 +270,12 @@ def _recursive_init(
246
270
247
271
self_state ._set_state_id_mut (self ._node_id )._set_child_states_mut (
248
272
child_states
249
- ),
273
+ )
250
274
return self_state , node_id
251
275
252
- def init (self , key : jax .Array = None , no_state : bool = False ) -> State :
276
+ def init (
277
+ self , key : jax .Array = None , no_state : bool = False , re_init : bool = False
278
+ ) -> State :
253
279
"""Initialize this module and all submodules
254
280
255
281
This method should not be overwritten.
@@ -264,9 +290,33 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
264
290
State
265
291
The state of this module and all submodules combined.
266
292
"""
267
- state , _node_id = self ._recursive_init (key , 0 , None , no_state )
293
+ state , _node_id = self ._recursive_init (key , 0 , None , no_state , re_init )
268
294
return state
269
295
296
+ def parallel_init (
297
+ self , key : jax .Array , num_copies : int , no_state : bool = False
298
+ ) -> Tuple [State , int ]:
299
+ """Initialize multiple copies of this module in parallel
300
+
301
+ This method should not be overwritten.
302
+
303
+ Parameters
304
+ ----------
305
+ key
306
+ A PRNGKey.
307
+ num_copies
308
+ The number of copies to be initialized
309
+ no_state
310
+ Whether to skip the state initialization
311
+
312
+ Returns
313
+ -------
314
+ Tuple[State, int]
315
+ The state of this module and all submodules combined, and the last node_id
316
+ """
317
+ subkeys = jax .random .split (key , num_copies )
318
+ return jax .vmap (self .init , in_axes = (0 , None ))(subkeys , no_state )
319
+
270
320
@classmethod
271
321
def stack (cls , stateful_objs , axis = 0 ):
272
322
for obj in stateful_objs :
0 commit comments