@@ -65,12 +65,14 @@ class ModuleState(statelib.State):
65
65
66
66
67
67
class Scope (Object ):
68
- def __init__ (self , rngs : rnglib .Rngs , mutable : CollectionFilter ):
68
+ def __init__ (self , module : Module , rngs : rnglib .Rngs , mutable : CollectionFilter ):
69
+ self .module = module
69
70
self .rngs = rngs
70
71
self .mutable = mutable
71
72
72
- def copy (self ):
73
- return Scope (self .rngs , self .mutable )
73
+ def copy (self , new_module ):
74
+ # Never copy the module - always fill in a new one
75
+ return Scope (new_module , self .rngs , self .mutable )
74
76
75
77
76
78
class _HasSetup (tp .Protocol ):
@@ -104,7 +106,7 @@ def _bind_module(parent: Module, module: Module) -> Module:
104
106
for _ , value in reversed (list (graph .iter_graph (module ))):
105
107
if isinstance (value , Module ):
106
108
if module .scope is None :
107
- value .scope = parent .scope .copy () # type: ignore[attribute-error]
109
+ value .scope = parent .scope .copy (value ) # type: ignore[attribute-error]
108
110
_maybe_call_setup (value )
109
111
return module
110
112
@@ -280,8 +282,9 @@ def param( # type: ignore[invalid-annotation]
280
282
'Parameters must be initialized in `setup()` or in a method '
281
283
'wrapped in `@compact`'
282
284
)
283
- if hasattr (self , name ):
284
- value = getattr (self , name )
285
+ module = self .scope .module
286
+ if hasattr (module , name ):
287
+ value = getattr (module , name )
285
288
# TODO(cgarciae): implement reservations
286
289
# if self._name_taken(name):
287
290
# raise errors.NameInUseError('param', name, self.__class__.__name__)
@@ -310,10 +313,10 @@ def param( # type: ignore[invalid-annotation]
310
313
311
314
variable = variablelib .Param (value )
312
315
else :
313
- value = init_fn (self .make_rng ('params' ), * init_args , ** init_kwargs )
316
+ value = init_fn (module .make_rng ('params' ), * init_args , ** init_kwargs )
314
317
variable = variablelib .Param (value )
315
318
316
- setattr (self , name , variable )
319
+ setattr (module , name , variable )
317
320
return variable
318
321
319
322
def variable ( # type: ignore[invalid-annotation]
@@ -333,9 +336,10 @@ def variable( # type: ignore[invalid-annotation]
333
336
'Variables must be initialized in `setup()` or in a method '
334
337
'wrapped in `@compact`'
335
338
)
339
+ module = self .scope .module
336
340
337
- if hasattr (self , name ):
338
- value = getattr (self , name )
341
+ if hasattr (module , name ):
342
+ value = getattr (module , name )
339
343
# TODO(cgarciae): implement reservations
340
344
# if self._name_taken(name):
341
345
# raise errors.NameInUseError('param', name, self.__class__.__name__)
@@ -367,7 +371,7 @@ def variable( # type: ignore[invalid-annotation]
367
371
value = init_fn (* init_args , ** init_kwargs )
368
372
variable = variable_type (value )
369
373
370
- setattr (self , name , variable )
374
+ setattr (module , name , variable )
371
375
return variable
372
376
373
377
def _get_variables (self ) -> tp .Mapping :
@@ -474,7 +478,7 @@ def to_variable(value):
474
478
if isinstance (value , Object ):
475
479
value ._object__state ._initializing = _initialize
476
480
if isinstance (value , Module ):
477
- value .scope = Scope (rngs , mutable )
481
+ value .scope = Scope (value , rngs , mutable )
478
482
_maybe_call_setup (value )
479
483
480
484
MODULE_CONTEXT .module_stack .append (
@@ -570,3 +574,14 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
570
574
571
575
return method_or_fn
572
576
577
+
578
+ def share_scope (parent : Module , child : Module , hide_key : str | None = None ):
579
+ """Behaves like `linen.share_scope`, for a pair of parent and child modules.
580
+
581
+ Essentially share all attribute fields of the child with the parent and make
582
+ the parent attributes higher priority, to make sure variable traversal goes
583
+ from the parent first. Ensures checkpoint compatibility.
584
+ """
585
+ child .scope .module = parent .scope .module
586
+ if hide_key is not None :
587
+ parent .set_attr_priority (hide_key , AttrPriority .LOW )
0 commit comments