diff --git a/cvxpy/atoms/affine/cumsum.py b/cvxpy/atoms/affine/cumsum.py index 28cc04b46e..276b082bd1 100644 --- a/cvxpy/atoms/affine/cumsum.py +++ b/cvxpy/atoms/affine/cumsum.py @@ -40,32 +40,18 @@ def get_diff_mat(dim: int, axis: int) -> sp.csc_matrix: Returns ------- - SciPy CSC matrix + sp.csc_matrix A square matrix representing first order difference. """ - # Construct a sparse matrix representation. - val_arr = [] - row_arr = [] - col_arr = [] - for i in range(dim): - val_arr.append(1.) - row_arr.append(i) - col_arr.append(i) - if i > 0: - val_arr.append(-1.) - row_arr.append(i) - col_arr.append(i-1) - - mat = sp.csc_matrix((val_arr, (row_arr, col_arr)), - (dim, dim)) - if axis == 0: - return mat - else: - return mat.T + mat = sp.diags([np.ones(dim), -np.ones(dim - 1)], [0, -1], + shape=(dim, dim), + format='csc') + return mat if axis == 0 else mat.T class cumsum(AffAtom, AxisAtom): - """Cumulative sum. + """ + Cumulative sum of the elements of an expression. Attributes ---------- @@ -79,13 +65,13 @@ def __init__(self, expr: Expression, axis: int = 0) -> None: @AffAtom.numpy_numeric def numeric(self, values): - """Convolve the two values. + """ + Returns the cumulative product of elements of an expression over an axis. """ return np.cumsum(values[0], axis=self.axis) def shape_from_args(self) -> Tuple[int, ...]: - """The same as the input. - """ + """The same as the input.""" return self.args[0].shape def _grad(self, values): @@ -99,12 +85,8 @@ def _grad(self, values): Returns: A list of SciPy CSC sparse matrices or None. """ - # TODO inefficient dim = values[0].shape[self.axis] - mat = np.zeros((dim, dim)) - for i in range(dim): - for j in range(i+1): - mat[i, j] = 1 + mat = sp.tril(np.ones((dim, dim))) var = Variable(self.args[0].shape) if self.axis == 0: grad = MulExpression(mat, var)._grad(values)[1] @@ -113,8 +95,7 @@ def _grad(self, values): return [grad] def get_data(self): - """Returns the axis being summed. - """ + """Returns the axis being summed.""" return [self.axis] def graph_implementation( diff --git a/cvxpy/tests/test_atoms.py b/cvxpy/tests/test_atoms.py index 1d654d8533..313f06779e 100644 --- a/cvxpy/tests/test_atoms.py +++ b/cvxpy/tests/test_atoms.py @@ -1033,6 +1033,18 @@ def test_conv(self) -> None: with pytest.raises(cp.DPPError): problem.solve(enforce_dpp=True) + def test_cumsum(self) -> None: + for axis in [0, 1]: + x = cp.Variable((4, 3)) + expr = cp.cumsum(x, axis=axis) + x_val = np.arange(12).reshape((4, 3)) + target = np.cumsum(x_val, axis=axis) + + prob = cp.Problem(cp.Minimize(0), [x == x_val]) + prob.solve() + + assert np.allclose(expr.value, target) + def test_kron_expr(self) -> None: """Test the kron atom. """