From 3b28d4e180f3ccf3112dc9bda444a28bab846b13 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Jul 2023 23:00:26 -0700 Subject: [PATCH] Refactor Jax compilation cache key generation. This is in preparation for introducing a more robust key-generation algorithm. This refactoring does not introduce any change in behavior. Testing: refactored unit tests and test workload. PiperOrigin-RevId: 551744892 --- jax/BUILD | 11 ++ jax/_src/cache_key.py | 322 ++++++++++++++++++++++++++++++++ jax/_src/compilation_cache.py | 264 +------------------------- tests/BUILD | 13 +- tests/cache_key_test.py | 285 ++++++++++++++++++++++++++++ tests/compilation_cache_test.py | 247 +----------------------- 6 files changed, 638 insertions(+), 504 deletions(-) create mode 100644 jax/_src/cache_key.py create mode 100644 tests/cache_key_test.py diff --git a/jax/BUILD b/jax/BUILD index 7b3bcfd830b0..d09d66c8ce38 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -264,6 +264,7 @@ pytype_strict_library( srcs = ["_src/compilation_cache.py"], visibility = [":internal"] + jax_visibility("compilation_cache"), deps = [ + ":cache_key", ":compilation_cache_interface", ":config", ":gfile_cache", @@ -272,6 +273,16 @@ pytype_strict_library( ] + py_deps("numpy") + py_deps("zstandard"), ) +pytype_strict_library( + name = "cache_key", + srcs = ["_src/cache_key.py"], + visibility = [":internal"] + jax_visibility("compilation_cache"), + deps = [ + ":config", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "compilation_cache_interface", srcs = ["_src/compilation_cache_interface.py"], diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py new file mode 100644 index 000000000000..ea112cb3c483 --- /dev/null +++ b/jax/_src/cache_key.py @@ -0,0 +1,322 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import io +import logging +import os +import sys + +import numpy as np + +from jax._src.config import config +from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version +from jax._src.lib import version_str as jaxlib_version_str +from jax._src.lib import version as jaxlib_version +from jax._src.lib.mlir import ir +from jax._src.lib.mlir import passmanager as pm + +if jaxlib_version < (0, 4, 14): + import jaxlib.mlir.jax as mlir_jax # pytype: disable=import-error + assert mlir_jax is not None # Imported for side effects only. + + +logger = logging.getLogger(__name__) + +_extra_flag_prefixes: list[str] = [] + +def add_flag_prefixes(flag_prefixes: list[str]) -> None: + """Add flag prefixes to include in the cache key. Call prior to get(). + """ + global _extra_flag_prefixes + _extra_flag_prefixes += flag_prefixes + + +def clear_flag_prefixes() -> None: + """Clear flag prefixes added by add_flag_prefixes(). + """ + global _extra_flag_prefixes + _extra_flag_prefixes = [] + + +def get_flag_prefixes() -> list[str]: + """Return flag prefixes added by add_flag_prefixes(). + """ + return _extra_flag_prefixes + + +def get(module: ir.Module, + devices: np.ndarray, + compile_options: xla_client.CompileOptions, + backend: xla_client.Client, + compression_algorithm: str = "zstandard") -> str: + """Creates a hashed string to use as a key to the compilation cache. + + Creates a cache key that is a hex-encoded string of a unique hash based on + the arguments. The hex-encoded string is 256 characters long. + + Args: + module: the input program + devices: an array of accelerator devices that the program will run on + compile_options: options passed to the XLA compiler + backend: description of the platform (e.g., TPU version) + compression_algorithm: a string representing the compression algorithm used + for the executable before persisting in the cache + + Typical return value example: + '14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' + """ + entries = [ + ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), + ("devices", lambda hash_obj: _hash_devices(hash_obj, devices)), + ("compile_options", + lambda hash_obj: _hash_compile_options(hash_obj, compile_options)), + ("jax_lib version", + lambda hash_obj: hash_obj.update( + bytes(jaxlib_version_str.encode("utf-8")))), + ("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)), + ("XLA flags", + lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())), + ("compression", + lambda hash_obj: _hash_string(hash_obj, compression_algorithm)), + ] + + hash_obj = hashlib.sha256() + for name, hashfn in entries: + hashfn(hash_obj) + _log_cache_key_hash(hash_obj, name, hashfn) + return hash_obj.digest().hex() + + +def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): + if logger.isEnabledFor(logging.DEBUG): + # Log the hash of just this entry + fresh_hash_obj = hashlib.sha256() + hashfn(fresh_hash_obj) + logger.debug( + "get_cache_key hash of serialized %s: %s", + last_serialized, + fresh_hash_obj.digest().hex(), + ) + # Log the cumulative hash + logger.debug( + "get_cache_key hash after serializing %s: %s", + last_serialized, + hash_obj.digest().hex(), + ) + + +def _serialize_ir(m: ir.Module) -> bytes: + output = io.BytesIO() + m.operation.write_bytecode(file=output) + return output.getvalue() + + +def _canonicalize_ir(m_original: ir.Module) -> bytes: + with m_original.context: + m = m_original.operation.clone() + if jaxlib_version < (0, 4, 14): + passes = pm.PassManager.parse( + "builtin.module(func.func(jax-strip-locations))" + ) + else: + passes = pm.PassManager.parse( + "builtin.module(strip-debuginfo)" + ) + passes.run(m.operation) + return _serialize_ir(m) + + +def _hash_computation(hash_obj, module): + if config.jax_compilation_cache_include_metadata_in_key: + canonical_ir = _serialize_ir(module) + else: + canonical_ir = _canonicalize_ir(module) + hash_obj.update(canonical_ir) + + +def _hash_devices(hash_obj, devices: np.ndarray) -> None: + for device in devices.flat: + _hash_string(hash_obj, device.device_kind) + + +def _hash_compile_options(hash_obj, compile_options_obj): + expected_num_compile_options = 12 + # Ignore private and built-in methods. These can unexpectedly change and lead + # to false positives, e.g. when different Python versions include different + # built-ins. + num_actual_options = len( + [x for x in dir(compile_options_obj) if not x.startswith("_")] + ) + assert num_actual_options == expected_num_compile_options, ( + "Unexpected number of CompileOption fields: " + f"{num_actual_options}. This likely: means that an extra " + "field was added, and this function needs to be updated." + ) + + if compile_options_obj.argument_layouts is not None: + map( + lambda shape: hash_obj.update(shape.to_serialized_proto()), + compile_options_obj.argument_layouts, + ) + _hash_int(hash_obj, compile_options_obj.parameter_is_tupled_arguments) + _hash_executable_build_options( + hash_obj, compile_options_obj.executable_build_options + ) + _hash_bool(hash_obj, compile_options_obj.tuple_arguments) + _hash_int(hash_obj, compile_options_obj.num_replicas) + _hash_int(hash_obj, compile_options_obj.num_partitions) + _hash_signed_int(hash_obj, compile_options_obj.profile_version) + if compile_options_obj.device_assignment is not None: + hash_obj.update(compile_options_obj.device_assignment.serialize()) + _hash_bool(hash_obj, compile_options_obj.compile_portable_executable) + _hash_int(hash_obj, len(compile_options_obj.env_option_overrides)) + for kv in compile_options_obj.env_option_overrides: + _hash_string(hash_obj, kv[0]) + if isinstance(kv[1], str): + _hash_string(hash_obj, kv[1]) + elif isinstance(kv[1], bool): + _hash_bool(hash_obj, kv[1]) + elif isinstance(kv[1], int): + _hash_int(hash_obj, kv[1]) + else: + raise RuntimeError("Invalid type: %s" % repr(type(kv[1]))) + + +def _hash_executable_build_options(hash_obj, executable_obj): + if xla_extension_version > 165: + expected_options = 11 + else: + expected_options = 10 + # Ignore private and built-in methods. These can unexpectedly change and lead + # to false positives, e.g. when different Python versions include different + # built-ins. + actual_options = len( + [x for x in dir(executable_obj) if not x.startswith("_")] + ) + assert actual_options == expected_options, ( + "Unexpected number of executable_build_options fields: " + f"{actual_options}, expected: {expected_options}. This likely means " + "that an extra field was added, and this function needs to be updated." + ) + if executable_obj.result_layout is not None: + hash_obj.update(executable_obj.result_layout.to_serialized_proto()) + _hash_int(hash_obj, executable_obj.num_replicas) + _hash_int(hash_obj, executable_obj.num_partitions) + _hash_debug_options(hash_obj, executable_obj.debug_options) + if executable_obj.device_assignment is not None: + hash_obj.update(executable_obj.device_assignment.serialize()) + _hash_bool(hash_obj, executable_obj.use_spmd_partitioning) + _hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning) + if executable_obj.use_auto_spmd_partitioning: + if executable_obj.auto_spmd_partitioning_mesh_shape is not None: + _hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_shape) + if executable_obj.auto_spmd_partitioning_mesh_ids is not None: + _hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_ids) + _hash_bool_list( + hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output + ) + if xla_extension_version > 165 and executable_obj.fdo_profile is not None: + _hash_string(hash_obj, executable_obj.fdo_profile) + + +def _hash_debug_options(hash_obj, debug_obj): + _hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math) + _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_infs) + _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_nans) + _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_division) + _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_functions) + _hash_bool(hash_obj, debug_obj.xla_gpu_enable_fast_min_max) + _hash_int(hash_obj, debug_obj.xla_backend_optimization_level) + _hash_bool(hash_obj, debug_obj.xla_cpu_enable_xprof_traceme) + _hash_bool(hash_obj, debug_obj.xla_llvm_disable_expensive_passes) + _hash_bool(hash_obj, debug_obj.xla_test_all_input_layouts) + + +def _hash_platform(hash_obj, backend): + _hash_string(hash_obj, backend.platform) + _hash_string(hash_obj, backend.platform_version) + _hash_string(hash_obj, backend.runtime_type) + + +def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): + xla_flags_to_exclude_from_cache_key = [ + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", + "--xla_tpu_sdc_checker_streamz_metric", + "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", + ] + + xla_flags = [] + + xla_flags_env_var = os.getenv("XLA_FLAGS") + if xla_flags_env_var: + xla_flags.extend(xla_flags_env_var.split()) + + for arg in sys.argv: + if arg.startswith("--xla") or any( + arg.startswith(p) for p in extra_flag_prefixes + ): + xla_flags.append(arg) + + # N.B. all XLA flags that take an argument must use '=' and not a space + # (e.g. --xla_force_host_platform_device_count=8) (I think). + for flag in xla_flags: + if flag.split("=")[0] in xla_flags_to_exclude_from_cache_key: + logger.debug("Not including XLA flag in cache key: %s", flag) + continue + logger.debug("Including XLA flag in cache key: %s", flag) + _hash_string(hash_obj, flag) + + +def _hash_int(hash_obj, int_var): + hash_obj.update(int_var.to_bytes(8, byteorder="big")) + + +def _hash_signed_int(hash_obj, int_var): + hash_obj.update(int_var.to_bytes(8, byteorder="big", signed=True)) + + +def _hash_bool(hash_obj, bool_var): + hash_obj.update(bool_var.to_bytes(1, byteorder="big")) + + +def _hash_string(hash_obj, str_var): + hash_obj.update(str_var.encode("utf-8").strip()) + + +def _hash_bool_list(hash_obj, bool_list): + for b in bool_list: + _hash_bool(hash_obj, b) + _hash_int(hash_obj, len(bool_list)) + + +def _hash_int_list(hash_obj, int_list): + for i in int_list: + _hash_int(hash_obj, i) + _hash_int(hash_obj, len(int_list)) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 467a02e7954d..a1de154989f0 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import io import logging -import os -import sys from typing import Optional import zlib @@ -28,20 +24,12 @@ except ImportError: zstandard = None -from jax._src.config import config from jax._src import path as pathlib +from jax._src import cache_key from jax._src.compilation_cache_interface import CacheInterface from jax._src.gfile_cache import GFileCache from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version -from jax._src.lib import version_str as jaxlib_version_str -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir -from jax._src.lib.mlir import passmanager as pm - -if jaxlib_version < (0, 4, 14): - import jaxlib.mlir.jax as mlir_jax # pytype: disable=import-error - assert mlir_jax is not None # Imported for side effects only. logger = logging.getLogger(__name__) @@ -126,256 +114,10 @@ def put_executable_and_time( _cache.put(cache_key, executable_and_time) -def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): - if logger.isEnabledFor(logging.DEBUG): - # Log the hash of just this entry - fresh_hash_obj = hashlib.sha256() - hashfn(fresh_hash_obj) - logger.debug( - "get_cache_key hash of serialized %s: %s", - last_serialized, - fresh_hash_obj.digest().hex(), - ) - # Log the cumulative hash - logger.debug( - "get_cache_key hash after serializing %s: %s", - last_serialized, - hash_obj.digest().hex(), - ) - - def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options, backend) -> str: - """Creates a hashed string to use as a key to the compilation cache. - - get_cache_key takes in the MLIR module and compile_options of a program - and hashes all the components into a unique hash. The hash is returned as a - hex-encoded string that is 256 characters long. - - Typical return value example: - '14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' - """ - entries = [ - ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), - ("devices", lambda hash_obj: _hash_devices(hash_obj, devices)), - ("compile_options", - lambda hash_obj: _hash_compile_options(hash_obj, compile_options)), - ("jax_lib version", - lambda hash_obj: hash_obj.update(bytes(jaxlib_version_str.encode("utf-8"))) - ), - ("the backend", lambda hash_obj: _hash_platform(hash_obj, backend)), - ("XLA flags", _hash_xla_flags), - ("compression", _hash_compression), - ] - - hash_obj = hashlib.sha256() - for name, hashfn in entries: - hashfn(hash_obj) - _log_cache_key_hash(hash_obj, name, hashfn) - return hash_obj.digest().hex() - - -def _serialize_ir(m: ir.Module) -> bytes: - output = io.BytesIO() - m.operation.write_bytecode(file=output) - return output.getvalue() - -def _canonicalize_ir(m_original: ir.Module) -> bytes: - with m_original.context: - m = m_original.operation.clone() - if jaxlib_version < (0, 4, 14): - passes = pm.PassManager.parse( - "builtin.module(func.func(jax-strip-locations))" - ) - else: - passes = pm.PassManager.parse( - "builtin.module(strip-debuginfo)" - ) - passes.run(m.operation) - return _serialize_ir(m) - -def _hash_computation(hash_obj, module): - if config.jax_compilation_cache_include_metadata_in_key: - canonical_ir = _serialize_ir(module) - else: - canonical_ir = _canonicalize_ir(module) - hash_obj.update(canonical_ir) - -def _hash_devices(hash_obj, devices: np.ndarray) -> None: - for device in devices.flat: - _hash_string(hash_obj, device.device_kind) - -def _hash_compile_options(hash_obj, compile_options_obj): - expected_num_compile_options = 12 - # Ignore private and built-in methods. These can unexpectedly change and lead - # to false positives, e.g. when different Python versions include different - # built-ins. - num_actual_options = len( - [x for x in dir(compile_options_obj) if not x.startswith("_")] - ) - assert num_actual_options == expected_num_compile_options, ( - "Unexpected number of CompileOption fields: " - f"{num_actual_options}. This likely: means that an extra " - "field was added, and this function needs to be updated." - ) - - if compile_options_obj.argument_layouts is not None: - map( - lambda shape: hash_obj.update(shape.to_serialized_proto()), - compile_options_obj.argument_layouts, - ) - _hash_int(hash_obj, compile_options_obj.parameter_is_tupled_arguments) - _hash_executable_build_options( - hash_obj, compile_options_obj.executable_build_options - ) - _hash_bool(hash_obj, compile_options_obj.tuple_arguments) - _hash_int(hash_obj, compile_options_obj.num_replicas) - _hash_int(hash_obj, compile_options_obj.num_partitions) - _hash_int(hash_obj, compile_options_obj.profile_version) - if compile_options_obj.device_assignment is not None: - hash_obj.update(compile_options_obj.device_assignment.serialize()) - _hash_bool(hash_obj, compile_options_obj.compile_portable_executable) - _hash_int(hash_obj, len(compile_options_obj.env_option_overrides)) - for kv in compile_options_obj.env_option_overrides: - _hash_string(hash_obj, kv[0]) - if isinstance(kv[1], str): - _hash_string(hash_obj, kv[1]) - elif isinstance(kv[1], bool): - _hash_bool(hash_obj, kv[1]) - elif isinstance(kv[1], int): - _hash_int(hash_obj, kv[1]) - else: - raise RuntimeError("Invalid type: %s" % repr(type(kv[1]))) - - -def _hash_executable_build_options(hash_obj, executable_obj): - if xla_extension_version > 165: - expected_options = 11 - else: - expected_options = 10 - # Ignore private and built-in methods. These can unexpectedly change and lead - # to false positives, e.g. when different Python versions include different - # built-ins. - actual_options = len( - [x for x in dir(executable_obj) if not x.startswith("_")] - ) - assert actual_options == expected_options, ( - "Unexpected number of executable_build_options fields: " - f"{actual_options}, expected: {expected_options}. This likely means " - "that an extra field was added, and this function needs to be updated." - ) - if executable_obj.result_layout is not None: - hash_obj.update(executable_obj.result_layout.to_serialized_proto()) - _hash_int(hash_obj, executable_obj.num_replicas) - _hash_int(hash_obj, executable_obj.num_partitions) - _hash_debug_options(hash_obj, executable_obj.debug_options) - if executable_obj.device_assignment is not None: - hash_obj.update(executable_obj.device_assignment.serialize()) - _hash_bool(hash_obj, executable_obj.use_spmd_partitioning) - _hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning) - if executable_obj.use_auto_spmd_partitioning: - if executable_obj.auto_spmd_partitioning_mesh_shape is not None: - _hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_shape) - if executable_obj.auto_spmd_partitioning_mesh_ids is not None: - _hash_int_list(hash_obj, executable_obj.auto_spmd_partitioning_mesh_ids) - _hash_bool_list( - hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output - ) - if xla_extension_version > 165 and executable_obj.fdo_profile is not None: - _hash_string(hash_obj, executable_obj.fdo_profile) - - -def _hash_debug_options(hash_obj, debug_obj): - _hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math) - _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_infs) - _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_nans) - _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_division) - _hash_bool(hash_obj, debug_obj.xla_cpu_fast_math_honor_functions) - _hash_bool(hash_obj, debug_obj.xla_gpu_enable_fast_min_max) - _hash_int(hash_obj, debug_obj.xla_backend_optimization_level) - _hash_bool(hash_obj, debug_obj.xla_cpu_enable_xprof_traceme) - _hash_bool(hash_obj, debug_obj.xla_llvm_disable_expensive_passes) - _hash_bool(hash_obj, debug_obj.xla_test_all_input_layouts) - - -def _hash_platform(hash_obj, backend): - _hash_string(hash_obj, backend.platform) - _hash_string(hash_obj, backend.platform_version) - _hash_string(hash_obj, backend.runtime_type) - - -_xla_flags_to_exclude_from_cache_key = [ - "--xla_dump_compress_protos", - "--xla_dump_module_metadata", - "--xla_dump_max_hlo_modules", - "--xla_dump_include_timestamp", - "--xla_dump_hlo_pass_re", - "--xla_dump_hlo_module_re", - "--xla_dump_hlo_snapshots", - "--xla_dump_fusion_visualization", - "--xla_dump_hlo_as_url", - "--xla_dump_hlo_as_proto", - "--xla_dump_hlo_as_text", - "--xla_dump_to", - "--xla_force_host_platform_device_count", - "--xla_dump_disable_metadata", - "--xla_dump_hlo_pipeline_re", - "--xla_tpu_sdc_checker_streamz_metric", - "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", -] - -extra_flag_prefixes_to_include_in_cache_key: list[str] = [] - - -def _hash_xla_flags(hash_obj): - xla_flags = [] - - xla_flags_env_var = os.getenv("XLA_FLAGS") - if xla_flags_env_var: - xla_flags.extend(xla_flags_env_var.split()) - - for arg in sys.argv: - if arg.startswith("--xla") or any( - arg.startswith(p) for p in extra_flag_prefixes_to_include_in_cache_key - ): - xla_flags.append(arg) - - # N.B. all XLA flags that take an argument must use '=' and not a space - # (e.g. --xla_force_host_platform_device_count=8) (I think). - for flag in xla_flags: - if flag.split("=")[0] in _xla_flags_to_exclude_from_cache_key: - logger.debug("Not including XLA flag in cache key: %s", flag) - continue - logger.debug("Including XLA flag in cache key: %s", flag) - _hash_string(hash_obj, flag) - - -def _hash_compression(hash_obj): - _hash_string(hash_obj, "zstandard" if zstandard is not None else "zlib") - - -def _hash_int(hash_obj, int_var): - hash_obj.update(int_var.to_bytes(8, byteorder="big")) - - -def _hash_bool(hash_obj, bool_var): - hash_obj.update(bool_var.to_bytes(1, byteorder="big")) - - -def _hash_string(hash_obj, str_var): - hash_obj.update(str_var.encode("utf-8").strip()) - - -def _hash_bool_list(hash_obj, bool_list): - for b in bool_list: - _hash_bool(hash_obj, b) - _hash_int(hash_obj, len(bool_list)) - - -def _hash_int_list(hash_obj, int_list): - for i in int_list: - _hash_int(hash_obj, i) - _hash_int(hash_obj, len(int_list)) + return cache_key.get(module, devices, compile_options, backend, + "zstandard" if zstandard is not None else "zlib") def is_initialized(): diff --git a/tests/BUILD b/tests/BUILD index 28f60e946034..717e45479fbe 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -921,9 +921,13 @@ jax_test( backend_tags = { "tpu": ["nomsan"], # TODO(b/213388298): this test fails msan. }, - deps = [ - "//jax:compilation_cache_internal", - ], + deps = ["//jax:compilation_cache_internal"], +) + +jax_test( + name = "cache_key_test", + srcs = ["cache_key_test.py"], + deps = ["//jax:cache_key"], ) jax_test( @@ -1144,8 +1148,9 @@ exports_files( [ "api_test.py", "array_test.py", - "pmap_test.py", + "cache_key_test.py", "compilation_cache_test.py", + "pmap_test.py", "pjit_test.py", "python_callback_test.py", "transfer_guard_test.py", diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py new file mode 100644 index 000000000000..fbe7376e8c11 --- /dev/null +++ b/tests/cache_key_test.py @@ -0,0 +1,285 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import random +import sys +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import config +from jax import lax +from jax._src import cache_key +from jax._src import test_util as jtu +from jax._src import xla_bridge +from jax._src.config import compilation_cache_include_metadata_in_key +from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version +import numpy as np + + +config.parse_flags_with_absl() +FLAGS = config.FLAGS + + +class CacheKeyTest(jtu.JaxTestCase): + + def test_compile_options(self): + compile_options_not_filled = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + compile_options_filled = self.filled_compile_options() + filled_hash1 = self.get_hashed_value( + cache_key._hash_compile_options, compile_options_filled + ) + filled_hash2 = self.get_hashed_value( + cache_key._hash_compile_options, compile_options_filled + ) + not_filled_hash3 = self.get_hashed_value( + cache_key._hash_compile_options, compile_options_not_filled + ) + self.assertEqual(filled_hash1, filled_hash2) + self.assertNotEqual(filled_hash1, not_filled_hash3) + + def test_executable_build_options(self): + compile_options_not_filled = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + compile_options_filled = self.filled_compile_options() + filled_hash1 = self.get_hashed_value( + cache_key._hash_executable_build_options, + compile_options_filled.executable_build_options, + ) + filled_hash2 = self.get_hashed_value( + cache_key._hash_executable_build_options, + compile_options_filled.executable_build_options, + ) + not_filled_hash3 = self.get_hashed_value( + cache_key._hash_executable_build_options, + compile_options_not_filled.executable_build_options, + ) + self.assertEqual(filled_hash1, filled_hash2) + self.assertNotEqual(filled_hash1, not_filled_hash3) + + def test_debug_options(self): + compile_options = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + hash1 = self.get_hashed_value( + cache_key._hash_debug_options, + compile_options.executable_build_options.debug_options, + ) + hash2 = self.get_hashed_value( + cache_key._hash_debug_options, + compile_options.executable_build_options.debug_options, + ) + self.assertEqual(hash1, hash2) + new_debug_options = self.create_new_debug_options( + compile_options.executable_build_options.debug_options + ) + hash3 = self.get_hashed_value( + cache_key._hash_debug_options, new_debug_options + ) + self.assertNotEqual(hash1, hash3) + + def test_hash_platform(self): + hash1 = self.get_hashed_value( + cache_key._hash_platform, xla_bridge.get_backend() + ) + hash2 = self.get_hashed_value( + cache_key._hash_platform, xla_bridge.get_backend() + ) + self.assertEqual(hash1, hash2) + if xla_bridge.get_backend().platform != "cpu": + cpu_backend = xla_bridge.get_backend("cpu") + hash3 = self.get_hashed_value(cache_key._hash_platform, cpu_backend) + self.assertNotEqual(hash1, hash3) + + def test_hash_int(self): + hash1 = self.get_hashed_value(cache_key._hash_int, 90) + hash2 = self.get_hashed_value(cache_key._hash_int, 8) + hash3 = self.get_hashed_value(cache_key._hash_int, 8) + self.assertEqual(hash2, hash3) + self.assertNotEqual(hash1, hash2) + + def test_hash_signed_int(self): + hash1 = self.get_hashed_value(cache_key._hash_signed_int, 90) + hash2 = self.get_hashed_value(cache_key._hash_signed_int, -90) + hash3 = self.get_hashed_value(cache_key._hash_signed_int, -8) + hash4 = self.get_hashed_value(cache_key._hash_signed_int, -8) + self.assertEqual(hash3, hash4) + self.assertNotEqual(hash1, hash2) + self.assertNotEqual(hash1, hash3) + + def test_hash_bool(self): + hash1 = self.get_hashed_value(cache_key._hash_bool, False) + hash2 = self.get_hashed_value(cache_key._hash_bool, True) + hash3 = self.get_hashed_value(cache_key._hash_bool, True) + self.assertEqual(hash2, hash3) + self.assertNotEqual(hash1, hash2) + + def test_hash_string(self): + hash1 = self.get_hashed_value(cache_key._hash_string, "foo") + hash2 = self.get_hashed_value(cache_key._hash_string, "bar") + hash3 = self.get_hashed_value(cache_key._hash_string, "bar") + self.assertEqual(hash2, hash3) + self.assertNotEqual(hash1, hash2) + + def test_same_key(self): + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + self.assertEqual( + cache_key.get(computation, devices, compile_options, backend), + cache_key.get(computation, devices, compile_options, backend), + ) + + def test_different_key(self): + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options_not_filled = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + compile_options_filled = self.filled_compile_options() + backend = xla_bridge.get_backend() + self.assertNotEqual( + cache_key.get( + computation, devices, compile_options_not_filled, backend + ), + cache_key.get(computation, devices, compile_options_filled, backend), + ) + + def test_different_computations(self): + computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + self.assertNotEqual( + cache_key.get(computation1, devices, compile_options, backend), + cache_key.get(computation2, devices, compile_options, backend), + ) + + @parameterized.parameters([False, True]) + def test_identical_computations_different_metadata(self, include_metadata): + f = lambda x, y: lax.mul(lax.add(x, y), 2) + g = lambda x, y: lax.mul(lax.add(x, y), 2) + assert id(f) != id(g) + computation1 = jax.jit(f).lower(1, 1).compiler_ir() + computation2 = jax.jit(g).lower(2, 3).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + with compilation_cache_include_metadata_in_key(include_metadata): + key1 = cache_key.get(computation1, devices, compile_options, backend) + key2 = cache_key.get(computation2, devices, compile_options, backend) + self.assertEqual(include_metadata, key1 != key2) + + def test_xla_flags(self): + if jtu.is_device_tpu_v4(): + raise unittest.SkipTest("TODO(b/240151176)") + + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = xla_bridge.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + + orig_xla_flags = os.getenv("XLA_FLAGS") + orig_argv = sys.argv + try: + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" + key1 = cache_key.get(computation, devices, compile_options, backend) + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1" + key2 = cache_key.get(computation, devices, compile_options, backend) + self.assertNotEqual(key1, key2) + + os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" + key3 = cache_key.get(computation, devices, compile_options, backend) + self.assertEqual(key1, key3) + + # Test flag in _xla_flags_to_exclude_from_cache_key + os.environ["XLA_FLAGS"] = ( + "--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8" + ) + key4 = cache_key.get(computation, devices, compile_options, backend) + self.assertEqual(key1, key4) + + # Test flags given on command line + del os.environ["XLA_FLAGS"] + sys.argv.append("--xla_gpu_autotune_level=0") + key5 = cache_key.get(computation, devices, compile_options, backend) + self.assertEqual(key1, key5) + sys.argv.append("--xla_force_host_platform_device_count=8") + self.assertEqual(key1, key5) + + finally: + if orig_xla_flags is not None: + os.environ["XLA_FLAGS"] = orig_xla_flags + elif os.getenv("XLA_FLAGS") is not None: + del os.environ["XLA_FLAGS"] + sys.argv = orig_argv + + def create_new_debug_options(self, debug_options_obj): + debug_options_obj.xla_cpu_enable_fast_math = False + debug_options_obj.xla_cpu_fast_math_honor_infs = False + debug_options_obj.xla_cpu_fast_math_honor_nans = False + debug_options_obj.xla_cpu_fast_math_honor_division = False + debug_options_obj.xla_cpu_fast_math_honor_functions = False + debug_options_obj.xla_gpu_enable_fast_min_max = False + debug_options_obj.xla_backend_optimization_level = random.randint(0, 10) + debug_options_obj.xla_cpu_enable_xprof_traceme = False + debug_options_obj.xla_llvm_disable_expensive_passes = False + debug_options_obj.xla_test_all_input_layouts = False + return debug_options_obj + + def filled_compile_options(self): + compile_options = xla_client.CompileOptions() + compile_options.num_replicas = 1 + compile_options.num_partitions = 1 + shape = xla_client.Shape.array_shape(np.dtype(np.float32), [2]) + shape_array = [shape, shape] + compile_options.argument_layouts = shape_array + compile_options.executable_build_options.result_layout = shape + + device_assignment = xla_client.DeviceAssignment.create( + np.ndarray(shape=(2, 2)) + ) + compile_options.device_assignment = device_assignment + compile_options.executable_build_options.device_assignment = ( + device_assignment + ) + if xla_extension_version > 165: + compile_options.executable_build_options.fdo_profile = b"test_profile" + return compile_options + + def get_hashed_value(self, hash_function, hash_function_input): + hash_obj = hashlib.sha256() + hash_function(hash_obj, hash_function_input) + return hash_obj.digest().hex() + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 57f109ae3018..8b2f9415234c 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -13,36 +13,31 @@ # limitations under the License. from functools import partial -import hashlib import math import os -import random -import sys import tempfile -import unittest -from unittest import SkipTest, mock +from unittest import mock +from unittest import SkipTest import warnings from absl.testing import absltest -from absl.testing import parameterized import jax from jax import config -from jax import jit, lax, pmap +from jax import jit +from jax import lax +from jax import pmap from jax._src import compilation_cache as cc from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.config import ( - compilation_cache_include_metadata_in_key, - persistent_cache_min_compile_time_secs, - raise_persistent_cache_errors, -) +from jax._src.config import persistent_cache_min_compile_time_secs +from jax._src.config import raise_persistent_cache_errors from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version from jax.experimental.maps import xmap from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np + config.parse_flags_with_absl() FLAGS = config.FLAGS @@ -76,194 +71,6 @@ def tearDown(self): cc.reset_cache() super().tearDown() - def test_compile_options(self): - compile_options_not_filled = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - compile_options_filled = self.filled_compile_options() - filled_hash1 = self.get_hashed_value( - cc._hash_compile_options, compile_options_filled - ) - filled_hash2 = self.get_hashed_value( - cc._hash_compile_options, compile_options_filled - ) - not_filled_hash3 = self.get_hashed_value( - cc._hash_compile_options, compile_options_not_filled - ) - self.assertEqual(filled_hash1, filled_hash2) - self.assertNotEqual(filled_hash1, not_filled_hash3) - - def test_executable_build_options(self): - compile_options_not_filled = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - compile_options_filled = self.filled_compile_options() - filled_hash1 = self.get_hashed_value( - cc._hash_executable_build_options, - compile_options_filled.executable_build_options, - ) - filled_hash2 = self.get_hashed_value( - cc._hash_executable_build_options, - compile_options_filled.executable_build_options, - ) - not_filled_hash3 = self.get_hashed_value( - cc._hash_executable_build_options, - compile_options_not_filled.executable_build_options, - ) - self.assertEqual(filled_hash1, filled_hash2) - self.assertNotEqual(filled_hash1, not_filled_hash3) - - def test_debug_options(self): - compile_options = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - hash1 = self.get_hashed_value( - cc._hash_debug_options, - compile_options.executable_build_options.debug_options, - ) - hash2 = self.get_hashed_value( - cc._hash_debug_options, - compile_options.executable_build_options.debug_options, - ) - self.assertEqual(hash1, hash2) - new_debug_options = self.create_new_debug_options( - compile_options.executable_build_options.debug_options - ) - hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options) - self.assertNotEqual(hash1, hash3) - - def test_hash_platform(self): - hash1 = self.get_hashed_value(cc._hash_platform, xla_bridge.get_backend()) - hash2 = self.get_hashed_value(cc._hash_platform, xla_bridge.get_backend()) - self.assertEqual(hash1, hash2) - if xla_bridge.get_backend().platform != "cpu": - cpu_backend = xla_bridge.get_backend("cpu") - hash3 = self.get_hashed_value(cc._hash_platform, cpu_backend) - self.assertNotEqual(hash1, hash3) - - def test_hash_int(self): - hash1 = self.get_hashed_value(cc._hash_int, 90) - hash2 = self.get_hashed_value(cc._hash_int, 8) - hash3 = self.get_hashed_value(cc._hash_int, 8) - self.assertEqual(hash2, hash3) - self.assertNotEqual(hash1, hash2) - - def test_hash_bool(self): - hash1 = self.get_hashed_value(cc._hash_bool, False) - hash2 = self.get_hashed_value(cc._hash_bool, True) - hash3 = self.get_hashed_value(cc._hash_bool, True) - self.assertEqual(hash2, hash3) - self.assertNotEqual(hash1, hash2) - - def test_hash_string(self): - hash1 = self.get_hashed_value(cc._hash_string, "foo") - hash2 = self.get_hashed_value(cc._hash_string, "bar") - hash3 = self.get_hashed_value(cc._hash_string, "bar") - self.assertEqual(hash2, hash3) - self.assertNotEqual(hash1, hash2) - - def test_same_hash_key(self): - computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - self.assertEqual( - cc.get_cache_key(computation, devices, compile_options, backend), - cc.get_cache_key(computation, devices, compile_options, backend), - ) - - def test_different_hash_key(self): - computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options_not_filled = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - compile_options_filled = self.filled_compile_options() - backend = xla_bridge.get_backend() - self.assertNotEqual( - cc.get_cache_key(computation, devices, compile_options_not_filled, - backend), - cc.get_cache_key(computation, devices, compile_options_filled, backend), - ) - - def test_different_computations(self): - computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - self.assertNotEqual( - cc.get_cache_key(computation1, devices, compile_options, backend), - cc.get_cache_key(computation2, devices, compile_options, backend), - ) - - @parameterized.parameters([False, True]) - def test_identical_computations_different_metadata(self, include_metadata): - f = lambda x, y: lax.mul(lax.add(x, y), 2) - g = lambda x, y: lax.mul(lax.add(x, y), 2) - assert id(f) != id(g) - computation1 = jax.jit(f).lower(1, 1).compiler_ir() - computation2 = jax.jit(g).lower(2, 3).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - with compilation_cache_include_metadata_in_key(include_metadata): - key1 = cc.get_cache_key(computation1, devices, compile_options, backend) - key2 = cc.get_cache_key(computation2, devices, compile_options, backend) - self.assertEqual(include_metadata, key1 != key2) - - def test_xla_flags(self): - if jtu.is_device_tpu_v4(): - raise unittest.SkipTest("TODO(b/240151176)") - - computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = xla_bridge.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - - orig_xla_flags = os.getenv("XLA_FLAGS") - orig_argv = sys.argv - try: - os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" - key1 = cc.get_cache_key(computation, devices, compile_options, backend) - os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1" - key2 = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertNotEqual(key1, key2) - - os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" - key3 = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertEqual(key1, key3) - - # Test flag in _xla_flags_to_exclude_from_cache_key - os.environ["XLA_FLAGS"] = ( - "--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8" - ) - key4 = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertEqual(key1, key4) - - # Test flags given on command line - del os.environ["XLA_FLAGS"] - sys.argv.append("--xla_gpu_autotune_level=0") - key5 = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertEqual(key1, key5) - sys.argv.append("--xla_force_host_platform_device_count=8") - self.assertEqual(key1, key5) - - finally: - if orig_xla_flags is not None: - os.environ["XLA_FLAGS"] = orig_xla_flags - elif os.getenv("XLA_FLAGS") is not None: - del os.environ["XLA_FLAGS"] - sys.argv = orig_argv - def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) @@ -457,44 +264,6 @@ def test_min_compile_time(self): files_in_cache = len(os.listdir(tmpdir)) self.assertEqual(files_in_cache, 1) - def create_new_debug_options(self, debug_options_obj): - debug_options_obj.xla_cpu_enable_fast_math = False - debug_options_obj.xla_cpu_fast_math_honor_infs = False - debug_options_obj.xla_cpu_fast_math_honor_nans = False - debug_options_obj.xla_cpu_fast_math_honor_division = False - debug_options_obj.xla_cpu_fast_math_honor_functions = False - debug_options_obj.xla_gpu_enable_fast_min_max = False - debug_options_obj.xla_backend_optimization_level = random.randint(0, 10) - debug_options_obj.xla_cpu_enable_xprof_traceme = False - debug_options_obj.xla_llvm_disable_expensive_passes = False - debug_options_obj.xla_test_all_input_layouts = False - return debug_options_obj - - def filled_compile_options(self): - compile_options = xla_client.CompileOptions() - compile_options.num_replicas = 1 - compile_options.num_partitions = 1 - shape = xla_client.Shape.array_shape(np.dtype(np.float32), [2]) - shape_array = [shape, shape] - compile_options.argument_layouts = shape_array - compile_options.executable_build_options.result_layout = shape - - device_assignment = xla_client.DeviceAssignment.create( - np.ndarray(shape=(2, 2)) - ) - compile_options.device_assignment = device_assignment - compile_options.executable_build_options.device_assignment = ( - device_assignment - ) - if xla_extension_version > 165: - compile_options.executable_build_options.fdo_profile = b"test_profile" - return compile_options - - def get_hashed_value(self, hash_function, hash_function_input): - hash_obj = hashlib.sha256() - hash_function(hash_obj, hash_function_input) - return hash_obj.digest().hex() - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())