-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shape-based lazy init to LinenToNNX
(prev LinenWrapper
)
#4081
Conversation
dd5936c
to
2c8963d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #4081 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 106 108 +2
Lines 13582 14045 +463
======================================
- Misses 13582 14045 +463 ☔ View full report in Codecov by Sentry. |
flax/nnx/nnx/bridge/wrappers.py
Outdated
_rngs['params'] = _rngs['default'] | ||
del _rngs['default'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_rngs['params'] = _rngs['default'] | |
del _rngs['default'] | |
_rngs['params'] = _rngs.pop('default') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I think about it, we have to make Rngs
implement MutableMapping
for either of these to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _rngs
here is a dict instead of Rngs
class, so this already works.
flax/nnx/nnx/bridge/wrappers.py
Outdated
"""To trigger init of all `LinenToNNX` module variables and return a wholesome state.""" | ||
assert callable(module) | ||
_ = module(*args, **kwargs) | ||
return nnx.split(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this function we should leverage the fact that Object._object__state._initializing
still exists and set if via something like
def _set_initializing(initializing: bool):
for _, value in graph.iter_graph(module):
if isinstance(value, Object):
value._object__state._initializing = initializing
and use the value of _initializing
to choose between init
and apply
when calling the Linen Modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added _set_initializing
to the LinenToNNX wrapper.
Note that we can't do check on top level modules' ._object__state._initializing
because the top level module might be a pure NNX module with ._object__state._initializing
always False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd do something like this:
def _set_initializing(module, initializing: bool):
for _, value in graph.iter_graph(module):
if isinstance(value, Object):
value._object__state._initializing = initializing
def shaped_init(module: Module, *args, **kwargs):
"""To trigger init of all `LinenToNNX` module variables and return a wholesome state."""
module = graph.clone(module) # create a copy
_set_initializing(module, True)
assert callable(module)
try:
_ = module(*args, **kwargs)
finally:
_set_initializing(module, False)
return nnx.split(module)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, also renamed to lazy_init
as discussed offline.
flax/nnx/nnx/bridge/wrappers.py
Outdated
# Shape-based lazy init of the flax variables | ||
if not rngs: | ||
rngs = self.rngs | ||
if not hasattr(self, 'states'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use self._object__state.initializing
instead, see above.
if not hasattr(self, 'states'): | |
if self._object__state.initializing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
rngs = self.rngs | ||
if self._object__state.initializing: | ||
_rngs = ( | ||
{name: stream.key.raw_value for name, stream in rngs.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to generate new keys so Linen Modules get new RNG state every time.
{name: stream.key.raw_value for name, stream in rngs.items()} | |
{name: stream() for name, stream in rngs.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
flax/nnx/nnx/bridge/wrappers.py
Outdated
if 'params' not in _rngs and 'default' in _rngs: | ||
_rngs['params'] = _rngs.pop('default') | ||
|
||
variables = self.module.init(_rngs, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could use init_with_output
to avoid calling forward twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
e9fa1e0
to
3f33ad8
Compare
LinenToNNX
andNNXToLinen
to minimize confusion.LinenToNNX
to__call__
to realize lazy init. This allows it to be a submodule of an NNX module, which doesn't have input args during initialization.nnx.shaped_init
to do a dry run of__call__
and initialize the whole state & full graphdef.LinenToNNX
nested & closer to NNX, aka. eachVariableState
is created for every jax Array, not every collection.