Skip to content

Commit

Permalink
Remove dead code after minimum jaxlib version bump to v0.4.36.
Browse files Browse the repository at this point in the history
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Dec 9, 2024
1 parent cc258f5 commit 79318a0
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 571 deletions.
9 changes: 2 additions & 7 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
Expand Down Expand Up @@ -1169,12 +1168,8 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))

if xla_extension_version >= 296:
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
else:
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter
batch_xs, batch_devs, batch_shardings)
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
for i, copy_out in safe_zip(batch_indices, copy_outs):
assert results[i] is None
results[i] = copy_out
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from jax._src import config
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
import numpy as np
Expand Down Expand Up @@ -301,10 +300,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
debug_options.xla_dump_hlo_as_long_text = False
debug_options.xla_dump_disable_metadata = False
debug_options.xla_dump_hlo_pipeline_re = ""

# "Requires jaxlib 0.4.36+"
if xla_extension_version > 296:
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
debug_options.xla_gpu_experimental_autotune_cache_mode = 0

# Optional way to specify the cuda install path to be used by the compiler.
# This could possibly affect the cuda version compiled with, but this should
Expand Down
6 changes: 2 additions & 4 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -192,9 +191,8 @@ def get_compile_options(
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment

if xla_extension_version >= 294:
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value

if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
Expand Down
Loading

0 comments on commit 79318a0

Please sign in to comment.