From 68130742faf12a19255f85269a54bedb2e902e23 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 13 Feb 2025 22:53:53 +0100 Subject: [PATCH] Added support for pickling jaxtyping annotations. --- jaxtyping/_array_types.py | 27 ++++++++++++++++++--------- pyproject.toml | 2 +- test/test_serialisation.py | 36 +++++++++++++++++++++++------------- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index b2f727a..67747a9 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -17,6 +17,7 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +import copyreg import enum import functools as ft import importlib.util @@ -317,6 +318,10 @@ def _check_shape( assert False +def _pickle_array_annotation(x: type["AbstractArray"]): + return x.dtype.__getitem__, ((x.array_type, x.dim_str),) + + @ft.lru_cache(maxsize=None) def _make_metaclass(base_metaclass): class MetaAbstractArray(_MetaAbstractArray, base_metaclass): @@ -338,6 +343,8 @@ def __eq__(cls, other): def __hash__(cls): return id(cls) + copyreg.pickle(MetaAbstractArray, _pickle_array_annotation) + return MetaAbstractArray @@ -358,11 +365,15 @@ class for `Float32[Array, "foo"]`. you can check `issubclass(annotation, jaxtyping.AbstractArray)`. """ + # This is what it was defined with. + dtype: type["AbstractDtype"] array_type: Any + dim_str: str + + # This is the processed information we need for later typechecking. dtypes: list[str] dims: tuple[_AbstractDimOrVariadicDim, ...] index_variadic: Optional[int] - dim_str: str _not_made = object() @@ -595,8 +606,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name): return (array_type, name, dtypes, dims, index_variadic, dim_str) -def _make_array(*args, **kwargs): - out = _make_array_cached(*args, **kwargs) +def _make_array(x, dim_str, dtype): + out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__) if type(out) is tuple: array_type, name, dtypes, dims, index_variadic, dim_str = out @@ -610,11 +621,12 @@ def _make_array(*args, **kwargs): name, (AbstractArray,) if array_type is Any else (array_type, AbstractArray), dict( + dtype=dtype, array_type=array_type, + dim_str=dim_str, dtypes=dtypes, dims=dims, index_variadic=index_variadic, - dim_str=dim_str, ), ) if getattr(typing, "GENERATING_DOCUMENTATION", False): @@ -654,10 +666,7 @@ def __getitem__(cls, item: tuple[Any, str]): array_type = bound del item if get_origin(array_type) in _union_types: - out = [ - _make_array(x, dim_str, cls.dtypes, cls.__name__) - for x in get_args(array_type) - ] + out = [_make_array(x, dim_str, cls) for x in get_args(array_type)] out = tuple(x for x in out if x is not _not_made) if len(out) == 0: raise ValueError("Invalid jaxtyping type annotation.") @@ -666,7 +675,7 @@ def __getitem__(cls, item: tuple[Any, str]): else: out = Union[out] else: - out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__) + out = _make_array(array_type, dim_str, cls) if out is _not_made: raise ValueError("Invalid jaxtyping type annotation.") return out diff --git a/pyproject.toml b/pyproject.toml index a7a2f8b..848178a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxtyping" -version = "0.2.37" +version = "0.2.38" description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays." readme = "README.md" requires-python =">=3.10" diff --git a/test/test_serialisation.py b/test/test_serialisation.py index 02453b7..6295e0e 100644 --- a/test/test_serialisation.py +++ b/test/test_serialisation.py @@ -1,3 +1,5 @@ +import pickle + import cloudpickle import numpy as np @@ -7,19 +9,27 @@ except ImportError: torch = None -from jaxtyping import AbstractArray, Array, Shaped +from jaxtyping import AbstractArray, Array, Float, Shaped def test_pickle(): - x = cloudpickle.dumps(Shaped[Array, ""]) - cloudpickle.loads(x) - - y = cloudpickle.dumps(AbstractArray) - cloudpickle.loads(y) - - z = cloudpickle.dumps(Shaped[np.ndarray, ""]) - cloudpickle.loads(z) - - if torch is not None: - w = cloudpickle.dumps(Shaped[torch.Tensor, ""]) - cloudpickle.loads(w) + for p in (pickle, cloudpickle): + x = p.dumps(Shaped[Array, ""]) + y = p.loads(x) + assert y.dtype is Shaped + assert y.dim_str == "" + + x = p.dumps(AbstractArray) + y = p.loads(x) + assert y is AbstractArray + + x = p.dumps(Shaped[np.ndarray, "3 4 hi"]) + y = p.loads(x) + assert y.dtype is Shaped + assert y.dim_str == "3 4 hi" + + if torch is not None: + x = p.dumps(Float[torch.Tensor, "batch length"]) + y = p.loads(x) + assert y.dtype is Float + assert y.dim_str == "batch length"