From a207fe9b7756cb9c816a33883c1f07c28ae328b6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 31 Jul 2024 06:40:53 -0700 Subject: [PATCH] Export `KeyPath` and related types to `jax.tree_util` These types lie on the APIs in `jax.tree_util`, so it makes sense to export them. PiperOrigin-RevId: 657987755 --- docs/jax.tree_util.rst | 2 ++ jax/tree_util.py | 2 ++ tests/package_structure_test.py | 12 ++++++++---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 35bce340d4de..73fd1f376e9f 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -26,6 +26,8 @@ List of Functions treedef_children treedef_is_leaf treedef_tuple + KeyEntry + KeyPath keystr Legacy APIs diff --git a/jax/tree_util.py b/jax/tree_util.py index 800086c220ac..b4854c7dfbf1 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -42,6 +42,8 @@ DictKey as DictKey, FlattenedIndexKey as FlattenedIndexKey, GetAttrKey as GetAttrKey, + KeyEntry as KeyEntry, + KeyPath as KeyPath, Partial as Partial, PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 448f8f1e7f93..e9944ec084af 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str class PackageStructureTest(jtu.JaxTestCase): + @parameterized.parameters([ - # TODO(jakevdp): expand test to other public modules. - _mod("jax.errors"), - _mod("jax.nn.initializers"), - _mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']), + # TODO(jakevdp): expand test to other public modules. + _mod("jax.errors"), + _mod("jax.nn.initializers"), + _mod( + "jax.tree_util", + exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"], + ), ]) def test_exported_names_match_module(self, module_name, include, exclude): """Test that all public exports have __module__ set correctly."""