Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX 0.4.31 compatibility, with sharding argument in `convert_el… #134

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax_scalify/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
scalify,
)
from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .typing import Array, ArrayTypes, Sharding, get_numpy_api # noqa: F401
from .utils import safe_div, safe_reciprocal # noqa: F401
2 changes: 2 additions & 0 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
# Type aliasing. To be compatible with JAX 0.3 as well.
try:
from jax import Array
from jax.sharding import Sharding

ArrayTypes: Tuple[Any, ...] = (Array,)
except ImportError:
from jaxlib.xla_extension import DeviceArray as Array

Sharding = Any
# Older version of JAX <0.4
ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)

Expand Down
5 changes: 4 additions & 1 deletion jax_scalify/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DTypeLike,
ScaledArray,
Shape,
Sharding,
as_scaled_array,
get_scale_dtype,
is_static_anyscale,
Expand Down Expand Up @@ -76,7 +77,9 @@ def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions:


@core.register_scaled_lax_op
def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray:
def scaled_convert_element_type(
A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False, sharding: Sharding | None = None
) -> ScaledArray:
# NOTE: by default, no rescaling done before casting.
# Choice of adding an optional rescaling op before is up to the user (and which strategy to use).
# NOTE bis: scale not casted as well by default!
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"chex>=0.1.6",
"jax>=0.3.16,<0.4.31",
"jax>=0.3.16",
"jaxlib>=0.3.15",
"ml_dtypes",
"numpy>=1.22.4"
Expand Down
Loading