Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array api] return NamedTuple from np.linalg APIs #19347

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import textwrap
import operator
from typing import Literal, cast, overload
from typing import Literal, NamedTuple, cast, overload

import jax
from jax import jit, custom_jvp
Expand All @@ -35,6 +35,27 @@
from jax._src.typing import ArrayLike, Array


class EighResult(NamedTuple):
eigenvalues: jax.Array
eigenvectors: jax.Array


class QRResult(NamedTuple):
Q: jax.Array
R: jax.Array


class SlogdetResult(NamedTuple):
sign: jax.Array
logabsdet: jax.Array


class SVDResult(NamedTuple):
U: jax.Array
S: jax.Array
Vh: jax.Array


def _H(x: ArrayLike) -> Array:
return ufuncs.conjugate(jnp.matrix_transpose(x))

Expand All @@ -51,10 +72,10 @@ def cholesky(a: ArrayLike) -> Array:

@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
hermitian: bool = False) -> SVDResult: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
hermitian: bool = False) -> SVDResult: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
Expand All @@ -63,12 +84,12 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | tuple[Array, Array, Array]: ...
hermitian: bool = False) -> Array | SVDResult: ...

@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | tuple[Array, Array, Array]:
hermitian: bool = False) -> Array | SVDResult:
check_arraylike("jnp.linalg.svd", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
Expand All @@ -83,11 +104,15 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
sign = lax.rev(sign, dimensions=[s.ndim - 1])
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
vh = _H(u * sign[..., None, :].astype(u.dtype))
return u, s, vh
return SVDResult(u, s, vh)
else:
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])

return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
if compute_uv:
u, s, vh = lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=True)
return SVDResult(u, s, vh)
else:
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False)


@_wraps(np.linalg.matrix_power)
Expand Down Expand Up @@ -195,7 +220,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
LU decomposition if ``None``.
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: str | None = None) -> tuple[Array, Array]:
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
check_arraylike("jnp.linalg.slogdet", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
Expand All @@ -204,9 +229,9 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> tuple[Array, Array]:
raise ValueError(msg.format(a_shape))

if method is None or method == "lu":
return _slogdet_lu(a)
return SlogdetResult(*_slogdet_lu(a))
elif method == "qr":
return _slogdet_qr(a)
return SlogdetResult(*_slogdet_qr(a))
else:
raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
"are 'lu' (`None`), and 'qr'.")
Expand Down Expand Up @@ -385,7 +410,7 @@ def eigvals(a: ArrayLike) -> Array:
@_wraps(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: str | None = None,
symmetrize_input: bool = True) -> tuple[Array, Array]:
symmetrize_input: bool = True) -> EighResult:
check_arraylike("jnp.linalg.eigh", a)
if UPLO is None or UPLO == "L":
lower = True
Expand All @@ -397,7 +422,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,

a, = promote_dtypes_inexact(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v
return EighResult(w, v)


@_wraps(np.linalg.eigvalsh)
Expand Down Expand Up @@ -581,16 +606,16 @@ def norm(x: ArrayLike, ord: int | str | None = None,
@overload
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
@overload
def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]: ...
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...

@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]:
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
check_arraylike("jnp.linalg.qr", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
return a.mT, taus
return QRResult(a.mT, taus)
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
Expand All @@ -600,7 +625,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]:
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return r
return q, r
return QRResult(q, r)


@_wraps(np.linalg.solve)
Expand Down
32 changes: 4 additions & 28 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import NamedTuple

import jax
from jax.experimental.array_api._data_type_functions import (
_promote_to_default_dtype,
)

class EighResult(NamedTuple):
eigenvalues: jax.Array
eigenvectors: jax.Array

class QRResult(NamedTuple):
Q: jax.Array
R: jax.Array

class SlogdetResult(NamedTuple):
sign: jax.Array
logabsdet: jax.Array

class SVDResult(NamedTuple):
U: jax.Array
S: jax.Array
Vh: jax.Array

def cholesky(x, /, *, upper=False):
"""
Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x.
Expand Down Expand Up @@ -65,8 +45,7 @@ def eigh(x, /):
"""
Returns an eigenvalue decomposition of a complex Hermitian or real symmetric matrix (or a stack of matrices) x.
"""
eigenvalues, eigenvectors = jax.numpy.linalg.eigh(x)
return EighResult(eigenvalues=eigenvalues, eigenvectors=eigenvectors)
return jax.numpy.linalg.eigh(x)

def eigvalsh(x, /):
"""
Expand Down Expand Up @@ -122,15 +101,13 @@ def qr(x, /, *, mode='reduced'):
"""
Returns the QR decomposition of a full column rank matrix (or a stack of matrices).
"""
Q, R = jax.numpy.linalg.qr(x, mode=mode)
return QRResult(Q=Q, R=R)
return jax.numpy.linalg.qr(x, mode=mode)

def slogdet(x, /):
"""
Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) x.
"""
sign, logabsdet = jax.numpy.linalg.slogdet(x)
return SlogdetResult(sign, logabsdet)
return jax.numpy.linalg.slogdet(x)

def solve(x1, x2, /):
"""
Expand All @@ -147,8 +124,7 @@ def svd(x, /, *, full_matrices=True):
"""
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) x.
"""
U, S, Vh = jax.numpy.linalg.svd(x, full_matrices=full_matrices)
return SVDResult(U=U, S=S, Vh=Vh)
return jax.numpy.linalg.svd(x, full_matrices=full_matrices)

def svdvals(x, /):
"""
Expand Down