Skip to content

Commit

Permalink
fixing issues with difference in backends, will need to make changes …
Browse files Browse the repository at this point in the history
…to solving chain as well
  • Loading branch information
William Zijie Zhang authored and William Zijie Zhang committed Jun 24, 2024
1 parent 89c73ba commit 03b6060
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
11 changes: 4 additions & 7 deletions cvxpy/atoms/affine/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import cvxpy.interface as intf
import cvxpy.lin_ops.lin_op as lo
import cvxpy.lin_ops.lin_utils as lu
import cvxpy.settings as s
from cvxpy.atoms.affine.affine_atom import AffAtom
from cvxpy.atoms.axis_atom import AxisAtom
from cvxpy.constraints.constraint import Constraint
Expand Down Expand Up @@ -84,9 +83,10 @@ def graph_implementation(
tuple
(LinOp for objective, list of constraints)
"""
if s.DEFAULT_CANON_BACKEND == 'CPP':
axis = data[0]
keepdims = data[1]
axis, keepdims = data
if len(arg_objs[0].shape) > 2 or axis not in {None, 0, 1}:
obj = lu.sum_entries(arg_objs[0], shape=shape, axis=axis, keepdims=keepdims)
else:
if axis is None:
obj = lu.sum_entries(arg_objs[0], shape=shape)
elif axis == 1:
Expand All @@ -103,9 +103,6 @@ def graph_implementation(
const_shape = (arg_objs[0].shape[0],)
ones = lu.create_const(np.ones(const_shape), const_shape)
obj = lu.mul_expr(ones, arg_objs[0], shape)
else:
axis, keepdims = data
obj = lu.sum_entries(arg_objs[0], shape=shape, axis=axis, keepdims=keepdims)
return (obj, [])


Expand Down
2 changes: 1 addition & 1 deletion cvxpy/cvxcore/python/cppbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def make_linC_from_linPy(linPy, linPy_to_linC) -> None:
linC = cvxcore.LinOp(typ, shape, lin_args_vec)
linPy_to_linC[linPy] = linC

if linPy.data is not None:
if linPy.data is not None and linPy.type != "sum_entries":
if isinstance(linPy.data, lo.LinOp):
linC_data = linPy_to_linC[linPy.data]
linC.set_linOp_data(linC_data)
Expand Down
4 changes: 1 addition & 3 deletions cvxpy/lin_ops/lin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np

import cvxpy.lin_ops.lin_op as lo
import cvxpy.settings as s
import cvxpy.utilities as u
from cvxpy.lin_ops.lin_constraints import LinEqConstr, LinLeqConstr

Expand Down Expand Up @@ -384,8 +383,7 @@ def sum_entries(operator, shape: Tuple[int, ...], axis=None, keepdims=None):
LinOp
An operator representing the sum.
"""
data = None if s.DEFAULT_CANON_BACKEND == "CPP" else [axis, keepdims]
return lo.LinOp(lo.SUM_ENTRIES, shape, [operator], data=data)
return lo.LinOp(lo.SUM_ENTRIES, shape, [operator], data=[axis, keepdims])


def trace(operator):
Expand Down

0 comments on commit 03b6060

Please sign in to comment.