Skip to content

Commit 0a43091

Browse files
IvyZXFlax Authors
authored and
Flax Authors
committed
[bridge module] Add bridge.share_scope for layer-sublayer pairs.
PiperOrigin-RevId: 738189137
1 parent fa0f3e8 commit 0a43091

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

flax/nnx/bridge/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .module import compact as compact
2828
from .module import current_context as current_context
2929
from .module import current_module as current_module
30+
from .module import share_scope as share_scope
3031
from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl
3132
from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl
3233
from flax.nnx.nn import initializers as initializers

flax/nnx/bridge/module.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,14 @@ class ModuleState(statelib.State):
6565

6666

6767
class Scope(Object):
68-
def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter):
68+
def __init__(self, module: Module, rngs: rnglib.Rngs, mutable: CollectionFilter):
69+
self.module = module
6970
self.rngs = rngs
7071
self.mutable = mutable
7172

72-
def copy(self):
73-
return Scope(self.rngs, self.mutable)
73+
def copy(self, new_module):
74+
# Never copy the module - always fill in a new one
75+
return Scope(new_module, self.rngs, self.mutable)
7476

7577

7678
class _HasSetup(tp.Protocol):
@@ -104,7 +106,7 @@ def _bind_module(parent: Module, module: Module) -> Module:
104106
for _, value in reversed(list(graph.iter_graph(module))):
105107
if isinstance(value, Module):
106108
if module.scope is None:
107-
value.scope = parent.scope.copy() # type: ignore[attribute-error]
109+
value.scope = parent.scope.copy(value) # type: ignore[attribute-error]
108110
_maybe_call_setup(value)
109111
return module
110112

@@ -280,8 +282,9 @@ def param( # type: ignore[invalid-annotation]
280282
'Parameters must be initialized in `setup()` or in a method '
281283
'wrapped in `@compact`'
282284
)
283-
if hasattr(self, name):
284-
value = getattr(self, name)
285+
module = self.scope.module
286+
if hasattr(module, name):
287+
value = getattr(module, name)
285288
# TODO(cgarciae): implement reservations
286289
# if self._name_taken(name):
287290
# raise errors.NameInUseError('param', name, self.__class__.__name__)
@@ -310,10 +313,10 @@ def param( # type: ignore[invalid-annotation]
310313

311314
variable = variablelib.Param(value)
312315
else:
313-
value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
316+
value = init_fn(module.make_rng('params'), *init_args, **init_kwargs)
314317
variable = variablelib.Param(value)
315318

316-
setattr(self, name, variable)
319+
setattr(module, name, variable)
317320
return variable
318321

319322
def variable( # type: ignore[invalid-annotation]
@@ -333,9 +336,10 @@ def variable( # type: ignore[invalid-annotation]
333336
'Variables must be initialized in `setup()` or in a method '
334337
'wrapped in `@compact`'
335338
)
339+
module = self.scope.module
336340

337-
if hasattr(self, name):
338-
value = getattr(self, name)
341+
if hasattr(module, name):
342+
value = getattr(module, name)
339343
# TODO(cgarciae): implement reservations
340344
# if self._name_taken(name):
341345
# raise errors.NameInUseError('param', name, self.__class__.__name__)
@@ -367,7 +371,7 @@ def variable( # type: ignore[invalid-annotation]
367371
value = init_fn(*init_args, **init_kwargs)
368372
variable = variable_type(value)
369373

370-
setattr(self, name, variable)
374+
setattr(module, name, variable)
371375
return variable
372376

373377
def _get_variables(self) -> tp.Mapping:
@@ -474,7 +478,7 @@ def to_variable(value):
474478
if isinstance(value, Object):
475479
value._object__state._initializing = _initialize
476480
if isinstance(value, Module):
477-
value.scope = Scope(rngs, mutable)
481+
value.scope = Scope(value, rngs, mutable)
478482
_maybe_call_setup(value)
479483

480484
MODULE_CONTEXT.module_stack.append(
@@ -570,3 +574,14 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
570574

571575
return method_or_fn
572576

577+
578+
def share_scope(parent: Module, child: Module, hide_key: str | None = None):
579+
"""Behaves like `linen.share_scope`, for a pair of parent and child modules.
580+
581+
Essentially share all attribute fields of the child with the parent and make
582+
the parent attributes higher priority, to make sure variable traversal goes
583+
from the parent first. Ensures checkpoint compatibility.
584+
"""
585+
child.scope.module = parent.scope.module
586+
if hide_key is not None:
587+
parent.set_attr_priority(hide_key, AttrPriority.LOW)

tests/nnx/bridge/module_test.py

+22
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,28 @@ def __call__(self, x):
542542
params = model.init(jax.random.key(0), x)['params']
543543
self.assertSameElements([f'layer_{i}' for i in range(3)], params.keys())
544544

545+
def test_share_scope(self):
546+
class Dense(bridge.Module):
547+
dout: int
548+
@bridge.compact
549+
def __call__(self, x):
550+
return x @ self.param('w', nn.initializers.normal(),
551+
(x.shape[-1], self.dout))
552+
553+
class Top(bridge.Module):
554+
def setup(self):
555+
self.a = Dense(4)
556+
bridge.module.share_scope(self, self.a, 'a')
557+
558+
def __call__(self, x):
559+
return self.a(x)
560+
561+
model = Top()
562+
x = jnp.ones((4, 32))
563+
params = model.init(jax.random.key(0), x)['params']
564+
self.assertSameElements(['w'], params.keys()) # 'a' doesn't exist
565+
566+
545567

546568
if __name__ == '__main__':
547569
absltest.main()

0 commit comments

Comments
 (0)