Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 #19674

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ Remember to align the itemized text with the first line of an item within a list
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
solves.

## jaxlib 0.4.24

Expand Down
13 changes: 10 additions & 3 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Sequence
from functools import partial
import warnings

import numpy as np
import textwrap
Expand Down Expand Up @@ -635,9 +636,15 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
# TODO(jakevdp): this condition matches the broadcasting behavior in numpy < 2.0.
# For the array API specification, we would check only if b.ndim == 1.
if b.ndim == 1 or a.ndim == b.ndim + 1:

if b.ndim == 1:
signature = "(m,m),(m)->(m)"
elif a.ndim == b.ndim + 1:
# Deprecation warning added 2024-02-06
warnings.warn("jnp.linalg.solve: batched 1D solves with b.ndim > 1 are deprecated, "
"and in the future will be treated as a batched 2D solve. "
"Use solve(a, b[..., None])[..., 0] to avoid this warning.",
category=FutureWarning)
signature = "(m,m),(m)->(m)"
else:
signature = "(m,m),(m,n)->(m,n)"
Expand Down
2 changes: 2 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ def testSolve(self, lhs_shape, rhs_shape, dtype):
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
def testSolveBroadcasting(self, lhs_shape, rhs_shape):
# Batched solve can involve some ambiguities; this test checks
# that we match NumPy's convention in all cases.
Expand Down Expand Up @@ -1196,6 +1197,7 @@ def test(x):
self.assertAllClose(xc, grad_test_jc(xc))

@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
def testIssue1151(self):
rng = self.rng()
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
Expand Down
Loading