Skip to content

Commit

Permalink
[export] Add support for serialization for some custom PyTree nodes
Browse files Browse the repository at this point in the history
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
  • Loading branch information
gnecula committed Oct 19, 2024
1 parent bb271aa commit 31238ac
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 24 deletions.
159 changes: 158 additions & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

from __future__ import annotations

import collections
from collections.abc import Callable, Sequence
import copy
import dataclasses
import functools
import itertools
import json
import re
from typing import Any, Union, cast
from typing import Any, TypeVar, Union, cast
import warnings

from absl import logging
Expand Down Expand Up @@ -344,6 +346,161 @@ def deserialize(blob: bytearray) -> Exported:
return deserialize(blob)


T = TypeVar("T")
PyTreeAuxData = Any # alias for tree_util._AuxData

# A function to serialize the PyTree node AuxData (returned by the
# `flatten_func` registered by `tree_util.register_pytree_node`).
_SerializeAuxData = Callable[[PyTreeAuxData], bytes]
# A function to deserialize the AuxData, and produce something ready for
# `_BuildFromChildren` below.
_DeserializeAuxData = Callable[[bytes], PyTreeAuxData]
# A function to materialize a T given a deserialized AuxData and children.
# This is similar in scope with the `unflatten_func`
_BuildFromChildren = Callable[[PyTreeAuxData, Sequence[Any]], Any]


custom_pytree_node_registry_serialization: dict[
type,
tuple[str, _SerializeAuxData]] = {}


