Skip to content

Commit

Permalink
Export KeyPath and related types to jax.tree_util
Browse files Browse the repository at this point in the history
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
  • Loading branch information
Google-ML-Automation authored and jax authors committed Jul 31, 2024
1 parent 9dba6eb commit a207fe9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/jax.tree_util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ List of Functions
treedef_children
treedef_is_leaf
treedef_tuple
KeyEntry
KeyPath
keystr

Legacy APIs
Expand Down
2 changes: 2 additions & 0 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
DictKey as DictKey,
FlattenedIndexKey as FlattenedIndexKey,
GetAttrKey as GetAttrKey,
KeyEntry as KeyEntry,
KeyPath as KeyPath,
Partial as Partial,
PyTreeDef as PyTreeDef,
SequenceKey as SequenceKey,
Expand Down
12 changes: 8 additions & 4 deletions tests/package_structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str


class PackageStructureTest(jtu.JaxTestCase):

@parameterized.parameters([
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.nn.initializers"),
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']),
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
),
])
def test_exported_names_match_module(self, module_name, include, exclude):
"""Test that all public exports have __module__ set correctly."""
Expand Down

0 comments on commit a207fe9

Please sign in to comment.