Skip to content

Commit

Permalink
Add shape-based lazy init to LinenToNNX (prev LinenWrapper)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Jul 22, 2024
1 parent ceacc09 commit 0139c90
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 44 deletions.
5 changes: 3 additions & 2 deletions flax/nnx/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .module import Scope as Scope
from .module import compact as compact
from .wrappers import functional as functional
from .wrappers import LinenWrapper as LinenWrapper
from .wrappers import LinenToNNX as LinenToNNX
from .wrappers import Functional as Functional
from .wrappers import NNXWrapper as NNXWrapper
from .wrappers import NNXToLinen as NNXToLinen
from .wrappers import lazy_init as lazy_init
109 changes: 74 additions & 35 deletions flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

from flax import nnx
from flax import linen
from flax.nnx.nnx import graph
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
from flax.nnx.nnx.object import Object
import jax
from jax import tree_util as jtu

M = tp.TypeVar('M', bound=Module)

Expand Down Expand Up @@ -55,56 +59,91 @@ def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]:
return _functional_constructor


class LinenWrapper(Module):
def _set_initializing(module: Module, initializing: bool):
for _, value in graph.iter_graph(module):
if isinstance(value, Object):
value._object__state._initializing = initializing


def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
"""To run through an arbitrary nnx.Module method and initialize all its needed state.
Here used to trigger initialization of all `LinenToNNX` module variables."""
if isinstance(fn, Module):
module = fn
assert callable(fn)
else:
assert hasattr(fn, '__self__') and isinstance(fn.__self__, Module), f'{fn = } needs to be a method of an NNX Module.'
module = fn.__self__
_set_initializing(module, True)
try:
_ = fn(*args, **kwargs)
finally:
_set_initializing(module, False)
return fn


class LinenToNNX(Module):
def __init__(
self,
module: linen.Module,
*args: tp.Any,
rngs: tp.Optional[Rngs] = None,
**kwargs: tp.Any,
):
self.module = module
self.rngs = rngs
self.linen_collections: set[str] = set()

_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
if rngs
else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs['default']
del _rngs['default']

variables = module.init(_rngs, *args, **kwargs)

self.states = {
collection: variableslib.variable_type(collection)(value)
for collection, value in variables.items()
}
def lazy_init(self, *args, **kwargs):
return lazy_init(self, *args, **kwargs)

def __call__(
self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any
self, *args: Any, rngs: tp.Optional[Rngs] = None,
method: tp.Callable[..., Any] | str | None = None, **kwargs: Any
) -> Any:
_rngs = (
{name: stream.key.value for name, stream in rngs.items()} if rngs else {}
)

variables = {
collection: value.value for collection, value in self.states.items()
}
out = self.module.apply(variables, *args, rngs=_rngs, **kwargs)

# Shape-based lazy init of the flax variables
if not rngs:
rngs = self.rngs
if self._object__state.initializing:
_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
if rngs
else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)
def nn_var_to_nnx_state(kp, v):
assert isinstance(kp[0], jtu.DictKey)
vtype = variableslib.variable_type(kp[0].key)
return vtype(v)
for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items():
self._setattr(col, tree)
self.linen_collections.add(col)

else:
variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col))
for col in self.linen_collections}
_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
)
out = self.module.apply(variables, *args, rngs=_rngs, method=method, **kwargs)

# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
for collection, value in updates.items():
if collection in self.states:
self.states[collection] = value
else:
self.states[collection] = variableslib.variable_type(collection)(
value
)
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value))

return out


class NNXWrapper(linen.Module): ...
class NNXToLinen(linen.Module):
module: Module

def setup(self):
...

def __call__(self, *args, **kwargs):
...
85 changes: 78 additions & 7 deletions flax/nnx/tests/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from absl.testing import absltest

from flax import linen
import flax
from flax import linen as nn
from flax import nnx
from flax.nnx import bridge
import jax
import jax.numpy as jnp
import numpy as np


class TestCompatibility:
class TestCompatibility(absltest.TestCase):
def test_functional(self):
# Functional API for NNX Modules
functional = bridge.functional(nnx.Linear)(32, 64)
state = functional.init(rngs=nnx.Rngs(0))
x = jax.numpy.ones((1, 32))
y, updates = functional.apply(state)(x)

def test_linen_wrapper(self):
def test_linen_to_nnx(self):
## Wrapper API for Linen Modules
linen_module = linen.Dense(features=64)
linen_module = nn.Dense(features=64)
x = jax.numpy.ones((1, 32))
module = bridge.LinenWrapper(linen_module, x, rngs=nnx.Rngs(0)) # init
y = module(x) # apply
model = bridge.LinenToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init
y = model(x) # like linen apply
assert y.shape == (1, 64)

def test_linen_to_nnx_submodule(self):
class NNXOuter(nnx.Module):
def __init__(self, dout: int, *, rngs: nnx.Rngs):
self.nn_dense1 = bridge.LinenToNNX(nn.Dense(dout, use_bias=False), rngs=rngs)
self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, dout,)))
self.batchnorm = bridge.LinenToNNX(nn.BatchNorm(use_running_average=True), rngs=rngs)
self.rngs = rngs

def __call__(self, x):
x = self.nn_dense1(x) + self.b
return self.batchnorm(x)

x = jax.random.normal(jax.random.key(0), (2, 4))
model = NNXOuter(3, rngs=nnx.Rngs(0))
gdef_before_lazy_init, _ = nnx.split(model)
bridge.lazy_init(model, x)
gdef_full, state = nnx.split(model)
assert gdef_before_lazy_init != gdef_full
assert 'params' in state.nn_dense1
assert 'batch_stats' in state.batchnorm
y = model(x)
k, b = state.nn_dense1.params.kernel.value, state.b.value
np.testing.assert_allclose(y, x @ k + b, rtol=1e-5)
assert gdef_full == nnx.graphdef(model) # static data is stable now

def test_linen_to_nnx_noncall_method(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
b = self.param('b', nn.zeros_init(), (1, 3,))
return self.dot(x) + b

@nn.compact
def dot(self, x):
w = self.param('w', nn.initializers.lecun_normal(), (4, 3))
return x @ w

x = jax.random.normal(jax.random.key(0), (2, 4))
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0))
bridge.lazy_init(model, x, method=model.module.dot)
y = model(x, method=model.module.dot)
np.testing.assert_allclose(y, x @ nnx.state(model).params.w.value)
# lazy_init only initialized param w inside dot(), so calling __call__ should fail
with self.assertRaises(flax.errors.ScopeParamNotFoundError):
y = model(x)

def test_linen_to_nnx_mutable(self):
class Foo(nn.Module):
def setup(self):
self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))

def __call__(self, x):
if not self.is_initializing():
self.count.value += 1
return x

x = lambda: jnp.zeros((), jnp.int32)
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x)
assert nnx.state(model).counter.count.value == 0
y = model(x, mutable=True)
assert nnx.state(model).counter.count.value == 1


if __name__ == '__main__':
absltest.main()

0 comments on commit 0139c90

Please sign in to comment.