Skip to content

Commit

Permalink
Activate FFI implementation of symmetric Eigendecomposition.
Browse files Browse the repository at this point in the history
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
  • Loading branch information
dfm authored and Google-ML-Automation committed Oct 4, 2024
1 parent 18f48bd commit 67f24df
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 243 deletions.
8 changes: 3 additions & 5 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,11 +947,6 @@ def _check_lowering(lowering) -> None:
"__gpu$xla.gpu.triton", # Pallas call on GPU
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
"hipsolver_syevj", "hipsolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
Expand All @@ -969,9 +964,12 @@ def _check_lowering(lowering) -> None:
# lu on GPU
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
# qr on GPU
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi",
# eigh on GPU
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
# svd on GPU
# lu on TPU
"LuDecomposition",
Expand Down

Large diffs are not rendered by default.

56 changes: 30 additions & 26 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"Argument to symmetric eigendecomposition must have shape [..., n, n], "
"got shape {}".format(operand.shape))

batch_dims = operand.shape[:-2]
Expand All @@ -894,33 +894,39 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):


def _eigh_cpu_gpu_lowering(
syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index,
platform=None
ctx, operand, *, lower, sort_eigenvalues, subset_by_index,
target_name_prefix: str
):
del sort_eigenvalues # The CPU/GPU implementations always sort.
operand_aval, = ctx.avals_in
v_aval, w_aval = ctx.avals_out
n = operand_aval.shape[-1]
batch_dims = operand_aval.shape[:-2]

# The eigh implementation on CPU and GPU uses lapack helper routines to
# find the size of the workspace based on the non-batch dimensions.
# Therefore, we cannot yet support dynamic non-batch dimensions.
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {operand_aval.shape}")

if not (subset_by_index is None or subset_by_index == (0, n)):
raise NotImplementedError("subset_by_index not implemented for CPU and GPU")
raise NotImplementedError("subset_by_index not supported on CPU and GPU")
batch_dims = operand_aval.shape[:-2]
nb = len(batch_dims)
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
result_layouts = [layout, tuple(range(nb, -1, -1)),
tuple(range(nb - 1, -1, -1))]
if target_name_prefix == "cpu":
dtype = operand_aval.dtype
prefix = "he" if dtypes.issubdtype(dtype, np.complexfloating) else "sy"
target_name = lapack.prepare_lapack_call(f"{prefix}evd_ffi",
operand_aval.dtype)
kwargs = {
"mode": np.uint8(ord("V")),
"uplo": np.uint8(ord("L" if lower else "U")),
}
else:
target_name = f"{target_name_prefix}solver_syevd_ffi"
kwargs = {"lower": lower, "algorithm": np.uint8(0)}

op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
cpu_args = []
if platform == "cpu":
ctx_args = (ctx,)
cpu_args.extend(ctx_args)
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
result_layouts=result_layouts,
operand_output_aliases={0: 0})
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
sub_ctx = ctx.replace(avals_out=[v_aval, w_aval, info_aval])
v, w, info = rule(sub_ctx, operand, **kwargs)

zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
Expand Down Expand Up @@ -1054,17 +1060,15 @@ def _eigh_batching_rule(
batching.primitive_batchers[eigh_p] = _eigh_batching_rule

mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'),
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'),
platform='cpu')

if gpu_solver is not None:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd,
platform='cuda'),
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd,
platform='rocm'),
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')

