Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/rnglib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ rnglib
.. autofunction:: split_rngs
.. autofunction:: fork_rngs
.. autofunction:: reseed
.. autofunction:: with_rngs
13 changes: 8 additions & 5 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,14 @@ def static_int_env(varname: str, default: int | None) -> int | None:
)
nnx_graph_mode = bool_flag(
name='nnx_graph_mode',
default=True,
default=False,
help='Whether NNX APIs default to graph-mode (True) or tree-mode (False).',
)
nnx_graph_updates = bool_flag(
name='nnx_graph_updates',
default=True,
help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.',
)
name='nnx_graph_updates',
default=False,
help=(
'Whether graph-mode uses dynamic (True) or simple (False) graph'
' traversal.'
),
)
5 changes: 3 additions & 2 deletions flax/nnx/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
flatten = functools.partial(_graphlib.flatten, graph=True)
iter_graph = functools.partial(_graphlib.iter_graph, graph=True)
recursive_map = functools.partial(_graphlib.recursive_map, graph=True)
cached_partial = functools.partial(_graphlib.cached_partial, graph=True, graph_updates=True)

# module
view = functools.partial(_module.view, graph=True)
Expand All @@ -45,8 +46,8 @@
iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type]

# rnglib
split_rngs = functools.partial(_rnglib.split_rngs, graph=True)
fork_rngs = functools.partial(_rnglib.fork_rngs, graph=True)
split_rngs = functools.partial(_rnglib.split_rngs, graph=True, graph_updates=True)
fork_rngs = functools.partial(_rnglib.fork_rngs, graph=True, graph_updates=True)
reseed = functools.partial(_rnglib.reseed, graph=True)
backup_keys = functools.partial(_rnglib.backup_keys, graph=True)

Expand Down
86 changes: 77 additions & 9 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def check_consistent_aliasing2(
value_id = id(value)
node_id_to_variable[value_id] = value
# If prefix is a TreeState (e.g. from nnx.prefix(graph=True)),
# extract the actual prefix value for this variable using local_path.
# extract the actual prefix value for this Variable using local_path.
if isinstance(prefix, TreeState):
prefix_fn = prefix.prefix_fn.value
if not callable(prefix_fn):
Expand Down Expand Up @@ -200,7 +200,7 @@ def broadcast_prefix2(
) -> tuple[list[KeyPath], list[tp.Any]]:
paths: list[KeyPath] = []
leaves: list[tp.Any] = []
num_leaves = lambda t: jax.tree_util.tree_structure(t).num_leaves
num_leaves = lambda t: jax.tree_util.tree_structure(t, is_leaf=is_leaf).num_leaves
def add_leaves(path, x, subtree):
n = num_leaves(subtree)
paths.extend([path] * n)
Expand All @@ -215,10 +215,14 @@ def broadcast_prefix_map(
*rest: tp.Any,
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tp.Any:
paths, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf)
leaves, treedef = jax.tree_util.tree_flatten(full_tree, is_leaf=is_leaf)
full_prefix_tree = treedef.unflatten(prefix_leaves)
return jax.tree.map_with_path(f, full_prefix_tree, full_tree, *rest, is_leaf=is_leaf)
_, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf)
full_leaves_with_path, treedef = jax.tree.flatten_with_path(full_tree, is_leaf=is_leaf)
rest_flat = [treedef.flatten_up_to(r) for r in rest]
out_leaves = []
for (path, full_leaf), p_leaf, *r_leaves in zip(full_leaves_with_path, prefix_leaves, *rest_flat):
out_leaf = f(path, p_leaf, full_leaf, *r_leaves)
out_leaves.append(out_leaf)
return jax.tree.unflatten(treedef, out_leaves)


class GraphDefState(struct.PyTreeNode):
Expand Down Expand Up @@ -557,6 +561,69 @@ def replace_at(t: tuple, index: int, value: tp.Any) -> tuple:
for i, x in enumerate(t)
)


def slice_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]:
if index is None:
return None, t
return t[index], t[:index] + t[index + 1 :]


def insert_at(t: tuple, index: int | None, value: tp.Any) -> tuple:
if index is None:
return t
xs = list(t)
xs.insert(index, value)
return tuple(xs)


