From 31238ac21e239fa44ce44d94b480b4173017aaa9 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 16 Oct 2024 18:08:25 +0100 Subject: [PATCH] [export] Add support for serialization for some custom PyTree nodes See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` fields of `Exported` functions (not to serialize actual instances of the custom types). When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing. The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, as it can result in strange import errors. Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized. Since I wanted to also add support for `collections.OrderedDict`, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions. --- jax/_src/export/_export.py | 159 ++++++++++++++++++++- jax/_src/export/serialization.fbs | 5 +- jax/_src/export/serialization.py | 45 ++++-- jax/_src/export/serialization_generated.py | 68 +++++++-- jax/experimental/export/__init__.py | 4 +- jax/export.py | 4 + tests/export_test.py | 120 ++++++++++++++++ 7 files changed, 381 insertions(+), 24 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 774953ed9b97..857466a919b7 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,15 @@ from __future__ import annotations +import collections from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools +import json import re -from typing import Any, Union, cast +from typing import Any, TypeVar, Union, cast import warnings from absl import logging @@ -344,6 +346,161 @@ def deserialize(blob: bytearray) -> Exported: return deserialize(blob) +T = TypeVar("T") +PyTreeAuxData = Any # alias for tree_util._AuxData + +# A function to serialize the PyTree node AuxData (returned by the +# `flatten_func` registered by `tree_util.register_pytree_node`). +_SerializeAuxData = Callable[[PyTreeAuxData], bytes] +# A function to deserialize the AuxData, and produce something ready for +# `_BuildFromChildren` below. +_DeserializeAuxData = Callable[[bytes], PyTreeAuxData] +# A function to materialize a T given a deserialized AuxData and children. +# This is similar in scope with the `unflatten_func` +_BuildFromChildren = Callable[[PyTreeAuxData, Sequence[Any]], Any] + + +custom_pytree_node_registry_serialization: dict[ + type, + tuple[str, _SerializeAuxData]] = {} + + +custom_pytree_node_registry_deserialization: dict[ + str, + tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {} + + +def _is_namedtuple(nodetype: type) -> bool: + return (issubclass(nodetype, tuple) and + hasattr(nodetype, "_fields") and + isinstance(nodetype._fields, Sequence) and + all(isinstance(f, str) for f in nodetype._fields)) + +def register_pytree_node_serialization( + nodetype: type[T], + *, + serialized_name: str, + serialize_auxdata: _SerializeAuxData, + deserialize_auxdata: _DeserializeAuxData, + from_children: _BuildFromChildren | None = None +) -> type[T]: + """Registers a custom PyTree node for serialization and deserialization. + + You must use this function before you can serialize and deserialize PyTree + nodes for the types not supported natively. We serialize PyTree nodes for + the `in_tree` and `out_tree` fields of `Exported`, which are part of the + exported function's calling convention. + + This function must be called after calling + `jax.tree_util.register_pytree_node` (except for `collections.namedtuple`, + which do not require a call to `register_pytree_node`). + + Args: + nodetype: the type whose PyTree nodes we want to serialize. Subsequent + registrations for the same type will overwrite old registrations. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. + serialize_auxdata: serialize the PyTree auxdata (returned by the + `flatten_func` argument to `jax.tree_util.register_pytree_node`.). + deserialize_auxdata: deserialize the auxdata that was serialized by the + `serialize_auxdata`. + from_children: if present, this is a function that takes that result of + `deserialize_auxdata` along with some children and creates an instance + of `nodetype`. This is similar to the `unflatten_func` passed to + `jax.tree_util.register_pytree_node`. If not present, we look up + and use the `unflatten_func`. This is needed for `collections.namedtuple`, + which does not have a `register_pytree_node`, but it can be useful to + override that function. Note that the result of `from_children` is + only used with `jax.tree_util.tree_structure` to construct a proper + PyTree node, it is not used to construct the outputs of the serialized + function. + + Returns: the same type passed as `nodetype`, so that this function can + be used as a class decorator. + + """ + if nodetype in custom_pytree_node_registry_serialization: + raise ValueError( + f"Duplicate serialization registration for type `{nodetype}`. " + "Previous registration was with serialized_name " + f"`{custom_pytree_node_registry_serialization[nodetype][0]}`.") + if serialized_name in custom_pytree_node_registry_deserialization: + raise ValueError( + "Duplicate serialization registration for " + f"serialized_name `{serialized_name}`. " + "Previous registration was for type " + f"`{custom_pytree_node_registry_deserialization[serialized_name][0]}`.") + if from_children is None: + if nodetype not in tree_util._registry: + raise ValueError( + f"If `from_children` is not present, you must call first" + f"`jax.tree_util.register_pytree_node` for `{nodetype}`") + from_children = tree_util._registry[nodetype].from_iter + + custom_pytree_node_registry_serialization[nodetype] = ( + serialized_name, serialize_auxdata) + custom_pytree_node_registry_deserialization[serialized_name] = ( + nodetype, deserialize_auxdata, from_children) + return nodetype + + +def register_namedtuple_serialization( + nodetype: type[T], + *, + serialized_name: str) -> type[T]: + """Registers a namedtuple for serialization and deserialization. + + JAX has native PyTree support for `collections.namedtuple`, and does not + require a call to `jax.tree_util.register_pytree_node`. However, if you + want to serialize functions that have inputs of outputs of a + namedtuple type, you must register that type for serialization. + + Args: + nodetype: the type whose PyTree nodes we want to serialize. Subsequent + registrations for the same type will overwrite old registrations. + On deserialization, this type must have the same set of keys that + were present during serialization. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. + + Returns: the same type passed as `nodetype`, so that this function can + be used as a class decorator. +""" + if not _is_namedtuple(nodetype): + raise ValueError("Use `jax.export.register_pytree_node_serialization` for " + "types other than `collections.namedtuple`.") + + def serialize_auxdata(aux: PyTreeAuxData) -> bytes: + # Store the serialized keys in the serialized auxdata + del aux + return json.dumps(nodetype._fields).encode("utf-8") + + def deserialize_auxdata(b: bytes) -> PyTreeAuxData: + return json.loads(b.decode("utf-8")) + + def from_children(aux: PyTreeAuxData, children: Sequence[Any]) -> T: + # Use our own "from_children" because namedtuples do not have a pytree + # registration. + ser_keys = cast(Sequence[str], aux) + assert len(ser_keys) == len(children) + return nodetype(** dict(zip(ser_keys, children))) + + return register_pytree_node_serialization(nodetype, + serialized_name=serialized_name, + serialize_auxdata=serialize_auxdata, + deserialize_auxdata=deserialize_auxdata, + from_children=from_children) + + +# collections.OrderedDict is registered as a pytree node with auxdata being +# `tuple(x.keys())`. +register_pytree_node_serialization( + collections.OrderedDict, + serialized_name="collections.OrderedDict", + serialize_auxdata=lambda keys: json.dumps(keys).encode("utf-8"), + deserialize_auxdata=lambda b: json.loads(b.decode("utf-8"))) + + def default_export_platform() -> str: """Retrieves the default export platform. diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950adaa8e..b72d0134cf1f 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -28,12 +28,15 @@ enum PyTreeDefKind: byte { tuple = 2, list = 3, dict = 4, + custom = 5, } table PyTreeDef { kind: PyTreeDefKind; children: [PyTreeDef]; - children_names: [string]; // only for "dict" + children_names: [string]; // only for "kind==dict" + custom_name: string; // only for "kind==custom" + custom_auxdata: [byte]; // only for "kind==custom" } enum AbstractValueKind: byte { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..0849d9f6acbe 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -45,6 +45,8 @@ # even if the change is backwards compatible. # Version 1, Nov 2023, first version. # Version 2, Dec 16th, 2023, adds the f0 dtype. +# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types +# This version is backwards compatible with Version 2. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -152,13 +154,13 @@ def _serialize_array( def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() - if serialization_version != _SERIALIZATION_VERSION: + if serialization_version not in [2, 3]: raise NotImplementedError( f"deserialize unsupported version {serialization_version}" ) fun_name = exp.FunctionName().decode("utf-8") - _, in_tree = tree_util.tree_flatten( + in_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.InTree()) ) scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints @@ -166,7 +168,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval ) - _, out_tree = tree_util.tree_flatten( + out_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.OutTree()) ) out_avals = _deserialize_tuple( @@ -246,23 +248,35 @@ def _serialize_pytreedef( children_vector_offset = _serialize_array( builder, _serialize_pytreedef, children ) + custom_name = None + custom_auxdata = None + node_type = node_data and node_data[0] if node_data is None: # leaf kind = ser_flatbuf.PyTreeDefKind.leaf - elif node_data[0] is type(None): + elif node_type is type(None): kind = ser_flatbuf.PyTreeDefKind.none - elif node_data[0] is tuple: + elif node_type is tuple: kind = ser_flatbuf.PyTreeDefKind.tuple - elif node_data[0] is list: + elif node_type is list: kind = ser_flatbuf.PyTreeDefKind.list - elif node_data[0] is dict: + elif node_type is dict: kind = ser_flatbuf.PyTreeDefKind.dict assert len(node_data[1]) == len(children) children_names_vector_offset = _serialize_array( builder, lambda b, s: b.CreateString(s), node_data[1] ) + elif node_type in _export.custom_pytree_node_registry_serialization: + kind = ser_flatbuf.PyTreeDefKind.custom + serialized_name, serialize_auxdata = _export.custom_pytree_node_registry_serialization[node_type] + custom_name = builder.CreateString(serialized_name) + custom_auxdata = builder.CreateByteVector(serialize_auxdata(node_data[1])) else: - raise NotImplementedError(f"serializing PyTreeDef {node_data}") + raise ValueError( + "Cannot serialize PyTreeDef containing an " + f"unregistered type `{node_type}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") ser_flatbuf.PyTreeDefStart(builder) ser_flatbuf.PyTreeDefAddKind(builder, kind) @@ -270,6 +284,10 @@ def _serialize_pytreedef( ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset) if children_names_vector_offset: ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset) + if custom_name is not None: + ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name) + if custom_auxdata is not None: + ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata) return ser_flatbuf.PyTreeDefEnd(builder) @@ -294,6 +312,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): assert p.ChildrenNamesLength() == nr_children keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)] return dict(zip(keys, children)) + elif kind == ser_flatbuf.PyTreeDefKind.custom: + serialized_name = p.CustomName().decode("utf-8") + if serialized_name not in _export.custom_pytree_node_registry_deserialization: + raise ValueError( + "Cannot deserialize a PyTreeDef containing an " + f"unregistered type `{serialized_name}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") + nodetype, deserialize_auxdata, from_iter = _export.custom_pytree_node_registry_deserialization[serialized_name] + auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes()) + return from_iter(auxdata, children) else: assert False, kind diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9fdd..18dd2c3cbab1 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,20 +21,21 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind: +class PyTreeDefKind(object): leaf = 0 none = 1 tuple = 2 list = 3 dict = 4 + custom = 5 -class AbstractValueKind: +class AbstractValueKind(object): shapedArray = 0 abstractToken = 1 -class DType: +class DType(object): bool = 0 i8 = 1 i16 = 2 @@ -60,18 +61,18 @@ class DType: f0 = 22 -class ShardingKind: +class ShardingKind(object): unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind: +class DisabledSafetyCheckKind(object): platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef: +class PyTreeDef(object): __slots__ = ['_tab'] @classmethod @@ -140,8 +141,42 @@ def ChildrenNamesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 + # PyTreeDef + def CustomName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # PyTreeDef + def CustomAuxdata(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # PyTreeDef + def CustomAuxdataAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # PyTreeDef + def CustomAuxdataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PyTreeDef + def CustomAuxdataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + def PyTreeDefStart(builder): - builder.StartObject(3) + builder.StartObject(5) def PyTreeDefAddKind(builder, kind): builder.PrependInt8Slot(0, kind, 0) @@ -158,12 +193,21 @@ def PyTreeDefAddChildrenNames(builder, childrenNames): def PyTreeDefStartChildrenNamesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def PyTreeDefAddCustomName(builder, customName): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(customName), 0) + +def PyTreeDefAddCustomAuxdata(builder, customAuxdata): + builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(customAuxdata), 0) + +def PyTreeDefStartCustomAuxdataVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + def PyTreeDefEnd(builder): return builder.EndObject() -class AbstractValue: +class AbstractValue(object): __slots__ = ['_tab'] @classmethod @@ -235,7 +279,7 @@ def AbstractValueEnd(builder): -class Sharding: +class Sharding(object): __slots__ = ['_tab'] @classmethod @@ -304,7 +348,7 @@ def ShardingEnd(builder): -class Effect: +class Effect(object): __slots__ = ['_tab'] @classmethod @@ -340,7 +384,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck: +class DisabledSafetyCheck(object): __slots__ = ['_tab'] @classmethod @@ -386,7 +430,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported: +class Exported(object): __slots__ = ['_tab'] @classmethod diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index b67354bb4248..d49aa296328a 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -35,8 +35,8 @@ "call": (_deprecation_message, _src_export.call), "call_exported": (_deprecation_message, _src_export.call_exported), "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version), + "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), + "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), "serialize": (_deprecation_message, _src_serialization.serialize), "deserialize": (_deprecation_message, _src_serialization.deserialize), diff --git a/jax/export.py b/jax/export.py index 13186f886f43..d9559909f158 100644 --- a/jax/export.py +++ b/jax/export.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. __all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "register_pytree_node_serialization", + "register_namedtuple_serialization", "maximum_supported_calling_convention_version", "minimum_supported_calling_convention_version", "default_export_platform", @@ -23,6 +25,8 @@ Exported, export, deserialize, + register_pytree_node_serialization, + register_namedtuple_serialization, maximum_supported_calling_convention_version, minimum_supported_calling_convention_version, default_export_platform) diff --git a/tests/export_test.py b/tests/export_test.py index 0d946d84d22b..fd6bef11ee43 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations +import collections from collections.abc import Callable, Sequence import contextlib import dataclasses import functools import logging +import json import math import re import unittest @@ -32,6 +34,7 @@ from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax import tree_util from jax._src import config from jax._src import core @@ -311,6 +314,123 @@ def f(a_b_pair, a, b): self.assertAllClose(f((a, b), a=a, b=b), exp_f.call((a, b), a=a, b=b)) + def test_pytree_namedtuple(self): + T = collections.namedtuple("SomeType", ("a", "b", "c")) + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType", + ) + x = T(a=1, b=2, c=3) + + def f(x): + return (x, x) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x) + res = exp.call(x) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure((x, x))) + self.assertEqual(type(res[0]), type(x)) + self.assertEqual(type(res[1]), type(x)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_pytree_namedtuple_error(self): + T = collections.namedtuple("SomeType", ("a", "b")) + x = T(a=1, b=2) + with self.assertRaisesRegex( + ValueError, + "Cannot serialize .* unregistered type .*SomeType"): + export.export(jax.jit(lambda x: x))(x).serialize() + + with self.assertRaisesRegex( + ValueError, + "If `from_children` is not present.* must call.*register_pytree_node" + ): + export.register_pytree_node_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType_V2", + serialize_auxdata=lambda x: b"", + deserialize_auxdata=lambda b: None + ) + + with self.assertRaisesRegex(ValueError, + "Use .*register_pytree_node_serialization"): + export.register_namedtuple_serialization(str, serialized_name="n/a") + + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.OtherType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + collections.namedtuple("SomeOtherType", ("a", "b")), + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + def test_pytree_custom_types(self): + x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]) + + @tree_util.register_pytree_node_class + class CustomType: + def __init__(self, a: int, b: CustomType | None, string: str): + self.a = a + self.b = b + self.string = string + + def tree_flatten(self): + return ((self.a, self.b), self.string) + + @classmethod + def tree_unflatten(cls, aux_data, children): + string = aux_data + return cls(*children, string) + + export.register_pytree_node_serialization( + CustomType, + serialized_name="test_pytree_custom_types.CustomType", + serialize_auxdata=lambda aux: aux.encode("utf-8"), + deserialize_auxdata=lambda b: b.decode("utf-8") + ) + x2 = CustomType(4, 5, "foo") + + def f(x1, x2): + return (x1, x2, x1, x2) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x1, x2) + res = exp.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure(((x1, x2, x1, x2)))) + self.assertEqual(type(res[0]), type(x1)) + self.assertEqual(type(res[1]), type(x2)) + self.assertEqual(type(res[2]), type(x1)) + self.assertEqual(type(res[3]), type(x2)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c