diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 2fc56f5cb..7806686db 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -146,7 +146,7 @@ def buildGenExprObj(expr): GenExprs = np.empty(expr.shape, dtype=object) for idx in np.ndindex(expr.shape): GenExprs[idx] = buildGenExprObj(expr[idx]) - return GenExprs + return GenExprs.view(MatrixExpr) else: assert isinstance(expr, GenExpr) @@ -223,6 +223,9 @@ cdef class Expr: return self def __mul__(self, other): + if isinstance(other, MatrixExpr): + return other * self + if _is_number(other): f = float(other) return Expr({v:f*c for v,c in self.terms.items()}) @@ -420,6 +423,9 @@ cdef class GenExpr: return UnaryExpr(Operator.fabs, self) def __add__(self, other): + if isinstance(other, MatrixExpr): + return other + self + left = buildGenExprObj(self) right = buildGenExprObj(other) ans = SumExpr() @@ -475,6 +481,9 @@ cdef class GenExpr: # return self def __mul__(self, other): + if isinstance(other, MatrixExpr): + return other * self + left = buildGenExprObj(self) right = buildGenExprObj(other) ans = ProdExpr() @@ -537,7 +546,7 @@ cdef class GenExpr: def __truediv__(self,other): divisor = buildGenExprObj(other) # we can't divide by 0 - if divisor.getOp() == Operator.const and divisor.number == 0.0: + if isinstance(divisor, GenExpr) and divisor.getOp() == Operator.const and divisor.number == 0.0: raise ZeroDivisionError("cannot divide by 0") return self * divisor**(-1) diff --git a/tests/test_matrix_variable.py b/tests/test_matrix_variable.py index 0308bb694..233ceb346 100644 --- a/tests/test_matrix_variable.py +++ b/tests/test_matrix_variable.py @@ -1,8 +1,10 @@ +import operator import pdb import pprint import pytest from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons +from pyscipopt.scip import GenExpr from time import time import numpy as np @@ -392,3 +394,22 @@ def test_matrix_cons_indicator(): assert m.getVal(is_equal).sum() == 2 assert (m.getVal(x) == m.getVal(y)).all().all() assert (m.getVal(x) == np.array([[5, 5, 5], [5, 5, 5]])).all().all() + + +_binop_model = Model() + +def var(): + return _binop_model.addVar() + +def genexpr(): + return _binop_model.addVar() ** 0.6 + +def matvar(): + return _binop_model.addMatrixVar((1,)) + +@pytest.mark.parametrize("right", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"]) +@pytest.mark.parametrize("left", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"]) +@pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul, operator.truediv]) +def test_binop(op, left, right): + res = op(left, right) + assert isinstance(res, (Expr, GenExpr, MatrixExpr))