diff --git a/solver/src/main/java/org/chocosolver/solver/constraints/IIntConstraintFactory.java b/solver/src/main/java/org/chocosolver/solver/constraints/IIntConstraintFactory.java index e857536d7b..fc844e9b25 100644 --- a/solver/src/main/java/org/chocosolver/solver/constraints/IIntConstraintFactory.java +++ b/solver/src/main/java/org/chocosolver/solver/constraints/IIntConstraintFactory.java @@ -736,7 +736,9 @@ default Constraint mod(IntVar X, IntVar Y, IntVar Z) { */ @SuppressWarnings("SuspiciousNameCombination") default Constraint times(IntVar X, IntVar Y, IntVar Z) { - if (Y.isInstantiated()) { + if (X == Y) { + return square(Z, X); + } else if (Y.isInstantiated()) { return times(X, Y.getValue(), Z); } else if (X.isInstantiated()) { return times(Y, X.getValue(), Z); diff --git a/solver/src/main/java/org/chocosolver/solver/constraints/binary/PropSquare.java b/solver/src/main/java/org/chocosolver/solver/constraints/binary/PropSquare.java index 1daf550ad1..2fc68b2a08 100644 --- a/solver/src/main/java/org/chocosolver/solver/constraints/binary/PropSquare.java +++ b/solver/src/main/java/org/chocosolver/solver/constraints/binary/PropSquare.java @@ -47,32 +47,17 @@ public void propagate(int evtmask) throws ContradictionException { do { setBounds(); } while (updateHolesinX() | updateHolesinY()); - if (vars[1].isInstantiated()) { - vars[0].instantiateTo(sqr(vars[1].getValue()), this); - } } - @Override public ESat isEntailed() { if (vars[0].getUB() < 0) { return ESat.FALSE; - } else if (vars[0].isInstantiated()) { - if (vars[1].isInstantiated()) { - return ESat.eval(vars[0].getValue() == sqr(vars[1].getValue())); - } else if (vars[1].getDomainSize() == 2 && - vars[1].contains(-floor_sqrt(vars[0].getValue())) && - vars[1].contains(-floor_sqrt(vars[0].getValue()))) { - return ESat.TRUE; - } else if (!vars[1].contains(floor_sqrt(vars[0].getValue())) && - !vars[1].contains(-floor_sqrt(vars[0].getValue()))) { - return ESat.FALSE; - } else { - return ESat.UNDEFINED; - } - } else { - return ESat.UNDEFINED; } + if (isCompletelyInstantiated()) { + return ESat.eval(vars[0].getValue() == sqr(vars[1].getValue())); + } + return ESat.UNDEFINED; } @Override @@ -89,14 +74,16 @@ private void setBounds() throws ContradictionException { } private static int floor_sqrt(int n) { - if (n < 0) + if (n < 0) { return 0; + } return (int) Math.floor(Math.sqrt(n)); } private static int ceil_sqrt(int n) { - if (n < 0) + if (n < 0) { return 0; + } return (int) Math.ceil(Math.sqrt(n)); } @@ -117,56 +104,69 @@ protected void updateLowerBoundofX() throws ContradictionException { protected void updateUpperBoundofX() throws ContradictionException { vars[0].updateUpperBound(Math.max(sqr(vars[1].getLB()), sqr(vars[1].getUB())), this); - } protected boolean updateHolesinX() throws ContradictionException { - // remove intervals to deal with consecutive value removal and upper bound modification - if (bothEnum) { + if (!vars[0].hasEnumeratedDomain()) { + return false; + } + boolean impact = false; + if (model.getSolver().getNodeCount() == 0) { // only at root node propagation + // check perfect squares once and for all int ub = vars[0].getUB(); vrms.clear(); vrms.setOffset(vars[0].getLB()); for (int value = vars[0].getLB(); value <= ub; value = vars[0].nextValue(value)) { - if (!(MathUtils.isPerfectSquare(value) && - (vars[1].contains(floor_sqrt(value)) || vars[1].contains(-floor_sqrt(value))))) { + if (!MathUtils.isPerfectSquare(value)) { vrms.add(value); } } - return vars[0].removeValues(vrms, this); - } else if (vars[0].hasEnumeratedDomain()) { - int value = vars[0].getLB(); - int nlb = value - 1; - while (nlb == value - 1) { - if (!vars[1].contains(floor_sqrt(value)) && !vars[1].contains(-floor_sqrt(value))) { - nlb = value; - } - value = vars[0].nextValue(value); - } - boolean filter = vars[0].updateLowerBound(nlb, this); + impact = vars[0].removeValues(vrms, this); + } - value = vars[0].getUB(); - int nub = value + 1; - while (nub == value + 1) { - if (!vars[1].contains(floor_sqrt(value)) && !vars[1].contains(-floor_sqrt(value))) { - nub = value; + // remove intervals to deal with consecutive value removal and upper bound modification + if (bothEnum) { + int ub = vars[0].getUB(); + vrms.clear(); + vrms.setOffset(vars[0].getLB()); + for (int value = vars[0].getLB(); value <= ub; value = vars[0].nextValue(value)) { + int sqrt = floor_sqrt(value); + if (!vars[1].contains(sqrt) && !vars[1].contains(-sqrt)) { + vrms.add(value); } - value = vars[0].previousValue(value); } - return filter | vars[0].updateUpperBound(nub, this); + impact |= vars[0].removeValues(vrms, this); } - return false; + return impact; } protected boolean updateLowerBoundofY() throws ContradictionException { - return vars[1].updateLowerBound(-ceil_sqrt(vars[0].getUB()), this); + if (vars[1].getLB() >= 0) { + return vars[1].updateLowerBound(ceil_sqrt(vars[0].getLB()), this); + } else { + return vars[1].updateLowerBound(-floor_sqrt(vars[0].getUB()), this); + } } protected boolean updateUpperBoundofY() throws ContradictionException { - return vars[1].updateUpperBound(floor_sqrt(vars[0].getUB()), this); + if (vars[1].getUB() < 0) { + return vars[1].updateUpperBound(-ceil_sqrt(vars[0].getLB()), this); + } else { + return vars[1].updateUpperBound(floor_sqrt(vars[0].getUB()), this); + } } protected boolean updateHolesinY() throws ContradictionException { - // remove intervals to deal with consecutive value removal and upper bound modification + if (!vars[1].hasEnumeratedDomain()) { + return false; + } + boolean impact = false; + // remove interval around 0 based on X LB + int val = ceil_sqrt(vars[0].getLB()) - 1; + if (val >= 0) { + impact = vars[1].removeInterval(-val, val, this); + } + // remove values based on X domain if (bothEnum) { int ub = vars[1].getUB(); vrms.clear(); @@ -176,24 +176,8 @@ protected boolean updateHolesinY() throws ContradictionException { vrms.add(value); } } - return vars[1].removeValues(vrms, this); - } else if (vars[1].hasEnumeratedDomain()) { - int lb = vars[1].getLB(); - int ub = vars[1].getUB(); - while (!vars[0].contains(sqr(lb))) { - lb = vars[1].nextValue(lb); - if (lb > ub) break; - } - boolean filter = vars[1].updateLowerBound(lb, this); - - while (!vars[0].contains(sqr(ub))) { - ub = vars[1].nextValue(ub); - if (ub < lb) break; - } - return filter | vars[1].updateUpperBound(ub, this); + impact |= vars[1].removeValues(vrms, this); } - return false; + return impact; } - - } diff --git a/solver/src/test/java/org/chocosolver/solver/constraints/binary/SquareTest.java b/solver/src/test/java/org/chocosolver/solver/constraints/binary/SquareTest.java index e5db51c6e1..41b9a3713e 100644 --- a/solver/src/test/java/org/chocosolver/solver/constraints/binary/SquareTest.java +++ b/solver/src/test/java/org/chocosolver/solver/constraints/binary/SquareTest.java @@ -9,20 +9,163 @@ */ package org.chocosolver.solver.constraints.binary; +import org.chocosolver.solver.Cause; import org.chocosolver.solver.Model; +import org.chocosolver.solver.Solver; import org.chocosolver.solver.exception.ContradictionException; import org.chocosolver.solver.search.strategy.Search; import org.chocosolver.solver.variables.IntVar; +import org.chocosolver.util.tools.ArrayUtils; import org.testng.Assert; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; /** * @author Jean-Guillaume Fages */ public class SquareTest { + @Test(groups = "1s", timeOut = 60000) + public void testInstZero() { + Model m = new Model(); + IntVar x = m.intVar(0); + IntVar y = m.intVar(-5, 5, false); + IntVar z = m.intVar(-5, 5, true); + m.square(x, y).post(); + m.square(x, z).post(); + + try { + m.getSolver().propagate(); + } catch (ContradictionException ex) { + fail(); + } + assertTrue(y.isInstantiatedTo(0)); + assertTrue(z.isInstantiatedTo(0)); + } + + @Test(groups = "1s", timeOut = 60000) + public void testInst() { + Model m = new Model(); + IntVar x1 = m.intVar(9); + IntVar y1 = m.intVar(-5, 5, false); + IntVar z1 = m.intVar(-5, 5, true); + m.square(x1, y1).post(); + m.square(x1, z1).post(); + + IntVar x2 = m.intVar(-50, 50, false); + IntVar y2 = m.intVar(-50, 50, true); + IntVar z2 = m.intVar(-3); + m.square(x2, z2).post(); + m.square(y2, z2).post(); + + try { + m.getSolver().propagate(); + } catch (ContradictionException ex) { + fail(); + } + + assertEquals(y1.getLB(), -3); + assertEquals(y1.getUB(), 3); + assertEquals(y1.getDomainSize(), 2); + assertEquals(z1.getLB(), -3); + assertEquals(z1.getUB(), 3); + assertEquals(z1.getDomainSize(), 7); + + assertTrue(x2.isInstantiatedTo(9)); + assertTrue(y2.isInstantiatedTo(9)); + } + + @Test(groups = "1s", timeOut = 60000) + public void testPropBounds() { + Model m = new Model(); + IntVar x = m.intVar(0, 50, false); + IntVar y = m.intVar(3, 5, false); + m.square(x, y).post(); + m.arithm(x, "!=", 16).post(); + + try { + m.getSolver().propagate(); + } catch (ContradictionException ex) { + fail(); + } + assertEquals(x.getLB(), 9); + assertEquals(x.getUB(), 25); + assertEquals(y.getLB(), 3); + assertEquals(y.getUB(), 5); + + try { + y.updateLowerBound(4, Cause.Null); + m.getSolver().propagate(); + } catch (ContradictionException ex) { + fail(); + } + assertTrue(x.isInstantiatedTo(25)); + assertTrue(y.isInstantiatedTo(5)); + } + + @Test(groups = "10s", timeOut = 60000) + public void testBigBoundBound() { + Model m = new Model(); + int n = 6; + IntVar[] x = m.intVarArray(n, -n, n, true); + IntVar[] x2 = m.intVarArray(n, -n*n,n*n, true); + for (int i=0;i