Skip to content

Commit 44959be

Browse files
committed
gh-405: array API support for glass.core.algorithm
1 parent 72a037d commit 44959be

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

glass/core/algorithm.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22

33
from __future__ import annotations
44

5-
import numpy as np
6-
import numpy.typing as npt
5+
import typing
6+
7+
if typing.TYPE_CHECKING:
8+
import cupy as cp
9+
import jax.typing as jxt
10+
import numpy as np
11+
import numpy.typing as npt
712

813

914
def nnls(
10-
a: npt.NDArray[np.float64],
11-
b: npt.NDArray[np.float64],
15+
a: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
16+
b: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
1217
*,
1318
tol: float = 0.0,
1419
maxiter: int | None = None,
15-
) -> npt.NDArray[np.float64]:
20+
) -> npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike:
1621
"""
1722
Compute a non-negative least squares solution.
1823
@@ -27,8 +32,11 @@ def nnls(
2732
Chemometrics, 11, 393-401.
2833
2934
"""
30-
a = np.asanyarray(a)
31-
b = np.asanyarray(b)
35+
if a.__array_namespace__() != b.__array_namespace__():
36+
msg = "input arrays should belong to the same array library"
37+
raise ValueError(msg)
38+
39+
xp = a.__array_namespace__()
3240

3341
if a.ndim != 2:
3442
msg = "input `a` is not a matrix"
@@ -45,25 +53,25 @@ def nnls(
4553
if maxiter is None:
4654
maxiter = 3 * n
4755

48-
index = np.arange(n)
49-
p = np.full(n, fill_value=False)
50-
x = np.zeros(n)
56+
index = xp.arange(n)
57+
p = xp.full(n, fill_value=False)
58+
x = xp.zeros(n)
5159
for _ in range(maxiter):
52-
if np.all(p):
60+
if xp.all(p):
5361
break
54-
w = np.dot(b - a @ x, a)
55-
m = index[~p][np.argmax(w[~p])]
62+
w = xp.dot(b - a @ x, a)
63+
m = index[~p][xp.argmax(w[~p])]
5664
if w[m] <= tol:
5765
break
5866
p[m] = True
5967
while True:
6068
ap = a[:, p]
61-
xp = x[p]
62-
sp = np.linalg.solve(ap.T @ ap, b @ ap)
69+
x_new = x[p]
70+
sp = xp.linalg.solve(ap.T @ ap, b @ ap)
6371
t = sp <= 0
64-
if not np.any(t):
72+
if not xp.any(t):
6573
break
66-
alpha = -np.min(xp[t] / (xp[t] - sp[t]))
74+
alpha = -xp.min(xp[t] / (x_new[t] - sp[t]))
6775
x[p] += alpha * (sp - xp)
6876
p[x <= 0] = False
6977
x[p] = sp

0 commit comments

Comments
 (0)