diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 35bce340d4de..73fd1f376e9f 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -26,6 +26,8 @@ List of Functions treedef_children treedef_is_leaf treedef_tuple + KeyEntry + KeyPath keystr Legacy APIs diff --git a/jax/tree_util.py b/jax/tree_util.py index 800086c220ac..b4854c7dfbf1 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -42,6 +42,8 @@ DictKey as DictKey, FlattenedIndexKey as FlattenedIndexKey, GetAttrKey as GetAttrKey, + KeyEntry as KeyEntry, + KeyPath as KeyPath, Partial as Partial, PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 448f8f1e7f93..e9944ec084af 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str class PackageStructureTest(jtu.JaxTestCase): + @parameterized.parameters([ - # TODO(jakevdp): expand test to other public modules. - _mod("jax.errors"), - _mod("jax.nn.initializers"), - _mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']), + # TODO(jakevdp): expand test to other public modules. + _mod("jax.errors"), + _mod("jax.nn.initializers"), + _mod( + "jax.tree_util", + exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"], + ), ]) def test_exported_names_match_module(self, module_name, include, exclude): """Test that all public exports have __module__ set correctly."""