Skip to content

Add shape-based lazy init to LinenToNNX (prev LinenWrapper) #4081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flax/nnx/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .module import Scope as Scope
from .module import compact as compact
from .wrappers import functional as functional
from .wrappers import LinenWrapper as LinenWrapper
from .wrappers import LinenToNNX as LinenToNNX
from .wrappers import Functional as Functional
from .wrappers import NNXWrapper as NNXWrapper
from .wrappers import NNXToLinen as NNXToLinen
from .wrappers import lazy_init as lazy_init
109 changes: 74 additions & 35 deletions flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

from flax import nnx
from flax import linen
from flax.nnx.nnx import graph
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
from flax.nnx.nnx.object import Object
import jax
from jax import tree_util as jtu

M = tp.TypeVar('M', bound=Module)

Expand Down Expand Up @@ -55,56 +59,91 @@ def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]:
return _functional_constructor


class LinenWrapper(Module):
def _set_initializing(module: Module, initializing: bool):
for _, value in graph.iter_graph(module):
if isinstance(value, Object):
value._object__state._initializing = initializing


def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
"""To run through an arbitrary nnx.Module method and initialize all its needed state.

Here used to trigger initialization of all `LinenToNNX` module variables."""
if isinstance(fn, Module):
module = fn
assert callable(fn)
else:
assert hasattr(fn, '__self__') and isinstance(fn.__self__, Module), f'{fn = } needs to be a method of an NNX Module.'
module = fn.__self__
_set_initializing(module, True)
try:
_ = fn(*args, **kwargs)
finally:
_set_initializing(module, False)
return fn


class LinenToNNX(Module):
def __init__(
self,
module: linen.Module,
*args: tp.Any,
rngs: tp.Optional[Rngs] = None,
**kwargs: tp.Any,
):
self.module = module
self.rngs = rngs
self.linen_collections: set[str] = set()

_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
if rngs
else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs['default']
del _rngs['default']

variables = module.init(_rngs, *args, **kwargs)

self.states = {
collection: variableslib.variable_type(collection)(value)
for collection, value in variables.items()
}
def lazy_init(self, *args, **kwargs):
return lazy_init(self, *args, **kwargs)

def __call__(
self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any
self, *args: Any, rngs: tp.Optional[Rngs] = None,
method: tp.Callable[..., Any] | str | None = None, **kwargs: Any
) -> Any:
_rngs = (
{name: stream.key.value for name, stream in rngs.items()} if rngs else {}
)

variables = {
collection: value.value for collection, value in self.states.items()
}
out = self.module.apply(variables, *args, rngs=_rngs, **kwargs)

# Shape-based lazy init of the flax variables
if not rngs:
rngs = self.rngs
if self._object__state.initializing:
_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to generate new keys so Linen Modules get new RNG state every time.

Suggested change
{name: stream.key.raw_value for name, stream in rngs.items()}
{name: stream() for name, stream in rngs.items()}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if rngs
else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)
def nn_var_to_nnx_state(kp, v):
assert isinstance(kp[0], jtu.DictKey)
vtype = variableslib.variable_type(kp[0].key)
return vtype(v)
for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items():
self._setattr(col, tree)
self.linen_collections.add(col)

else:
variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col))
for col in self.linen_collections}
_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
)
out = self.module.apply(variables, *args, rngs=_rngs, method=method, **kwargs)

# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
for collection, value in updates.items():
if collection in self.states:
self.states[collection] = value
else:
self.states[collection] = variableslib.variable_type(collection)(
value
)
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value))

return out


class NNXWrapper(linen.Module): ...
class NNXToLinen(linen.Module):
module: Module

def setup(self):
...

def __call__(self, *args, **kwargs):
...
85 changes: 78 additions & 7 deletions flax/nnx/tests/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from absl.testing import absltest

from flax import linen
import flax
from flax import linen as nn
from flax import nnx
from flax.nnx import bridge
import jax
import jax.numpy as jnp
import numpy as np


class TestCompatibility:
class TestCompatibility(absltest.TestCase):
def test_functional(self):
# Functional API for NNX Modules
functional = bridge.functional(nnx.Linear)(32, 64)
state = functional.init(rngs=nnx.Rngs(0))
x = jax.numpy.ones((1, 32))
y, updates = functional.apply(state)(x)

def test_linen_wrapper(self):
def test_linen_to_nnx(self):
## Wrapper API for Linen Modules
linen_module = linen.Dense(features=64)
linen_module = nn.Dense(features=64)
x = jax.numpy.ones((1, 32))
module = bridge.LinenWrapper(linen_module, x, rngs=nnx.Rngs(0)) # init
y = module(x) # apply
model = bridge.LinenToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init
y = model(x) # like linen apply
assert y.shape == (1, 64)

def test_linen_to_nnx_submodule(self):
class NNXOuter(nnx.Module):
def __init__(self, dout: int, *, rngs: nnx.Rngs):
self.nn_dense1 = bridge.LinenToNNX(nn.Dense(dout, use_bias=False), rngs=rngs)
self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, dout,)))
self.batchnorm = bridge.LinenToNNX(nn.BatchNorm(use_running_average=True), rngs=rngs)
self.rngs = rngs

def __call__(self, x):
x = self.nn_dense1(x) + self.b
return self.batchnorm(x)

x = jax.random.normal(jax.random.key(0), (2, 4))
model = NNXOuter(3, rngs=nnx.Rngs(0))
gdef_before_lazy_init, _ = nnx.split(model)
bridge.lazy_init(model, x)
gdef_full, state = nnx.split(model)
assert gdef_before_lazy_init != gdef_full
assert 'params' in state.nn_dense1
assert 'batch_stats' in state.batchnorm
y = model(x)
k, b = state.nn_dense1.params.kernel.value, state.b.value
np.testing.assert_allclose(y, x @ k + b, rtol=1e-5)
assert gdef_full == nnx.graphdef(model) # static data is stable now

def test_linen_to_nnx_noncall_method(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
b = self.param('b', nn.zeros_init(), (1, 3,))
return self.dot(x) + b

@nn.compact
def dot(self, x):
w = self.param('w', nn.initializers.lecun_normal(), (4, 3))
return x @ w

x = jax.random.normal(jax.random.key(0), (2, 4))
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0))
bridge.lazy_init(model, x, method=model.module.dot)
y = model(x, method=model.module.dot)
np.testing.assert_allclose(y, x @ nnx.state(model).params.w.value)
# lazy_init only initialized param w inside dot(), so calling __call__ should fail
with self.assertRaises(flax.errors.ScopeParamNotFoundError):
y = model(x)

def test_linen_to_nnx_mutable(self):
class Foo(nn.Module):
def setup(self):
self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))

def __call__(self, x):
if not self.is_initializing():
self.count.value += 1
return x

x = lambda: jnp.zeros((), jnp.int32)
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x)
assert nnx.state(model).counter.count.value == 0
y = model(x, mutable=True)
assert nnx.state(model).counter.count.value == 1


if __name__ == '__main__':
absltest.main()
Loading