mlir.register_lowering(
Expand Down
79 changes: 1 addition & 78 deletions jaxlib/gpu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sequence
from functools import partial
import importlib
import math
Expand All @@ -24,9 +23,7 @@

from jaxlib import xla_client

from .hlo_helpers import (
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
from .hlo_helpers import custom_call, dense_int_array

try:
from .cuda import _blas as _cublas # pytype: disable=import-error
Expand Down Expand Up @@ -122,80 +119,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)


def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
a_shape_vals: tuple[DimensionSize, ...], lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
a_type = ir.RankedTensorType(a.type)
assert len(a_shape_vals) >= 2
m, n = a_shape_vals[-2:]
assert type(m) is int and type(n) is int and m == n, a_shape_vals
batch_dims_vals = a_shape_vals[:-2]

num_bd = len(batch_dims_vals)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

dynamic_batch_dims = any(type(d) != int for d in batch_dims_vals)
if dynamic_batch_dims:
batch_int = -1 # Signals to the kernel that the batch is an operand.
else:
batch_int = math.prod(batch_dims_vals)

if have_jacobi_solver and n <= 32 and not dynamic_batch_dims:
# We cannot use syevj for dynamic shapes because the workspace size
# depends on the batch size.
kernel = f"{platform}solver_syevj"
lwork, opaque = gpu_solver.build_syevj_descriptor(
np.dtype(dtype), lower, batch_int, n)
else:
kernel = f"{platform}solver_syevd"
lwork, opaque = gpu_solver.build_syevd_descriptor(
np.dtype(dtype), lower, batch_int, n)
# TODO(Ruturaj4): Currently, hipsolverSsyevd sets lwork to 0 if n==0.
# Remove if this behavior changes in then new ROCm release.
if n > 0 or platform != "hip":
assert lwork > 0

if ir.ComplexType.isinstance(a_type.element_type):
eigvals_type = ir.ComplexType(a_type.element_type).element_type
else:
eigvals_type = a_type.element_type

i32_type = ir.IntegerType.get_signless(32)
operands = [a]
operand_layouts = [layout]
if dynamic_batch_dims:
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
operands.append(batch_size_val)
operand_layouts.append(())

shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (n,), eigvals_type),
(batch_dims_vals, i32_type),
([lwork], a_type.element_type)]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
kernel,
result_types=result_types,
operands=operands,
backend_config=opaque,
operand_layouts=operand_layouts,
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0},
result_shapes=result_shapes).results
return out[:3]

cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True)
rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True)


def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
Expand Down
114 changes: 0 additions & 114 deletions jaxlib/lapack.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,120 +340,6 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
).results[1:]


# # syevd: Symmetric eigendecomposition

def syevd_hlo(ctx, dtype, a: ir.Value,
a_shape_vals: tuple[DimensionSize, ...],
lower=False):
a_type = ir.RankedTensorType(a.type)
assert len(a_shape_vals) >= 2
m, n = a_shape_vals[-2:]
# Non-batch dimensions must be static
assert type(m) is int and type(n) is int and m == n, a_shape_vals

batch_dims_vals = a_shape_vals[:-2]
num_bd = len(a_shape_vals) - 2
mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors)

i32_type = ir.IntegerType.get_signless(32)
workspace: list[ShapeTypePair]
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
# Hermitian is for complex square matrices, symmetric otherwise.
fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy"
fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype)
if ctx.is_forward_compat():
fn = fn_base
if dtype == np.float32:
eigvals_type = ir.F32Type.get()
workspace = [
([_lapack.syevd_work_size(n)], a_type.element_type),
([_lapack.syevd_iwork_size(n)], i32_type),
]
elif dtype == np.float64:
eigvals_type = ir.F64Type.get()
workspace = [
([_lapack.syevd_work_size(n)], a_type.element_type),
([_lapack.syevd_iwork_size(n)], i32_type),
]
elif dtype == np.complex64:
eigvals_type = ir.F32Type.get()
workspace = [
([_lapack.heevd_work_size(n)], a_type.element_type),
([_lapack.heevd_rwork_size(n)], eigvals_type),
([_lapack.syevd_iwork_size(n)], i32_type),
]
elif dtype == np.complex128:
eigvals_type = ir.F64Type.get()
workspace = [
([_lapack.heevd_work_size(n)], a_type.element_type),
([_lapack.heevd_rwork_size(n)], eigvals_type),
([_lapack.syevd_iwork_size(n)], i32_type),
]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))

scalar_layout = []
shape_layout = [0]
workspace_layouts = [shape_layout] * len(workspace)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

result_types, result_shapes = mk_result_types_and_shapes(
[(a_shape_vals, a_type.element_type),
(batch_dims_vals + (n,), eigvals_type),
(batch_dims_vals, i32_type)] + workspace
)

return custom_call(
fn,
result_types=result_types,
operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
] + workspace_layouts,
operand_output_aliases={3: 0},
result_shapes=result_shapes,
).results[:3]
fn = fn_base + "_ffi"
if dtype == np.float32 or dtype == np.complex64:
eigvals_type = ir.F32Type.get()
elif dtype == np.float64 or dtype == np.complex128:
eigvals_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

result_types, result_shapes = mk_result_types_and_shapes([
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (n,), eigvals_type),
(batch_dims_vals, i32_type),
])

return custom_call(
fn,
result_types=result_types,
operands=[a],
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
],
operand_output_aliases={0: 0},
result_shapes=result_shapes,
backend_config={
"uplo": _matrix_uplo_attr(lower=lower),
"mode": mode,
},
api_version=4,
).results


# # geev: Nonsymmetric eigendecomposition (eig)

def geev_hlo(ctx, dtype, input, *,
Expand Down
Loading

0 comments on commit 67f24df

Please sign in to comment.