def find(t: tuple, value: tp.Any) -> int | None:
return next((i for i, x in enumerate(t) if x == value), None)


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)
class ExtractIndex:
index: int


def extract(
f: tp.Callable[[jax.tree_util.KeyPath, tp.Any, tp.Any], bool],
prefix: tp.Any,
tree: tp.Any,
*,
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tuple[tp.Any, list[tp.Any]]:
extracted: list[tp.Any] = []
def _leaf_fn(path: jax.tree_util.KeyPath, prefix_leaf: tp.Any, leaf: tp.Any):
if f(path, prefix_leaf, leaf):
idx = len(extracted)
extracted.append(leaf)
return ExtractIndex(idx)
return leaf

full_prefix = jax.tree.broadcast(prefix, tree, is_leaf=is_leaf)
new_tree = jax.tree.map_with_path(_leaf_fn, full_prefix, tree, is_leaf=is_leaf)
return new_tree, extracted


def insert(
tree: tp.Any,
extracted: list[tp.Any],
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tp.Any:
if is_leaf is None:
_is_leaf = lambda x: isinstance(x, ExtractIndex)
else:
_is_leaf = lambda x: isinstance(x, ExtractIndex) or is_leaf(x)

def _leaf_fn(leaf: tp.Any):
if isinstance(leaf, ExtractIndex):
return extracted[leaf.index]
return leaf

return jax.tree.map(_leaf_fn, tree, is_leaf=_is_leaf)


def updates_and_snapshot(args: A) -> tuple[A, A]:
is_leaf = lambda x: isinstance(x, variablelib.Variable)
leaves, treedef = jax.tree.flatten(args, is_leaf=is_leaf)
Expand Down Expand Up @@ -613,7 +680,8 @@ def check_no_aliases(
f' - {seen_path_str}\n'
f' - {path_str}\n\n'
f'nnx.{fn_name} with graph_updates=False does not support '
'returning input Variables as outputs. '
'Variable aliasing (duplicate inputs, duplicate outputs, or '
'input Variables returned as outputs). '
f'Consider the following options:\n\n'
f'1. Remove the duplicate Variables.\n'
f'2. Create new Variables via nnx.clone() and use those instead.\n'
Expand Down Expand Up @@ -816,9 +884,9 @@ def forward(m1, m2, x):
filters = list(filter_map.keys())

def prefix_fn(path, leaf):
for predicate, value in predicates:
for predicate, _prefix in predicates:
if predicate(path, leaf):
return value
return _prefix
raise ValueError(
f'No filter matched leaf at path {path!r} with value {leaf!r}. '
f'Filters: {filters}'
Expand Down
19 changes: 15 additions & 4 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def _check_valid_pytree(

@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)


class NoUpdate: ...


Expand All @@ -108,6 +110,8 @@ class NoUpdate: ...

@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)


class Repeated: ...


Expand Down Expand Up @@ -1576,7 +1580,12 @@ def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]):
)


def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bool | None = None):
def _cached_partial(
f: tp.Callable[..., tp.Any],
*cached_args,
graph: bool | None = None,
graph_updates: bool | None = None,
):
"""Create a partial from a NNX transformed function alog with some cached input arguments
and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful
for speed up function that are called repeatedly with the same subset of inputs e.g. a
Expand Down Expand Up @@ -1625,10 +1634,12 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bool | Non
"""
if graph is None:
graph = set_graph_mode.current_value()
if not graph:
if graph_updates is None:
graph_updates = set_graph_updates.current_value()

if not graph or not graph_updates:
raise ValueError(
'cached_partial is a graph-mode-only API and does not support '
'tree-mode (graph=False).'
'cached_partial is a graph-mode-only API and requires graph_updates=True.'
)
cache: tp.MutableMapping[tp.Any, StaticCache] = PythonRefMap() # type: ignore
original_ref_index: RefMap = RefMap()
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def __call__(
# we use split_rngs with splits=1 and squeeze=True to get unique rngs
# every time RNN is called
@nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True)
@nnx.scan(
@nnx.compat.scan(
in_axes=(state_axes, iteration.Carry, time_axis),
out_axes=(iteration.Carry, (0, time_axis))
if slice_carry
Expand Down
Loading
Loading