diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 774953ed9b97..857466a919b7 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,15 @@ from __future__ import annotations +import collections from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools +import json import re -from typing import Any, Union, cast +from typing import Any, TypeVar, Union, cast import warnings from absl import logging @@ -344,6 +346,161 @@ def deserialize(blob: bytearray) -> Exported: return deserialize(blob) +T = TypeVar("T") +PyTreeAuxData = Any # alias for tree_util._AuxData + +# A function to serialize the PyTree node AuxData (returned by the +# `flatten_func` registered by `tree_util.register_pytree_node`). +_SerializeAuxData = Callable[[PyTreeAuxData], bytes] +# A function to deserialize the AuxData, and produce something ready for +# `_BuildFromChildren` below. +_DeserializeAuxData = Callable[[bytes], PyTreeAuxData] +# A function to materialize a T given a deserialized AuxData and children. +# This is similar in scope with the `unflatten_func` +_BuildFromChildren = Callable[[PyTreeAuxData, Sequence[Any]], Any] + + +custom_pytree_node_registry_serialization: dict[ + type, + tuple[str, _SerializeAuxData]] = {} + + +custom_pytree_node_registry_deserialization: dict[ + str, + tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {} + + +def _is_namedtuple(nodetype: type) -> bool: + return (issubclass(nodetype, tuple) and + hasattr(nodetype, "_fields") and + isinstance(nodetype._fields, Sequence) and + all(isinstance(f, str) for f in nodetype._fields)) + +def register_pytree_node_serialization( + nodetype: type[T], + *, + serialized_name: str, + serialize_auxdata: _SerializeAuxData, + deserialize_auxdata: _DeserializeAuxData, + from_children: _BuildFromChildren | None = None +) -> type[T]: + """Registers a custom PyTree node for serialization and deserialization. + + You must use this function before you can serialize and deserialize PyTree + nodes for the types not supported natively. We serialize PyTree nodes for + the `in_tree` and `out_tree` fields of `Exported`, which are part of the + exported function's calling convention. + + This function must be called after calling + `jax.tree_util.register_pytree_node` (except for `collections.namedtuple`, + which do not require a call to `register_pytree_node`). + + Args: + nodetype: the type whose PyTree nodes we want to serialize. Subsequent + registrations for the same type will overwrite old registrations. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. + serialize_auxdata: serialize the PyTree auxdata (returned by the + `flatten_func` argument to `jax.tree_util.register_pytree_node`.). + deserialize_auxdata: deserialize the auxdata that was serialized by the + `serialize_auxdata`. + from_children: if present, this is a function that takes that result of + `deserialize_auxdata` along with some children and creates an instance + of `nodetype`. This is similar to the `unflatten_func` passed to + `jax.tree_util.register_pytree_node`. If not present, we look up + and use the `unflatten_func`. This is needed for `collections.namedtuple`, + which does not have a `register_pytree_node`, but it can be useful to + override that function. Note that the result of `from_children` is + only used with `jax.tree_util.tree_structure` to construct a proper + PyTree node, it is not used to construct the outputs of the serialized + function. + + Returns: the same type passed as `nodetype`, so that this function can + be used as a class decorator. + + """ + if nodetype in custom_pytree_node_registry_serialization: + raise ValueError( + f"Duplicate serialization registration for type `{nodetype}`. " + "Previous registration was with serialized_name " + f"`{custom_pytree_node_registry_serialization[nodetype][0]}`.") + if serialized_name in custom_pytree_node_registry_deserialization: + raise ValueError( + "Duplicate serialization registration for " + f"serialized_name `{serialized_name}`. " + "Previous registration was for type " + f"`{custom_pytree_node_registry_deserialization[serialized_name][0]}`.") + if from_children is None: + if nodetype not in tree_util._registry: + raise ValueError( + f"If `from_children` is not present, you must call first" + f"`jax.tree_util.register_pytree_node` for `{nodetype}`") + from_children = tree_util._registry[nodetype].from_iter + + custom_pytree_node_registry_serialization[nodetype] = ( + serialized_name, serialize_auxdata) + custom_pytree_node_registry_deserialization[serialized_name] = ( + nodetype, deserialize_auxdata, from_children) + return nodetype + + +def register_namedtuple_serialization( + nodetype: type[T], + *, + serialized_name: str) -> type[T]: + """Registers a namedtuple for serialization and deserialization. + + JAX has native PyTree support for `collections.namedtuple`, and does not + require a call to `jax.tree_util.register_pytree_node`. However, if you + want to serialize functions that have inputs of outputs of a + namedtuple type, you must register that type for serialization. + + Args: + nodetype: the type whose PyTree nodes we want to serialize. Subsequent + registrations for the same type will overwrite old registrations. + On deserialization, this type must have the same set of keys that + were present during serialization. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. + + Returns: the same type passed as `nodetype`, so that this function can + be used as a class decorator. +""" + if not _is_namedtuple(nodetype): + raise ValueError("Use `jax.export.register_pytree_node_serialization` for " + "types other than `collections.namedtuple`.") + + def serialize_auxdata(aux: PyTreeAuxData) -> bytes: + # Store the serialized keys in the serialized auxdata + del aux + return json.dumps(nodetype._fields).encode("utf-8") + + def deserialize_auxdata(b: bytes) -> PyTreeAuxData: + return json.loads(b.decode("utf-8")) + + def from_children(aux: PyTreeAuxData, children: Sequence[Any]) -> T: + # Use our own "from_children" because namedtuples do not have a pytree + # registration. + ser_keys = cast(Sequence[str], aux) + assert len(ser_keys) == len(children) + return nodetype(** dict(zip(ser_keys, children))) + + return register_pytree_node_serialization(nodetype, + serialized_name=serialized_name, + serialize_auxdata=serialize_auxdata, + deserialize_auxdata=deserialize_auxdata, + from_children=from_children) + + +# collections.OrderedDict is registered as a pytree node with auxdata being +# `tuple(x.keys())`. +register_pytree_node_serialization( + collections.OrderedDict, + serialized_name="collections.OrderedDict", + serialize_auxdata=lambda keys: json.dumps(keys).encode("utf-8"), + deserialize_auxdata=lambda b: json.loads(b.decode("utf-8"))) + + def default_export_platform() -> str: """Retrieves the default export platform. diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950adaa8e..b72d0134cf1f 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -28,12 +28,15 @@ enum PyTreeDefKind: byte { tuple = 2, list = 3, dict = 4, + custom = 5, } table PyTreeDef { kind: PyTreeDefKind; children: [PyTreeDef]; - children_names: [string]; // only for "dict" + children_names: [string]; // only for "kind==dict" + custom_name: string; // only for "kind==custom" + custom_auxdata: [byte]; // only for "kind==custom" } enum AbstractValueKind: byte { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..0849d9f6acbe 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -45,6 +45,8 @@ # even if the change is backwards compatible. # Version 1, Nov 2023, first version. # Version 2, Dec 16th, 2023, adds the f0 dtype. +# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types +# This version is backwards compatible with Version 2. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -152,13 +154,13 @@ def _serialize_array( def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() - if serialization_version != _SERIALIZATION_VERSION: + if serialization_version not in [2, 3]: raise NotImplementedError( f"deserialize unsupported version {serialization_version}" ) fun_name = exp.FunctionName().decode("utf-8") - _, in_tree = tree_util.tree_flatten( + in_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.InTree()) ) scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints @@ -166,7 +168,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval ) - _, out_tree = tree_util.tree_flatten( + out_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.OutTree()) ) out_avals = _deserialize_tuple( @@ -246,23 +248,35 @@ def _serialize_pytreedef( children_vector_offset = _serialize_array( builder, _serialize_pytreedef, children ) + custom_name = None + custom_auxdata = None + node_type = node_data and node_data[0] if node_data is None: # leaf kind = ser_flatbuf.PyTreeDefKind.leaf - elif node_data[0] is type(None): + elif node_type is type(None): kind = ser_flatbuf.PyTreeDefKind.none - elif node_data[0] is tuple: + elif node_type is tuple: kind = ser_flatbuf.PyTreeDefKind.tuple - elif node_data[0] is list: + elif node_type is list: kind = ser_flatbuf.PyTreeDefKind.list - elif node_data[0] is dict: + elif node_type is dict: kind = ser_flatbuf.PyTreeDefKind.dict assert len(node_data[1]) == len(children) children_names_vector_offset = _serialize_array( builder, lambda b, s: b.CreateString(s), node_data[1] ) + elif node_type in _export.custom_pytree_node_registry_serialization: + kind = ser_flatbuf.PyTreeDefKind.custom + serialized_name, serialize_auxdata = _export.custom_pytree_node_registry_serialization[node_type] + custom_name = builder.CreateString(serialized_name) + custom_auxdata = builder.CreateByteVector(serialize_auxdata(node_data[1])) else: - raise NotImplementedError(f"serializing PyTreeDef {node_data}") + raise ValueError( + "Cannot serialize PyTreeDef containing an " + f"unregistered type `{node_type}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") ser_flatbuf.PyTreeDefStart(builder) ser_flatbuf.PyTreeDefAddKind(builder, kind) @@ -270,6 +284,10 @@ def _serialize_pytreedef( ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset) if children_names_vector_offset: ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset) + if custom_name is not None: + ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name) + if custom_auxdata is not None: + ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata) return ser_flatbuf.PyTreeDefEnd(builder) @@ -294,6 +312,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): assert p.ChildrenNamesLength() == nr_children keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)] return dict(zip(keys, children)) + elif kind == ser_flatbuf.PyTreeDefKind.custom: + serialized_name = p.CustomName().decode("utf-8") + if serialized_name not in _export.custom_pytree_node_registry_deserialization: + raise ValueError( + "Cannot deserialize a PyTreeDef containing an " + f"unregistered type `{serialized_name}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") + nodetype, deserialize_auxdata, from_iter = _export.custom_pytree_node_registry_deserialization[serialized_name] + auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes()) + return from_iter(auxdata, children) else: assert False, kind diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9fdd..18dd2c3cbab1 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,20 +21,21 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind: +class PyTreeDefKind(object): leaf = 0 none = 1 tuple = 2 list = 3 dict = 4 + custom = 5 -class AbstractValueKind: +class AbstractValueKind(object): shapedArray = 0 abstractToken = 1 -class DType: +class DType(object): bool = 0 i8 = 1 i16 = 2 @@ -60,18 +61,18 @@ class DType: f0 = 22 -class ShardingKind: +class ShardingKind(object): unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind: +class DisabledSafetyCheckKind(object): platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef: +class PyTreeDef(object): __slots__ = ['_tab'] @classmethod @@ -140,8 +141,42 @@ def ChildrenNamesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 + # PyTreeDef + def CustomName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # PyTreeDef + def CustomAuxdata(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # PyTreeDef + def CustomAuxdataAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # PyTreeDef + def CustomAuxdataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PyTreeDef + def CustomAuxdataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + def PyTreeDefStart(builder): - builder.StartObject(3) + builder.StartObject(5) def PyTreeDefAddKind(builder, kind): builder.PrependInt8Slot(0, kind, 0) @@ -158,12 +193,21 @@ def PyTreeDefAddChildrenNames(builder, childrenNames): def PyTreeDefStartChildrenNamesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def PyTreeDefAddCustomName(builder, customName): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(customName), 0) + +def PyTreeDefAddCustomAuxdata(builder, customAuxdata): + builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(customAuxdata), 0) + +def PyTreeDefStartCustomAuxdataVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + def PyTreeDefEnd(builder): return builder.EndObject() -class AbstractValue: +class AbstractValue(object): __slots__ = ['_tab'] @classmethod @@ -235,7 +279,7 @@ def AbstractValueEnd(builder): -class Sharding: +class Sharding(object): __slots__ = ['_tab'] @classmethod @@ -304,7 +348,7 @@ def ShardingEnd(builder): -class Effect: +class Effect(object): __slots__ = ['_tab'] @classmethod @@ -340,7 +384,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck: +class DisabledSafetyCheck(object): __slots__ = ['_tab'] @classmethod @@ -386,7 +430,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported: +class Exported(object): __slots__ = ['_tab'] @classmethod diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index b67354bb4248..d49aa296328a 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -35,8 +35,8 @@ "call": (_deprecation_message, _src_export.call), "call_exported": (_deprecation_message, _src_export.call_exported), "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version), + "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), + "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), "serialize": (_deprecation_message, _src_serialization.serialize), "deserialize": (_deprecation_message, _src_serialization.deserialize), diff --git a/jax/export.py b/jax/export.py index 13186f886f43..d9559909f158 100644 --- a/jax/export.py +++ b/jax/export.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. __all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "register_pytree_node_serialization", + "register_namedtuple_serialization", "maximum_supported_calling_convention_version", "minimum_supported_calling_convention_version", "default_export_platform", @@ -23,6 +25,8 @@ Exported, export, deserialize, + register_pytree_node_serialization, + register_namedtuple_serialization, maximum_supported_calling_convention_version, minimum_supported_calling_convention_version, default_export_platform) diff --git a/tests/export_test.py b/tests/export_test.py index 0d946d84d22b..fd6bef11ee43 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations +import collections from collections.abc import Callable, Sequence import contextlib import dataclasses import functools import logging +import json import math import re import unittest @@ -32,6 +34,7 @@ from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax import tree_util from jax._src import config from jax._src import core @@ -311,6 +314,123 @@ def f(a_b_pair, a, b): self.assertAllClose(f((a, b), a=a, b=b), exp_f.call((a, b), a=a, b=b)) + def test_pytree_namedtuple(self): + T = collections.namedtuple("SomeType", ("a", "b", "c")) + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType", + ) + x = T(a=1, b=2, c=3) + + def f(x): + return (x, x) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x) + res = exp.call(x) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure((x, x))) + self.assertEqual(type(res[0]), type(x)) + self.assertEqual(type(res[1]), type(x)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_pytree_namedtuple_error(self): + T = collections.namedtuple("SomeType", ("a", "b")) + x = T(a=1, b=2) + with self.assertRaisesRegex( + ValueError, + "Cannot serialize .* unregistered type .*SomeType"): + export.export(jax.jit(lambda x: x))(x).serialize() + + with self.assertRaisesRegex( + ValueError, + "If `from_children` is not present.* must call.*register_pytree_node" + ): + export.register_pytree_node_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType_V2", + serialize_auxdata=lambda x: b"", + deserialize_auxdata=lambda b: None + ) + + with self.assertRaisesRegex(ValueError, + "Use .*register_pytree_node_serialization"): + export.register_namedtuple_serialization(str, serialized_name="n/a") + + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.OtherType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + collections.namedtuple("SomeOtherType", ("a", "b")), + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + def test_pytree_custom_types(self): + x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]) + + @tree_util.register_pytree_node_class + class CustomType: + def __init__(self, a: int, b: CustomType | None, string: str): + self.a = a + self.b = b + self.string = string + + def tree_flatten(self): + return ((self.a, self.b), self.string) + + @classmethod + def tree_unflatten(cls, aux_data, children): + string = aux_data + return cls(*children, string) + + export.register_pytree_node_serialization( + CustomType, + serialized_name="test_pytree_custom_types.CustomType", + serialize_auxdata=lambda aux: aux.encode("utf-8"), + deserialize_auxdata=lambda b: b.decode("utf-8") + ) + x2 = CustomType(4, 5, "foo") + + def f(x1, x2): + return (x1, x2, x1, x2) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x1, x2) + res = exp.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure(((x1, x2, x1, x2)))) + self.assertEqual(type(res[0]), type(x1)) + self.assertEqual(type(res[1]), type(x2)) + self.assertEqual(type(res[2]), type(x1)) + self.assertEqual(type(res[3]), type(x2)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c