From 44357fe7c81ea345c052eb6005e36978b1bbfe31 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 9 Sep 2024 16:55:50 +0100 Subject: [PATCH] Fix JAX 0.4.31 compatibility, with `sharding` argument in `convert_element_type`. --- jax_scalify/core/__init__.py | 2 +- jax_scalify/core/typing.py | 2 ++ jax_scalify/lax/scaled_ops_common.py | 5 ++++- pyproject.toml | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/jax_scalify/core/__init__.py b/jax_scalify/core/__init__.py index d82e7f5..3fafb56 100644 --- a/jax_scalify/core/__init__.py +++ b/jax_scalify/core/__init__.py @@ -24,5 +24,5 @@ scalify, ) from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401 -from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401 +from .typing import Array, ArrayTypes, Sharding, get_numpy_api # noqa: F401 from .utils import safe_div, safe_reciprocal # noqa: F401 diff --git a/jax_scalify/core/typing.py b/jax_scalify/core/typing.py index 8291577..be0e289 100644 --- a/jax_scalify/core/typing.py +++ b/jax_scalify/core/typing.py @@ -9,11 +9,13 @@ # Type aliasing. To be compatible with JAX 0.3 as well. try: from jax import Array + from jax.sharding import Sharding ArrayTypes: Tuple[Any, ...] = (Array,) except ImportError: from jaxlib.xla_extension import DeviceArray as Array + Sharding = Any # Older version of JAX <0.4 ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer) diff --git a/jax_scalify/lax/scaled_ops_common.py b/jax_scalify/lax/scaled_ops_common.py index 17850f2..45a9bab 100644 --- a/jax_scalify/lax/scaled_ops_common.py +++ b/jax_scalify/lax/scaled_ops_common.py @@ -13,6 +13,7 @@ DTypeLike, ScaledArray, Shape, + Sharding, as_scaled_array, get_scale_dtype, is_static_anyscale, @@ -76,7 +77,9 @@ def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: @core.register_scaled_lax_op -def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray: +def scaled_convert_element_type( + A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False, sharding: Sharding | None = None +) -> ScaledArray: # NOTE: by default, no rescaling done before casting. # Choice of adding an optional rescaling op before is up to the user (and which strategy to use). # NOTE bis: scale not casted as well by default! diff --git a/pyproject.toml b/pyproject.toml index 9016657..485508f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "chex>=0.1.6", - "jax>=0.3.16,<0.4.31", + "jax>=0.3.16", "jaxlib>=0.3.15", "ml_dtypes", "numpy>=1.22.4"