From 03b60609ec63d00f0c5630daac92b3cd460c2b69 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 24 Jun 2024 14:07:32 -0700 Subject: [PATCH] fixing issues with difference in backends, will need to make changes to solving chain as well --- cvxpy/atoms/affine/sum.py | 11 ++++------- cvxpy/cvxcore/python/cppbackend.py | 2 +- cvxpy/lin_ops/lin_utils.py | 4 +--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/cvxpy/atoms/affine/sum.py b/cvxpy/atoms/affine/sum.py index e7f2599dd9..fe46653b63 100644 --- a/cvxpy/atoms/affine/sum.py +++ b/cvxpy/atoms/affine/sum.py @@ -22,7 +22,6 @@ import cvxpy.interface as intf import cvxpy.lin_ops.lin_op as lo import cvxpy.lin_ops.lin_utils as lu -import cvxpy.settings as s from cvxpy.atoms.affine.affine_atom import AffAtom from cvxpy.atoms.axis_atom import AxisAtom from cvxpy.constraints.constraint import Constraint @@ -84,9 +83,10 @@ def graph_implementation( tuple (LinOp for objective, list of constraints) """ - if s.DEFAULT_CANON_BACKEND == 'CPP': - axis = data[0] - keepdims = data[1] + axis, keepdims = data + if len(arg_objs[0].shape) > 2 or axis not in {None, 0, 1}: + obj = lu.sum_entries(arg_objs[0], shape=shape, axis=axis, keepdims=keepdims) + else: if axis is None: obj = lu.sum_entries(arg_objs[0], shape=shape) elif axis == 1: @@ -103,9 +103,6 @@ def graph_implementation( const_shape = (arg_objs[0].shape[0],) ones = lu.create_const(np.ones(const_shape), const_shape) obj = lu.mul_expr(ones, arg_objs[0], shape) - else: - axis, keepdims = data - obj = lu.sum_entries(arg_objs[0], shape=shape, axis=axis, keepdims=keepdims) return (obj, []) diff --git a/cvxpy/cvxcore/python/cppbackend.py b/cvxpy/cvxcore/python/cppbackend.py index 2f4648afb3..52141ecfb6 100644 --- a/cvxpy/cvxcore/python/cppbackend.py +++ b/cvxpy/cvxcore/python/cppbackend.py @@ -175,7 +175,7 @@ def make_linC_from_linPy(linPy, linPy_to_linC) -> None: linC = cvxcore.LinOp(typ, shape, lin_args_vec) linPy_to_linC[linPy] = linC - if linPy.data is not None: + if linPy.data is not None and linPy.type != "sum_entries": if isinstance(linPy.data, lo.LinOp): linC_data = linPy_to_linC[linPy.data] linC.set_linOp_data(linC_data) diff --git a/cvxpy/lin_ops/lin_utils.py b/cvxpy/lin_ops/lin_utils.py index 5217335ba7..35ce351065 100644 --- a/cvxpy/lin_ops/lin_utils.py +++ b/cvxpy/lin_ops/lin_utils.py @@ -21,7 +21,6 @@ import numpy as np import cvxpy.lin_ops.lin_op as lo -import cvxpy.settings as s import cvxpy.utilities as u from cvxpy.lin_ops.lin_constraints import LinEqConstr, LinLeqConstr @@ -384,8 +383,7 @@ def sum_entries(operator, shape: Tuple[int, ...], axis=None, keepdims=None): LinOp An operator representing the sum. """ - data = None if s.DEFAULT_CANON_BACKEND == "CPP" else [axis, keepdims] - return lo.LinOp(lo.SUM_ENTRIES, shape, [operator], data=data) + return lo.LinOp(lo.SUM_ENTRIES, shape, [operator], data=[axis, keepdims]) def trace(operator):