custom_pytree_node_registry_deserialization: dict[
str,
tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {}


def _is_namedtuple(nodetype: type) -> bool:
return (issubclass(nodetype, tuple) and
hasattr(nodetype, "_fields") and
isinstance(nodetype._fields, Sequence) and
all(isinstance(f, str) for f in nodetype._fields))

def register_pytree_node_serialization(
nodetype: type[T],
*,
serialized_name: str,
serialize_auxdata: _SerializeAuxData,
deserialize_auxdata: _DeserializeAuxData,
from_children: _BuildFromChildren | None = None
) -> type[T]:
"""Registers a custom PyTree node for serialization and deserialization.
You must use this function before you can serialize and deserialize PyTree
nodes for the types not supported natively. We serialize PyTree nodes for
the `in_tree` and `out_tree` fields of `Exported`, which are part of the
exported function's calling convention.
This function must be called after calling
`jax.tree_util.register_pytree_node` (except for `collections.namedtuple`,
which do not require a call to `register_pytree_node`).
Args:
nodetype: the type whose PyTree nodes we want to serialize. Subsequent
registrations for the same type will overwrite old registrations.
serialized_name: a string that will be present in the serialization and
will be used to look up the registration during deserialization.
serialize_auxdata: serialize the PyTree auxdata (returned by the
`flatten_func` argument to `jax.tree_util.register_pytree_node`.).
deserialize_auxdata: deserialize the auxdata that was serialized by the
`serialize_auxdata`.
from_children: if present, this is a function that takes that result of
`deserialize_auxdata` along with some children and creates an instance
of `nodetype`. This is similar to the `unflatten_func` passed to
`jax.tree_util.register_pytree_node`. If not present, we look up
and use the `unflatten_func`. This is needed for `collections.namedtuple`,
which does not have a `register_pytree_node`, but it can be useful to
override that function. Note that the result of `from_children` is
only used with `jax.tree_util.tree_structure` to construct a proper
PyTree node, it is not used to construct the outputs of the serialized
function.
Returns: the same type passed as `nodetype`, so that this function can
be used as a class decorator.
"""
if nodetype in custom_pytree_node_registry_serialization:
raise ValueError(
f"Duplicate serialization registration for type `{nodetype}`. "
"Previous registration was with serialized_name "
f"`{custom_pytree_node_registry_serialization[nodetype][0]}`.")
if serialized_name in custom_pytree_node_registry_deserialization:
raise ValueError(
"Duplicate serialization registration for "
f"serialized_name `{serialized_name}`. "
"Previous registration was for type "
f"`{custom_pytree_node_registry_deserialization[serialized_name][0]}`.")
if from_children is None:
if nodetype not in tree_util._registry:
raise ValueError(
f"If `from_children` is not present, you must call first"
f"`jax.tree_util.register_pytree_node` for `{nodetype}`")
from_children = tree_util._registry[nodetype].from_iter

custom_pytree_node_registry_serialization[nodetype] = (
serialized_name, serialize_auxdata)
custom_pytree_node_registry_deserialization[serialized_name] = (
nodetype, deserialize_auxdata, from_children)
return nodetype


def register_namedtuple_serialization(
nodetype: type[T],
*,
serialized_name: str) -> type[T]:
"""Registers a namedtuple for serialization and deserialization.
JAX has native PyTree support for `collections.namedtuple`, and does not
require a call to `jax.tree_util.register_pytree_node`. However, if you
want to serialize functions that have inputs of outputs of a
namedtuple type, you must register that type for serialization.
Args:
nodetype: the type whose PyTree nodes we want to serialize. Subsequent
registrations for the same type will overwrite old registrations.
On deserialization, this type must have the same set of keys that
were present during serialization.
serialized_name: a string that will be present in the serialization and
will be used to look up the registration during deserialization.
Returns: the same type passed as `nodetype`, so that this function can
be used as a class decorator.
"""
if not _is_namedtuple(nodetype):
raise ValueError("Use `jax.export.register_pytree_node_serialization` for "
"types other than `collections.namedtuple`.")

def serialize_auxdata(aux: PyTreeAuxData) -> bytes:
# Store the serialized keys in the serialized auxdata
del aux
return json.dumps(nodetype._fields).encode("utf-8")

def deserialize_auxdata(b: bytes) -> PyTreeAuxData:
return json.loads(b.decode("utf-8"))

def from_children(aux: PyTreeAuxData, children: Sequence[Any]) -> T:
# Use our own "from_children" because namedtuples do not have a pytree
# registration.
ser_keys = cast(Sequence[str], aux)
assert len(ser_keys) == len(children)
return nodetype(** dict(zip(ser_keys, children)))

return register_pytree_node_serialization(nodetype,
serialized_name=serialized_name,
serialize_auxdata=serialize_auxdata,
deserialize_auxdata=deserialize_auxdata,
from_children=from_children)


# collections.OrderedDict is registered as a pytree node with auxdata being
# `tuple(x.keys())`.
register_pytree_node_serialization(
collections.OrderedDict,
serialized_name="collections.OrderedDict",
serialize_auxdata=lambda keys: json.dumps(keys).encode("utf-8"),
deserialize_auxdata=lambda b: json.loads(b.decode("utf-8")))


def default_export_platform() -> str:
"""Retrieves the default export platform.
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/export/serialization.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ enum PyTreeDefKind: byte {
tuple = 2,
list = 3,
dict = 4,
custom = 5,
}

table PyTreeDef {
kind: PyTreeDefKind;
children: [PyTreeDef];
children_names: [string]; // only for "dict"
children_names: [string]; // only for "kind==dict"
custom_name: string; // only for "kind==custom"
custom_auxdata: [byte]; // only for "kind==custom"
}

enum AbstractValueKind: byte {
Expand Down
45 changes: 37 additions & 8 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
# even if the change is backwards compatible.
# Version 1, Nov 2023, first version.
# Version 2, Dec 16th, 2023, adds the f0 dtype.
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
# This version is backwards compatible with Version 2.
_SERIALIZATION_VERSION = 2

def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
Expand Down Expand Up @@ -152,21 +154,21 @@ def _serialize_array(

def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
serialization_version = exp.SerializationVersion()
if serialization_version != _SERIALIZATION_VERSION:
if serialization_version not in [2, 3]:
raise NotImplementedError(
f"deserialize unsupported version {serialization_version}"
)

fun_name = exp.FunctionName().decode("utf-8")
_, in_tree = tree_util.tree_flatten(
in_tree = tree_util.tree_structure(
_deserialize_pytreedef_to_pytree(exp.InTree())
)
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
deser_aval = partial(_deserialize_aval, scope=scope)
in_avals = _deserialize_tuple(
exp.InAvalsLength, exp.InAvals, deser_aval
)
_, out_tree = tree_util.tree_flatten(
out_tree = tree_util.tree_structure(
_deserialize_pytreedef_to_pytree(exp.OutTree())
)
out_avals = _deserialize_tuple(
Expand Down Expand Up @@ -246,30 +248,46 @@ def _serialize_pytreedef(
children_vector_offset = _serialize_array(
builder, _serialize_pytreedef, children
)
custom_name = None
custom_auxdata = None
node_type = node_data and node_data[0]

if node_data is None: # leaf
kind = ser_flatbuf.PyTreeDefKind.leaf
elif node_data[0] is type(None):
elif node_type is type(None):
kind = ser_flatbuf.PyTreeDefKind.none
elif node_data[0] is tuple:
elif node_type is tuple:
kind = ser_flatbuf.PyTreeDefKind.tuple
elif node_data[0] is list:
elif node_type is list:
kind = ser_flatbuf.PyTreeDefKind.list
elif node_data[0] is dict:
elif node_type is dict:
kind = ser_flatbuf.PyTreeDefKind.dict
assert len(node_data[1]) == len(children)
children_names_vector_offset = _serialize_array(
builder, lambda b, s: b.CreateString(s), node_data[1]
)
elif node_type in _export.custom_pytree_node_registry_serialization:
kind = ser_flatbuf.PyTreeDefKind.custom
serialized_name, serialize_auxdata = _export.custom_pytree_node_registry_serialization[node_type]
custom_name = builder.CreateString(serialized_name)
custom_auxdata = builder.CreateByteVector(serialize_auxdata(node_data[1]))
else:
raise NotImplementedError(f"serializing PyTreeDef {node_data}")
raise ValueError(
"Cannot serialize PyTreeDef containing an "
f"unregistered type `{node_type}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")

ser_flatbuf.PyTreeDefStart(builder)
ser_flatbuf.PyTreeDefAddKind(builder, kind)
if children_vector_offset:
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
if children_names_vector_offset:
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
if custom_name is not None:
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
if custom_auxdata is not None:
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
return ser_flatbuf.PyTreeDefEnd(builder)


Expand All @@ -294,6 +312,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
assert p.ChildrenNamesLength() == nr_children
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
return dict(zip(keys, children))
elif kind == ser_flatbuf.PyTreeDefKind.custom:
serialized_name = p.CustomName().decode("utf-8")
if serialized_name not in _export.custom_pytree_node_registry_deserialization:
raise ValueError(
"Cannot deserialize a PyTreeDef containing an "
f"unregistered type `{serialized_name}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")
nodetype, deserialize_auxdata, from_iter = _export.custom_pytree_node_registry_deserialization[serialized_name]
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
return from_iter(auxdata, children)
else:
assert False, kind

Expand Down
Loading

0 comments on commit 31238ac

Please sign in to comment.