Skip to content

Commit

Permalink
Added support for pickling jaxtyping annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 13, 2025
1 parent bd84aed commit 6813074
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
27 changes: 18 additions & 9 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import copyreg
import enum
import functools as ft
import importlib.util
Expand Down Expand Up @@ -317,6 +318,10 @@ def _check_shape(
assert False


def _pickle_array_annotation(x: type["AbstractArray"]):
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)


@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
Expand All @@ -338,6 +343,8 @@ def __eq__(cls, other):
def __hash__(cls):
return id(cls)

copyreg.pickle(MetaAbstractArray, _pickle_array_annotation)

return MetaAbstractArray


Expand All @@ -358,11 +365,15 @@ class for `Float32[Array, "foo"]`.
you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
"""

# This is what it was defined with.
dtype: type["AbstractDtype"]
array_type: Any
dim_str: str

# This is the processed information we need for later typechecking.
dtypes: list[str]
dims: tuple[_AbstractDimOrVariadicDim, ...]
index_variadic: Optional[int]
dim_str: str


_not_made = object()
Expand Down Expand Up @@ -595,8 +606,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
return (array_type, name, dtypes, dims, index_variadic, dim_str)


def _make_array(*args, **kwargs):
out = _make_array_cached(*args, **kwargs)
def _make_array(x, dim_str, dtype):
out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__)

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
Expand All @@ -610,11 +621,12 @@ def _make_array(*args, **kwargs):
name,
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
dict(
dtype=dtype,
array_type=array_type,
dim_str=dim_str,
dtypes=dtypes,
dims=dims,
index_variadic=index_variadic,
dim_str=dim_str,
),
)
if getattr(typing, "GENERATING_DOCUMENTATION", False):
Expand Down Expand Up @@ -654,10 +666,7 @@ def __getitem__(cls, item: tuple[Any, str]):
array_type = bound
del item
if get_origin(array_type) in _union_types:
out = [
_make_array(x, dim_str, cls.dtypes, cls.__name__)
for x in get_args(array_type)
]
out = [_make_array(x, dim_str, cls) for x in get_args(array_type)]
out = tuple(x for x in out if x is not _not_made)
if len(out) == 0:
raise ValueError("Invalid jaxtyping type annotation.")
Expand All @@ -666,7 +675,7 @@ def __getitem__(cls, item: tuple[Any, str]):
else:
out = Union[out]
else:
out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__)
out = _make_array(array_type, dim_str, cls)
if out is _not_made:
raise ValueError("Invalid jaxtyping type annotation.")
return out
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxtyping"
version = "0.2.37"
version = "0.2.38"
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
readme = "README.md"
requires-python =">=3.10"
Expand Down
36 changes: 23 additions & 13 deletions test/test_serialisation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import cloudpickle
import numpy as np

Expand All @@ -7,19 +9,27 @@
except ImportError:
torch = None

from jaxtyping import AbstractArray, Array, Shaped
from jaxtyping import AbstractArray, Array, Float, Shaped


def test_pickle():
x = cloudpickle.dumps(Shaped[Array, ""])
cloudpickle.loads(x)

y = cloudpickle.dumps(AbstractArray)
cloudpickle.loads(y)

z = cloudpickle.dumps(Shaped[np.ndarray, ""])
cloudpickle.loads(z)

if torch is not None:
w = cloudpickle.dumps(Shaped[torch.Tensor, ""])
cloudpickle.loads(w)
for p in (pickle, cloudpickle):
x = p.dumps(Shaped[Array, ""])
y = p.loads(x)
assert y.dtype is Shaped
assert y.dim_str == ""

x = p.dumps(AbstractArray)
y = p.loads(x)
assert y is AbstractArray

x = p.dumps(Shaped[np.ndarray, "3 4 hi"])
y = p.loads(x)
assert y.dtype is Shaped
assert y.dim_str == "3 4 hi"

if torch is not None:
x = p.dumps(Float[torch.Tensor, "batch length"])
y = p.loads(x)
assert y.dtype is Float
assert y.dim_str == "batch length"

0 comments on commit 6813074

Please sign in to comment.