Skip to content

Commit 1bd6b07

Browse files
committed
Apply suggestions from code review
1 parent 4dfafeb commit 1bd6b07

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

flax/nnx/module.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,23 +268,29 @@ def perturb(
268268
return old_value.value + value
269269

270270
def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
271-
"""Recursively iterates over all nested :class:`Module`'s of the current Module, including
272-
the current Module. Alias of ``nnx.iter_modules``.
271+
"""
272+
Warning: this method is method is deprecated; use :func:`iter_modules` instead.
273+
274+
Recursively iterates over all nested :class:`Module`'s of the current Module, including
275+
the current Module. Alias of :func:`iter_modules`.
273276
"""
274277
warnings.warn(
275-
"using the 'm.iter_modules()' method is deprecated; use the 'nnx.iter_modules(m)' function instead.",
278+
"The 'm.iter_modules()' method is deprecated; use the 'nnx.iter_modules(m)' function instead.",
276279
DeprecationWarning,
277280
stacklevel=2,
278281
)
279282
yield from iter_modules(self)
280283

281284
def iter_children(self) -> tp.Iterator[tuple[Key, Module]]:
282-
"""Iterates over all children :class:`Module`'s of the current Module. This
285+
"""
286+
Warning: this method is method is deprecated; use :func:`iter_children` instead.
287+
288+
Iterates over all children :class:`Module`'s of the current Module. This
283289
method is similar to :func:`iter_modules`, except it only iterates over the
284-
immediate children, and does not recurse further down. Alias of ``nnx.iter_children``.
290+
immediate children, and does not recurse further down. Alias of :func:`iter_children`.
285291
"""
286292
warnings.warn(
287-
"using the 'm.iter_children()' method is deprecated; use the 'nnx.iter_children(m)' function instead.",
293+
"The 'm.iter_children()' method is deprecated; use the 'nnx.iter_children(m)' function instead.",
288294
DeprecationWarning,
289295
stacklevel=2,
290296
)
@@ -441,7 +447,7 @@ def iter_modules(module: Module) -> tp.Iterator[tuple[PathParts, Module]]:
441447
"""Recursively iterates over all nested :class:`Module`'s of the given Module, including
442448
the argument.
443449
444-
``iter_modules`` creates a generator that yields the path and the Module instance, where
450+
Specifically, this function creates a generator that yields the path and the Module instance, where
445451
the path is a tuple of strings or integers representing the path to the Module from the
446452
root Module.
447453
@@ -482,7 +488,7 @@ def iter_children(module: Module) -> tp.Iterator[tuple[Key, Module]]:
482488
method is similar to :func:`iter_modules`, except it only iterates over the
483489
immediate children, and does not recurse further down.
484490
485-
``iter_children`` creates a generator that yields the key and the Module instance,
491+
Specifically, this function creates a generator that yields the key and the Module instance,
486492
where the key is a string representing the attribute name of the Module to access
487493
the corresponding child Module.
488494
@@ -503,7 +509,7 @@ def iter_children(module: Module) -> tp.Iterator[tuple[Key, Module]]:
503509
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
504510
...
505511
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
506-
>>> for path, module in iter_children(model):
512+
>>> for path, module in nnx.iter_children(model):
507513
... print(path, type(module).__name__)
508514
...
509515
batch_norm BatchNorm

0 commit comments

Comments
 (0)