From 67f24df74072bcaa5c31c95d4d3ea963521f0add Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 4 Oct 2024 12:37:37 -0700 Subject: [PATCH] Activate FFI implementation of symmetric Eigendecomposition. These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks. PiperOrigin-RevId: 682415506 --- jax/_src/export/_export.py | 8 +- .../cuda_eigh_cusolver_syev.py | 341 +++++++++++++++++- jax/_src/lax/linalg.py | 56 +-- jaxlib/gpu_solver.py | 79 +--- jaxlib/lapack.py | 114 ------ tests/export_back_compat_test.py | 49 ++- tests/shape_poly_test.py | 41 ++- 7 files changed, 445 insertions(+), 243 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 52c61048a926..bd95f5a0e29c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -947,11 +947,6 @@ def _check_lowering(lowering) -> None: "__gpu$xla.gpu.triton", # Pallas call on GPU # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", - # eigh on CPU - "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", - # eigh on GPU - "cusolver_syevj", "cusolver_syevd", - "hipsolver_syevj", "hipsolver_syevd", # eigh on TPU "Eigh", # eig on CPU @@ -969,9 +964,12 @@ def _check_lowering(lowering) -> None: # lu on GPU "cu_lu_pivots_to_permutation", "cusolver_getrf_ffi", "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", + "cu_lu_pivots_to_permutation", "cusolver_getrf_ffi", # qr on GPU "cusolver_geqrf_ffi", "cusolver_orgqr_ffi", "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", + # eigh on GPU + "cusolver_syevd_ffi", "hipsolver_syevd_ffi", # svd on GPU # lu on TPU "LuDecomposition", diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py index 896ecad019e2..56479e82f9d9 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py @@ -15,7 +15,7 @@ # ruff: noqa import datetime -from numpy import array, float32 +from numpy import array, float32, complex64 data_2023_03_17=dict( # Pasted from the test output (see back_compat_test.py module docstring) @@ -1409,3 +1409,342 @@ xla_call_module_version=4, ) # End paste ) + +data_2024_09_30 = {} + +data_2024_09_30["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941186 , -0.3696443 , -0.40418202 , 0.26339266 ], + [ 0.3696443 , 0.7941186 , 0.26339266 , 0.4041819 ], + [-0.054829806, -0.47930413 , 0.6857606 , 0.5449713 ], + [-0.4793042 , 0.05482992 , -0.5449712 , 0.68576056 ]], + dtype=float32), array([-3.7082872e+00, -4.0793765e-07, 4.4458108e-07, 3.3708286e+01], + dtype=float32)), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc13 = loc("jit()/jit(main)/pjit"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<4x4xf32> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<4x4xf32> loc(#loc12) + %6 = call @tril(%5) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc13) + %7:3 = stablehlo.custom_call @cusolver_syevd_ffi(%6) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4xf32>, tensor) loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf32> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<4xi1>, tensor<4xf32> loc(#loc14) + return %13, %17 : tensor<4x4xf32>, tensor<4xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc6))) -> (tensor<4x4xf32> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc15) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc15) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc17) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc18) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc19) + return %6 : tensor<4x4xf32> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc8 = loc("jit()/jit(main)/iota"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#'+/37;?\x03\xfb\xb17\x01U\x0f\x07\x0b\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03]\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x1f\x0f\x0b\x1f\x1fO\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x033\x17\x0f\x0f\x07\x07\x07\x13\x17\x07\x07\x17\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xae\x06\x1dQS\x1f\x05!\x17\x05J\x047\x1d\x1f\x07\x11\x03\x05\x1d!\x07\x1d#\x07\x1dIK\x03\x07\x15\x17\x19\x0b\x1b\x0b\x05#\x11\x01\x00\x05%\x05'\x05)\x05+\x05-\x05/\x1d'\x07\x051\x1d+\x07\x053\x1d/\x07\x055\x1d35\x057\x17\x05*\x045\x1d9;\x059\x17\x05*\x04\x1d\x1d?A\x05;\x17\x052\x04E\x1dEG\x05=\x17\x052\x04\x1f\x05?\x17\x052\x04\x1d\x03\x03O\x8d\x05A\x05C\x17\x05J\x04\x17\x1f!\x01\x1dE\x1dG\x03\x01\x1dI\x03\x03{\x1dK\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05os\r\x05]qWY\x1dM\r\x05]uWY\x1dO\x1dQ\x1dS\r\x03WY#\x1f\x1dU\x1f\x07\t\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00@\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\x8f\x91\x93\x95\x1dW\x13%\x00\x1dY\x05\x03\x0b\x03\x1d[\x1d]\x05\x01\x03\x03i\x03\x03\xa3\x15\x03\x01\x01\x01\x03\x07i\xa7\xa9\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x01\x07\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x0f)\x01\x0f)\x01\x17\x1d\x01\t)\x03\x11\x0f)\x05\x11\x11\x17\x13\x1b)\x05\x11\x11\r)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x0f!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04b\x04\x05\x01Q\x03\x13\x01\x07\x04:\x04\x03\x01\t\x0bP\x03\x03\x07\x04\xba\x02\x03/Y\x05B\x03\x05\x03\x07\x05B\x03\x07\x03\t\x05B\x03\t\x03\x07\x07B1\x0b\x03#\x13\x067\x03\x05\x03\x07\x15F=\r\x03\x05\x03\t\r\x06C\x03\x05\x05\t\x0b\x03F\x11\x0f\x03\x05\x03\x05\x17\x06\x11\x03\x05\x05\r\x0f\x19F\t\x11\x03\x05\x03\x11\x1bG\x01M\x13\x07\x05\x11\t\x03\x13\x03F\x01\x0f\x03\t\x03\x03\x0fF\x01\x15\x03-\x05\x19\x1b\x03F\x01\x0f\x03/\x03\x1d\x03F\x01\x0f\x03\x05\x03\x01\x03F\x01\x17\x03\x19\x03\x1f\t\x06\x01\x03\x05\x07#\x15!\x03F\x01\x0f\x031\x03\x1d\x03F\x01\x0f\x03\x11\x03\x01\x03F\x01\x19\x033\x03'\t\x06\x01\x03\x11\x07+\x17)\x11\x04\x03\x05%-\x0bP\t\x1b\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1d\x03\x07\x05B\x03\x07\x03\t\x07B\r\x0b\x03\x13\x03F\x0f\x0f\x03\x13\x03\x05\r\x06\x0f\x03\x13\x05\x07\t\x07B\r\x1f\x03\x13\x0fF%!\x03\x19\x05\x0b\r\x03F)\x0f\x03\x05\x03\x03\t\x06-\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe6\r_'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99EA;WgKMO;\x1b%)9i\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08k#\x05;\x01\x0b[kmwy\x03\x87\x03c\x03\x89\x03e\x03\x8b\x03U\x03a\x11\x97\x99\x9b[\x9d\x9f\xa1\xa5\x05g\xab\x03\xad\x03\xaf\x0b_}_a\x7f\x03\x81\x03\x83\x05g\x85", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941185704969033 , -0.36964433974346045, -0.4041819665640973 , + 0.2633926650306618 ], + [ 0.3696443397434605 , 0.7941185704969035 , 0.2633926650306616 , + 0.4041819665640974 ], + [-0.05482989100998295, -0.47930412176342563, 0.6857605696309688 , + 0.544971268097533 ], + [-0.47930412176342574, 0.05482989100998273, -0.544971268097533 , + 0.6857605696309688 ]]), array([-3.7082869338697053e+00, 7.7329581044653176e-17, + 8.6623770428558249e-16, 3.3708286933869694e+01])), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc13 = loc("jit()/jit(main)/pjit"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<4x4xf64> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<4x4xf64> loc(#loc12) + %6 = call @tril(%5) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc13) + %7:3 = stablehlo.custom_call @cusolver_syevd_ffi(%6) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4xf64>, tensor) loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf64> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<4xi1>, tensor<4xf64> loc(#loc14) + return %13, %17 : tensor<4x4xf64>, tensor<4xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc6))) -> (tensor<4x4xf64> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc15) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc15) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc17) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc18) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc19) + return %6 : tensor<4x4xf64> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc8 = loc("jit()/jit(main)/iota"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#'+/37;?\x03\xfb\xb17\x01U\x0f\x07\x0b\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03]\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b/\x0f\x0b//O\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x033\x17\x0f\x0f\x07\x07\x07\x13\x17\x07\x07\x17\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xde\x06\x1dQS\x1f\x05!\x17\x05J\x047\x1d\x1f\x07\x11\x03\x05\x1d!\x07\x1d#\x07\x1dIK\x03\x07\x15\x17\x19\x0b\x1b\x0b\x05#\x11\x01\x00\x05%\x05'\x05)\x05+\x05-\x05/\x1d'\x07\x051\x1d+\x07\x053\x1d/\x07\x055\x1d35\x057\x17\x05*\x045\x1d9;\x059\x17\x05*\x04\x1d\x1d?A\x05;\x17\x052\x04E\x1dEG\x05=\x17\x052\x04\x1f\x05?\x17\x052\x04\x1d\x03\x03O\x8d\x05A\x05C\x17\x05J\x04\x17\x1f!\x01\x1dE\x1dG\x03\x01\x1dI\x03\x03{\x1dK\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05os\r\x05]qWY\x1dM\r\x05]uWY\x1dO\x1dQ\x1dS\r\x03WY#\x1f\x1dU\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\x8f\x91\x93\x95\x1dW\x13%\x00\x1dY\x05\x03\x0b\x03\x1d[\x1d]\x05\x01\x03\x03i\x03\x03\xa3\x15\x03\x01\x01\x01\x03\x07i\xa7\xa9\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x01\x07\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x0f)\x01\x0f)\x01\x17\x1d\x01\x0b)\x03\x11\x0f)\x05\x11\x11\x17\x13\x1b)\x05\x11\x11\r)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x0f!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04b\x04\x05\x01Q\x03\x13\x01\x07\x04:\x04\x03\x01\t\x0bP\x03\x03\x07\x04\xba\x02\x03/Y\x05B\x03\x05\x03\x07\x05B\x03\x07\x03\t\x05B\x03\t\x03\x07\x07B1\x0b\x03#\x13\x067\x03\x05\x03\x07\x15F=\r\x03\x05\x03\t\r\x06C\x03\x05\x05\t\x0b\x03F\x11\x0f\x03\x05\x03\x05\x17\x06\x11\x03\x05\x05\r\x0f\x19F\t\x11\x03\x05\x03\x11\x1bG\x01M\x13\x07\x05\x11\t\x03\x13\x03F\x01\x0f\x03\t\x03\x03\x0fF\x01\x15\x03-\x05\x19\x1b\x03F\x01\x0f\x03/\x03\x1d\x03F\x01\x0f\x03\x05\x03\x01\x03F\x01\x17\x03\x19\x03\x1f\t\x06\x01\x03\x05\x07#\x15!\x03F\x01\x0f\x031\x03\x1d\x03F\x01\x0f\x03\x11\x03\x01\x03F\x01\x19\x033\x03'\t\x06\x01\x03\x11\x07+\x17)\x11\x04\x03\x05%-\x0bP\t\x1b\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1d\x03\x07\x05B\x03\x07\x03\t\x07B\r\x0b\x03\x13\x03F\x0f\x0f\x03\x13\x03\x05\r\x06\x0f\x03\x13\x05\x07\t\x07B\r\x1f\x03\x13\x0fF%!\x03\x19\x05\x0b\r\x03F)\x0f\x03\x05\x03\x03\t\x06-\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe6\r_'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99EA;WgKMO;\x1b%)9i\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08k#\x05;\x01\x0b[kmwy\x03\x87\x03c\x03\x89\x03e\x03\x8b\x03U\x03a\x11\x97\x99\x9b[\x9d\x9f\xa1\xa5\x05g\xab\x03\xad\x03\xaf\x0b_}_a\x7f\x03\x81\x03\x83\x05g\x85", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.79411864 +0.j, 0.3696443 +0.j, 0.40418214 +0.j, + -0.26339263 +0.j], + [ 0.3696443 +0.j, -0.7941186 +0.j, -0.26339272 +0.j, + -0.40418193 +0.j], + [-0.054829765+0.j, 0.47930422 +0.j, -0.6857606 +0.j, + -0.5449713 +0.j], + [-0.47930422 +0.j, -0.054829985+0.j, 0.5449712 +0.j, + -0.6857606 +0.j]], dtype=complex64), array([-3.7082872e+00, -2.9983883e-07, 3.5983098e-07, 3.3708286e+01], + dtype=float32)), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc18 = loc("jit()/jit(main)/pjit"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<4x4xf32> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<4x4xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<4x4xcomplex> loc(#loc16) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<4x4xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc18) + %11:3 = stablehlo.custom_call @cusolver_syevd_ffi(%10) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf32>, tensor) loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf32> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<4xi1>, tensor<4xf32> loc(#loc19) + return %17, %21 : tensor<4x4xcomplex>, tensor<4xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc7))) -> (tensor<4x4xcomplex> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc20) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc20) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc22) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc23) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc24) + return %6 : tensor<4x4xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc9 = loc("jit()/jit(main)/iota"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x015\x05\x01\x05%\x01\x03\x0b\x03#\x0f\x13\x17\x1b\x1f#'+/37;?CGKO\x03*\x02\xc5=\x01g\x0f\x07\x0b\x17\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03_\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b/\x0f\x0b\x1f//O\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x039\x17\x0f\x0f\x07\x07\x07\x13\x17\x0b\x17\x07\x07\x17\x0f\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x86\x07\x1dce\x1f\x05)\x17\x05J\x047\x1d!\x07\x17\x052\x043\x11\x03\x05\x1d#\x07\x1d%\x07\x1d[]\x03\x07\x17\x19\x1b\r\x1d\r\x05+\x11\x01\x00\x05-\x05/\x051\x053\x055\x057\x1d)\x07\x059\x1d-\x07\x05;\x1d1\x07\x05=\x1d57\x05?\x17\x05*\x045\x1d;=\x05A\x17\x05*\x04\x1d\x1dAC\x05C\x17\x052\x04E\x1dG\x0b\x05E\x1dK\x0b\x05G\x1dO\x0b\x05I\x1dS\x0b\x05K\x1dWY\x05M\x17\x052\x04\x1f\x05O\x17\x052\x04\x1d\x03\x03a\xa1\x05Q\x05S\x17\x05J\x04\x17\x1f'\x01\x1dU\x1dW\x03\x01\x1dY\x03\x03\x8d\x1d[\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00##\x03\x05\x81\x85\r\x05o\x83ik\x1d]\r\x05o\x87ik\x1d_\x1da\x1dc\r\x03ik#%\x1de\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x1f\t\x00\x00\xc0\x7f\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\xa3\xa5\xa7\xa9\x1dg\x13+\x00\x1di\x05\x03\x0b\x03\x1dk\x1dm\x05\x01\x03\x03{\x03\x03\xb7\x15\x03\x01\x01\x01\x03\x07{\xbb\xbd\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f1\x01\x07\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x15)\x01\x15)\x01\x1b\x1d\x01\t)\x03\x11\x0f)\x05\x11\x11\x1b\x03\x0f)\x05\x11\x11\x0f\x13\x1b)\x05\x11\x11\r)\x01\x0f)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x15!)\x03\t\x19)\x03\x05\x19)\x03\x01\x19)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04\xee\x04\x05\x01Q\x03\x15\x01\x07\x04\xc6\x04\x03\x01\t\x0bP\x03\x03\x07\x04F\x03\x039m\x05B\x03\x05\x03\x1f\x05B\x03\x07\x03\x07\x05B\x03\t\x03\t\x05B\x03\x0b\x03\x07\x07B3\r\x03)\x13\x069\x03\x05\x03\t\x15F?\x0f\x03\x05\x03\x0b\x17\x06E\x03\x17\x03\r\x19\x06I\x03\x17\x03\r\x1b\x06M\x03\x17\x03\x11\x1d\x06Q\x03\x05\x05\x0f\x13\r\x06U\x03\x05\x05\x0b\x15\x03F\x13\x11\x03\x05\x03\x07\x1f\x06\x13\x03\x05\x05\x17\x19!F\t\x13\x03\x05\x03\x1b#G\x01_\x15\x07\x05\x11\t\x03\x1d\x03F\x01\x11\x03\t\x03\x05\x0fF\x01\x17\x033\x05#%\x03F\x01\x11\x035\x03'\x03F\x01\x11\x03\x05\x03\x03\x03F\x01\x19\x03\x1d\x03)\t\x06\x01\x03\x05\x07-\x1f+\x03F\x01\x11\x037\x03'\x03F\x01\x11\x03\x11\x03\x01\x03F\x01\x1b\x039\x031\t\x06\x01\x03\x11\x075!3\x11\x04\x03\x05/7\x0bP\t\x1d\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1f\x03\x07\x05B\x03\t\x03\t\x07B\x0f\r\x03\x13\x03F\x11\x11\x03\x13\x03\x05\r\x06\x11\x03\x13\x05\x07\t\x07B\x0f!\x03\x13\x0fF'#\x03\x1d\x05\x0b\r\x03F+\x11\x03\x05\x03\x03\t\x06/\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00r\x10o'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99A9;;EA;WgKMO;\x1b%)9i\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08o%\x05?\x01\x0bm}\x7f\x89\x8b\x03\x99\x03\x9b\x03u\x03\x9d\x03w\x03\x9f\x03g\x03s\x11\xab\xad\xafm\xb1\xb3\xb5\xb9\x05y\xbf\x03\xc1\x03\xc3\x0bq\x8fqs\x91\x03\x93\x03\x95\x05y\x97", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941185704969035 +0.j, 0.3696443397434604 +0.j, + 0.4041819665640972 +0.j, -0.2633926650306618 +0.j], + [ 0.3696443397434601 +0.j, -0.7941185704969035 +0.j, + -0.2633926650306616 +0.j, -0.4041819665640975 +0.j], + [-0.05482989100998286+0.j, 0.4793041217634256 +0.j, + -0.6857605696309689 +0.j, -0.5449712680975332 +0.j], + [-0.47930412176342574+0.j, -0.05482989100998264+0.j, + 0.5449712680975333 +0.j, -0.6857605696309688 +0.j]]), array([-3.7082869338697044e+00, 3.5411017930205070e-16, + 6.5803628062392796e-16, 3.3708286933869694e+01])), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc18 = loc("jit()/jit(main)/pjit"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<4x4xf64> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<4x4xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<4x4xcomplex> loc(#loc16) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<4x4xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc18) + %11:3 = stablehlo.custom_call @cusolver_syevd_ffi(%10) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf64>, tensor) loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf64> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<4xi1>, tensor<4xf64> loc(#loc19) + return %17, %21 : tensor<4x4xcomplex>, tensor<4xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc7))) -> (tensor<4x4xcomplex> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc20) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc20) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc22) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc23) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc24) + return %6 : tensor<4x4xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc9 = loc("jit()/jit(main)/iota"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x015\x05\x01\x05%\x01\x03\x0b\x03#\x0f\x13\x17\x1b\x1f#'+/37;?CGKO\x03*\x02\xc5=\x01g\x0f\x07\x0b\x17\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03_\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0bO\x0f\x0b/OOO\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x039\x17\x0f\x0f\x07\x07\x07\x13\x17\x0b\x17\x07\x07\x17\x0f\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xf6\x07\x1dce\x1f\x05)\x17\x05J\x047\x1d!\x07\x17\x052\x043\x11\x03\x05\x1d#\x07\x1d%\x07\x1d[]\x03\x07\x17\x19\x1b\r\x1d\r\x05+\x11\x01\x00\x05-\x05/\x051\x053\x055\x057\x1d)\x07\x059\x1d-\x07\x05;\x1d1\x07\x05=\x1d57\x05?\x17\x05*\x045\x1d;=\x05A\x17\x05*\x04\x1d\x1dAC\x05C\x17\x052\x04E\x1dG\x0b\x05E\x1dK\x0b\x05G\x1dO\x0b\x05I\x1dS\x0b\x05K\x1dWY\x05M\x17\x052\x04\x1f\x05O\x17\x052\x04\x1d\x03\x03a\xa1\x05Q\x05S\x17\x05J\x04\x17\x1f'\x01\x1dU\x1dW\x03\x01\x1dY\x03\x03\x8d\x1d[\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00##\x03\x05\x81\x85\r\x05o\x83ik\x1d]\r\x05o\x87ik\x1d_\x1da\x1dc\r\x03ik#%\x1de\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\xa3\xa5\xa7\xa9\x1dg\x13+\x00\x1di\x05\x03\x0b\x03\x1dk\x1dm\x05\x01\x03\x03{\x03\x03\xb7\x15\x03\x01\x01\x01\x03\x07{\xbb\xbd\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f1\x01\x07\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x15)\x01\x15)\x01\x1b\x1d\x01\x0b)\x03\x11\x0f)\x05\x11\x11\x1b\x03\x0f)\x05\x11\x11\x0f\x13\x1b)\x05\x11\x11\r)\x01\x0f)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x15!)\x03\t\x19)\x03\x05\x19)\x03\x01\x19)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04\xee\x04\x05\x01Q\x03\x15\x01\x07\x04\xc6\x04\x03\x01\t\x0bP\x03\x03\x07\x04F\x03\x039m\x05B\x03\x05\x03\x1f\x05B\x03\x07\x03\x07\x05B\x03\t\x03\t\x05B\x03\x0b\x03\x07\x07B3\r\x03)\x13\x069\x03\x05\x03\t\x15F?\x0f\x03\x05\x03\x0b\x17\x06E\x03\x17\x03\r\x19\x06I\x03\x17\x03\r\x1b\x06M\x03\x17\x03\x11\x1d\x06Q\x03\x05\x05\x0f\x13\r\x06U\x03\x05\x05\x0b\x15\x03F\x13\x11\x03\x05\x03\x07\x1f\x06\x13\x03\x05\x05\x17\x19!F\t\x13\x03\x05\x03\x1b#G\x01_\x15\x07\x05\x11\t\x03\x1d\x03F\x01\x11\x03\t\x03\x05\x0fF\x01\x17\x033\x05#%\x03F\x01\x11\x035\x03'\x03F\x01\x11\x03\x05\x03\x03\x03F\x01\x19\x03\x1d\x03)\t\x06\x01\x03\x05\x07-\x1f+\x03F\x01\x11\x037\x03'\x03F\x01\x11\x03\x11\x03\x01\x03F\x01\x1b\x039\x031\t\x06\x01\x03\x11\x075!3\x11\x04\x03\x05/7\x0bP\t\x1d\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1f\x03\x07\x05B\x03\t\x03\t\x07B\x0f\r\x03\x13\x03F\x11\x11\x03\x13\x03\x05\r\x06\x11\x03\x13\x05\x07\t\x07B\x0f!\x03\x13\x0fF'#\x03\x1d\x05\x0b\r\x03F+\x11\x03\x05\x03\x03\t\x06/\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00r\x10o'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99A9;;EA;WgKMO;\x1b%)9i\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08o%\x05?\x01\x0bm}\x7f\x89\x8b\x03\x99\x03\x9b\x03u\x03\x9d\x03w\x03\x9f\x03g\x03s\x11\xab\xad\xafm\xb1\xb3\xb5\xb9\x05y\xbf\x03\xc1\x03\xc3\x0bq\x8fqs\x91\x03\x93\x03\x95\x05y\x97", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ec2dd91b258a..c7a1a599f298 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -873,7 +873,7 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( - "Argument to symmetric eigendecomposition must have shape [..., n, n]," + "Argument to symmetric eigendecomposition must have shape [..., n, n], " "got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] @@ -894,33 +894,39 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): def _eigh_cpu_gpu_lowering( - syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index, - platform=None + ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + target_name_prefix: str ): del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in v_aval, w_aval = ctx.avals_out n = operand_aval.shape[-1] - batch_dims = operand_aval.shape[:-2] - - # The eigh implementation on CPU and GPU uses lapack helper routines to - # find the size of the workspace based on the non-batch dimensions. - # Therefore, we cannot yet support dynamic non-batch dimensions. - if not is_constant_shape(operand_aval.shape[-2:]): - raise NotImplementedError( - "Shape polymorphism for native lowering for eigh is implemented " - f"only for the batch dimensions: {operand_aval.shape}") - if not (subset_by_index is None or subset_by_index == (0, n)): - raise NotImplementedError("subset_by_index not implemented for CPU and GPU") + raise NotImplementedError("subset_by_index not supported on CPU and GPU") + batch_dims = operand_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1)), + tuple(range(nb - 1, -1, -1))] + if target_name_prefix == "cpu": + dtype = operand_aval.dtype + prefix = "he" if dtypes.issubdtype(dtype, np.complexfloating) else "sy" + target_name = lapack.prepare_lapack_call(f"{prefix}evd_ffi", + operand_aval.dtype) + kwargs = { + "mode": np.uint8(ord("V")), + "uplo": np.uint8(ord("L" if lower else "U")), + } + else: + target_name = f"{target_name_prefix}solver_syevd_ffi" + kwargs = {"lower": lower, "algorithm": np.uint8(0)} - op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - cpu_args = [] - if platform == "cpu": - ctx_args = (ctx,) - cpu_args.extend(ctx_args) - v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, - a_shape_vals=op_shape_vals, lower=lower) + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) + sub_ctx = ctx.replace(avals_out=[v_aval, w_aval, info_aval]) + v, w, info = rule(sub_ctx, operand, **kwargs) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") @@ -1054,17 +1060,15 @@ def _eigh_batching_rule( batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') if gpu_solver is not None: mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd, - platform='cuda'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd, - platform='rocm'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering( diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 68a95521ab69..457d9f59d210 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from functools import partial import importlib import math @@ -24,9 +23,7 @@ from jaxlib import xla_client -from .hlo_helpers import ( - DimensionSize, ShapeTypePair, mk_result_types_and_shapes, - custom_call, ensure_hlo_s32, hlo_s32, dense_int_array) +from .hlo_helpers import custom_call, dense_int_array try: from .cuda import _blas as _cublas # pytype: disable=import-error @@ -122,80 +119,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver) -def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *, - a_shape_vals: tuple[DimensionSize, ...], lower=False): - """Symmetric (Hermitian) eigendecomposition.""" - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - assert type(m) is int and type(n) is int and m == n, a_shape_vals - batch_dims_vals = a_shape_vals[:-2] - - num_bd = len(batch_dims_vals) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - dynamic_batch_dims = any(type(d) != int for d in batch_dims_vals) - if dynamic_batch_dims: - batch_int = -1 # Signals to the kernel that the batch is an operand. - else: - batch_int = math.prod(batch_dims_vals) - - if have_jacobi_solver and n <= 32 and not dynamic_batch_dims: - # We cannot use syevj for dynamic shapes because the workspace size - # depends on the batch size. - kernel = f"{platform}solver_syevj" - lwork, opaque = gpu_solver.build_syevj_descriptor( - np.dtype(dtype), lower, batch_int, n) - else: - kernel = f"{platform}solver_syevd" - lwork, opaque = gpu_solver.build_syevd_descriptor( - np.dtype(dtype), lower, batch_int, n) - # TODO(Ruturaj4): Currently, hipsolverSsyevd sets lwork to 0 if n==0. - # Remove if this behavior changes in then new ROCm release. - if n > 0 or platform != "hip": - assert lwork > 0 - - if ir.ComplexType.isinstance(a_type.element_type): - eigvals_type = ir.ComplexType(a_type.element_type).element_type - else: - eigvals_type = a_type.element_type - - i32_type = ir.IntegerType.get_signless(32) - operands = [a] - operand_layouts = [layout] - if dynamic_batch_dims: - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - operands.append(batch_size_val) - operand_layouts.append(()) - - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type)] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( - kernel, - result_types=result_types, - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes).results - return out[:3] - -cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True) -rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True) - - def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 4acffb9ba16b..9b3acf641db2 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -340,120 +340,6 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, ).results[1:] -# # syevd: Symmetric eigendecomposition - -def syevd_hlo(ctx, dtype, a: ir.Value, - a_shape_vals: tuple[DimensionSize, ...], - lower=False): - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - # Non-batch dimensions must be static - assert type(m) is int and type(n) is int and m == n, a_shape_vals - - batch_dims_vals = a_shape_vals[:-2] - num_bd = len(a_shape_vals) - 2 - mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors) - - i32_type = ir.IntegerType.get_signless(32) - workspace: list[ShapeTypePair] - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - # Hermitian is for complex square matrices, symmetric otherwise. - fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" - fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype) - if ctx.is_forward_compat(): - fn = fn_base - if dtype == np.float32: - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.float64: - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex64: - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex128: - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - scalar_layout = [] - shape_layout = [0] - workspace_layouts = [shape_layout] * len(workspace) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - result_types, result_shapes = mk_result_types_and_shapes( - [(a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type)] + workspace - ) - - return custom_call( - fn, - result_types=result_types, - operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={3: 0}, - result_shapes=result_shapes, - ).results[:3] - fn = fn_base + "_ffi" - if dtype == np.float32 or dtype == np.complex64: - eigvals_type = ir.F32Type.get() - elif dtype == np.float64 or dtype == np.complex128: - eigvals_type = ir.F64Type.get() - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - result_types, result_shapes = mk_result_types_and_shapes([ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type), - ]) - - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={ - "uplo": _matrix_uplo_attr(lower=lower), - "mode": mode, - }, - api_version=4, - ).results - - # # geev: Nonsymmetric eigendecomposition (eig) def geev_hlo(ctx, dtype, input, *, diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index e261e1dfce83..7daa20cb159e 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -134,7 +134,7 @@ def test_custom_call_coverage(self): cuda_lu_pivots_to_permutation.data_2024_08_08, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, - cuda_eigh_cusolver_syev.data_2023_03_17, + cuda_eigh_cusolver_syev.data_2024_09_30, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, @@ -165,7 +165,7 @@ def test_custom_call_coverage(self): "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py # The following require ROCm to test "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", - "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", + "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered, @@ -310,16 +310,19 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): size = 8 operand = CompatTest.eigh_input((size, size), dtype) func = lambda: CompatTest.eigh_harness((8, 8), dtype) - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + info = cpu_eigh_lapack_syev.data_2024_08_19[dtype_name] + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # FFI Kernel test - with config.export_ignore_forward_compatibility(True): - data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand)) + + # Legacy custom call test + data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand), + expect_current_custom_calls=info["custom_call_targets"]) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", @@ -327,17 +330,19 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): for dtype_name in ("f32", "f64") # We use different custom calls for sizes <= 32 for variant in ["syevj", "syevd"]) - def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"): + def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"): if not config.enable_x64.value and dtype_name == "f64": self.skipTest("Test disabled for x32 mode") - if jtu.test_device_matches(["cuda"]): + if jtu.test_device_matches(["rocm"]): + data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) + prefix = "hip" + elif jtu.test_device_matches(["cuda"]): if _is_required_cusolver_version_satisfied(11600): # The underlying problem is that this test assumes the workspace size can be # queried from an older version of cuSOLVER and then be used in a newer one. self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"]) - elif jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) + prefix = "cu" else: self.skipTest("Unsupported platform") # For lax.linalg.eigh @@ -347,6 +352,26 @@ def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"): atol = dict(f32=1e-2, f64=1e-10)[dtype_name] operand = CompatTest.eigh_input((size, size), dtype) func = lambda: CompatTest.eigh_harness((size, size), dtype) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand), + expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"]) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_gpu_eigh_solver_syev(self, dtype_name="f32"): + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Unsupported platform") + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + size = 4 + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-2, f64=1e-10, c64=1e-2, c128=1e-10)[dtype_name] + operand = CompatTest.eigh_input((size, size), dtype) + data = self.load_testdata(cuda_eigh_cusolver_syev.data_2024_09_30[dtype_name]) + func = lambda: CompatTest.eigh_harness((size, size), dtype) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index d61248923efc..7f8aa2524f74 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2848,6 +2848,34 @@ def test_vmap_error(self): ((2, 3, 4, 5), "b1, b2, m, n"), ] ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "eigh", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{lower=}", + lambda x, lower: lax.linalg.eigh(x, lower=lower), + arg_descriptors=[RandArg(shape, dtype), StaticArg(lower)], + polymorphic_shapes=[poly]) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for lower in [True, False] + for shape, poly in [ + ((4, 4), "n, n"), + ((2, 3, 4, 4), "b1, b2, ..."), + ((2, 3, 4, 4), "b1, b2, n, n"), + ] + ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "eigh_shape_error", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}", + lambda x: lax.linalg.eigh(x, symmetrize_input=False), + arg_descriptors=[RandArg(shape, dtype)], + polymorphic_shapes=[poly], + expect_error=(ValueError, "Argument to symmetric eigendecomposition")) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for shape, poly in [ + ((4, 5), "m, n"), + ((2, 3, 4, 5), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, m, n"), + ] + ], [ # The random primitive tests, with threefry (both partitionable and # non-partitionable), and unsafe_rbg. @@ -3490,13 +3518,6 @@ def test_harness(self, harness: PolyHarness): if "nr_fft_lengths_2" in harness.fullname: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU") - if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]): - # For eigh on GPU with shape polymorphism under native serialization, - # we use a different lowering for small matrices. - shape = harness.original_harness.params["shape"] - if 0 < shape[-1] <= 32: - harness.check_result = False - if harness.group_name == "vmap_eigh": raise unittest.SkipTest( "Should not compare eigendecompositions for equality directly" @@ -3528,6 +3549,12 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") + if (harness.group_name == "eigh" and + not harness.polymorphic_shapes[0].endswith("...") and + jtu.test_device_matches(["tpu"])): + raise unittest.SkipTest( + "Shape polymorphsim for Eigh is only supported for batch dimensions on TPU.") + config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived # from test_harnesses.all_harnesses, which strips override_jax_config_flags.