2
2
3
3
from __future__ import annotations
4
4
5
- import numpy as np
6
- import numpy .typing as npt
5
+ import typing
6
+
7
+ if typing .TYPE_CHECKING :
8
+ import cupy as cp
9
+ import jax .typing as jxt
10
+ import numpy as np
11
+ import numpy .typing as npt
7
12
8
13
9
14
def nnls (
10
- a : npt .NDArray [np .float64 ],
11
- b : npt .NDArray [np .float64 ],
15
+ a : npt .NDArray [np .float64 ] | cp . ndarray | jxt . ArrayLike ,
16
+ b : npt .NDArray [np .float64 ] | cp . ndarray | jxt . ArrayLike ,
12
17
* ,
13
18
tol : float = 0.0 ,
14
19
maxiter : int | None = None ,
15
- ) -> npt .NDArray [np .float64 ]:
20
+ ) -> npt .NDArray [np .float64 ] | cp . ndarray | jxt . ArrayLike :
16
21
"""
17
22
Compute a non-negative least squares solution.
18
23
@@ -27,8 +32,11 @@ def nnls(
27
32
Chemometrics, 11, 393-401.
28
33
29
34
"""
30
- a = np .asanyarray (a )
31
- b = np .asanyarray (b )
35
+ if a .__array_namespace__ () != b .__array_namespace__ ():
36
+ msg = "input arrays should belong to the same array library"
37
+ raise ValueError (msg )
38
+
39
+ xp = a .__array_namespace__ ()
32
40
33
41
if a .ndim != 2 :
34
42
msg = "input `a` is not a matrix"
@@ -45,25 +53,25 @@ def nnls(
45
53
if maxiter is None :
46
54
maxiter = 3 * n
47
55
48
- index = np .arange (n )
49
- p = np .full (n , fill_value = False )
50
- x = np .zeros (n )
56
+ index = xp .arange (n )
57
+ p = xp .full (n , fill_value = False )
58
+ x = xp .zeros (n )
51
59
for _ in range (maxiter ):
52
- if np .all (p ):
60
+ if xp .all (p ):
53
61
break
54
- w = np .dot (b - a @ x , a )
55
- m = index [~ p ][np .argmax (w [~ p ])]
62
+ w = xp .dot (b - a @ x , a )
63
+ m = index [~ p ][xp .argmax (w [~ p ])]
56
64
if w [m ] <= tol :
57
65
break
58
66
p [m ] = True
59
67
while True :
60
68
ap = a [:, p ]
61
- xp = x [p ]
62
- sp = np .linalg .solve (ap .T @ ap , b @ ap )
69
+ x_new = x [p ]
70
+ sp = xp .linalg .solve (ap .T @ ap , b @ ap )
63
71
t = sp <= 0
64
- if not np .any (t ):
72
+ if not xp .any (t ):
65
73
break
66
- alpha = - np .min (xp [t ] / (xp [t ] - sp [t ]))
74
+ alpha = - xp .min (xp [t ] / (x_new [t ] - sp [t ]))
67
75
x [p ] += alpha * (sp - xp )
68
76
p [x <= 0 ] = False
69
77
x [p ] = sp
0 commit comments