diff --git a/src/col/vct/col/ast/util/ExpressionEqualityCheck.scala b/src/col/vct/col/ast/util/ExpressionEqualityCheck.scala index ab806c1ca8..97ef481a94 100644 --- a/src/col/vct/col/ast/util/ExpressionEqualityCheck.scala +++ b/src/col/vct/col/ast/util/ExpressionEqualityCheck.scala @@ -699,38 +699,28 @@ class AnnotationVariableInfoGetter[G]() { } } - def getInfo(annotations: Seq[Expr[G]]): AnnotationVariableInfo[G] = { - variableEqualities.clear() - variableValues.clear() - variableSynonyms.clear() - currentSynonymGroup = 0 - variableNotZero.clear() - lessThanEqVars.clear() - upperBound.clear() - lowerBound.clear() - usefullConditions.clear() - - for (clause <- annotations) { extractEqualities(clause) } - - val res = AnnotationVariableInfo[G]( - variableEqualities.view.mapValues(_.toList).toMap, - variableValues.toMap, - variableSynonyms.toMap, - Set[Local[G]](), - Map[Local[G], Set[Local[G]]](), - Map[Local[G], BigInt](), - Map[Local[G], BigInt](), - usefullConditions, - ) - equalCheck = ExpressionEqualityCheck(Some(res)) - - for (clause <- annotations) { - if (isSimpleExpr(clause)) { - extractComparisons(clause) - usefullConditions.addOne(clause) - } + def addInfo(annotation: Expr[G]): Unit = { + extractEqualities(annotation) + + if (isSimpleExpr(annotation)) { + val res = AnnotationVariableInfo[G]( + variableEqualities.view.mapValues(_.toList).toMap, + variableValues.toMap, + variableSynonyms.toMap, + Set[Local[G]](), + Map[Local[G], Set[Local[G]]](), + Map[Local[G], BigInt](), + Map[Local[G], BigInt](), + usefullConditions, + ) + + equalCheck = ExpressionEqualityCheck(Some(res)) + extractComparisons(annotation) + usefullConditions.addOne(annotation) } + } + def finalInfo(): AnnotationVariableInfo[G] = { distributeInfo() AnnotationVariableInfo( @@ -745,6 +735,18 @@ class AnnotationVariableInfoGetter[G]() { ) } + def setupInfo(): Unit = { + variableEqualities.clear() + variableValues.clear() + variableSynonyms.clear() + currentSynonymGroup = 0 + variableNotZero.clear() + lessThanEqVars.clear() + upperBound.clear() + lowerBound.clear() + usefullConditions.clear() + } + def distributeInfo(): Unit = { // First check if expressions have become integers for ((name, equals) <- variableEqualities) { diff --git a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala index 519118dc75..b6bedd526c 100644 --- a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala +++ b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala @@ -4,15 +4,7 @@ import com.typesafe.scalalogging.LazyLogging import vct.col.ast._ import vct.col.ast.util.{AnnotationVariableInfoGetter, ExpressionEqualityCheck} import vct.col.rewrite.util.Comparison -import vct.col.origin.{ - ArrayInsufficientPermission, - DiagnosticOrigin, - LabelContext, - Origin, - PanicBlame, - PointerBounds, - PreferredName, -} +import vct.col.origin.{ArrayInsufficientPermission, DiagnosticOrigin, LabelContext, Origin, PanicBlame, PointerBounds, PreferredName} import vct.col.ref.Ref import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ @@ -56,21 +48,43 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() private def one: IntegerValue[Pre] = IntegerValue(1) var equalityChecker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck() + var topLevel: Boolean = false + var infoGetter: AnnotationVariableInfoGetter[Pre] = new AnnotationVariableInfoGetter[Pre]() override def dispatch(e: Expr[Pre]): Expr[Post] = { e match { + // Consider elements of top level stars and ands also toplevel + case e: Star[Pre] if topLevel => + val left = dispatch(e.left) + topLevel = true + val right = dispatch(e.right) + topLevel = true + Star(left, right)(e.o) + case e: And[Pre] if topLevel => + val left = dispatch(e.left) + topLevel = true + val right = dispatch(e.right) + topLevel = true + And(left, right)(e.o) case e: Forall[Pre] => + topLevel = false + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) mapUnfoldedStar( e.body, (b: Expr[Pre]) => rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)), ) case e: Starall[Pre] => + topLevel = false + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) mapUnfoldedStar( e.body, (b: Expr[Pre]) => rewriteBinder(Starall(e.bindings, e.triggers, b)(e.blame)(e.o)), ) + case other if topLevel => infoGetter.addInfo(other) + topLevel = false + other.rewriteDefault() case other => other.rewriteDefault() } } @@ -106,55 +120,48 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() } override def dispatch(stat: Statement[Pre]): Statement[Post] = { - val e = - stat match { - case Exhale(e) => e - case Inhale(e) => e + stat match { + case Exhale(e) => + case Inhale(e) => case proof: FramedProof[Pre] => return checkFramedProof(proof) case _ => return stat.rewriteDefault() - } - - val conditions = getConditions(e) - val infoGetter = new AnnotationVariableInfoGetter[Pre]() - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter.getInfo(conditions)) - ) + } + topLevel = true + infoGetter.setupInfo() val result = stat.rewriteDefault() + topLevel = false equalityChecker = ExpressionEqualityCheck() result } def checkFramedProof(proof: FramedProof[Pre]): Statement[Post] = { - val conditions_pre = getConditions(proof.pre) - val infoGetter_pre = new AnnotationVariableInfoGetter[Pre]() - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter_pre.getInfo(conditions_pre)) - ) + topLevel = true + infoGetter.setupInfo() val pre = dispatch(proof.pre) - - val conditions_post = getConditions(proof.post) - val infoGetter_post = new AnnotationVariableInfoGetter[Pre]() - ExpressionEqualityCheck(Some(infoGetter_post.getInfo(conditions_post))) + equalityChecker = ExpressionEqualityCheck() + infoGetter.setupInfo() val post = dispatch(proof.post) + topLevel = false equalityChecker = ExpressionEqualityCheck() + val body = dispatch(proof.body) FramedProof[Post](pre, body, post)(proof.blame)(proof.o) } - def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] = - preds match { - case UnitAccountedPredicate(pred) => getConditions(pred) - case SplitAccountedPredicate(left, right) => - getConditions(left) ++ getConditions(right) - } - - def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] = - e match { - case And(left, right) => getConditions(left) ++ getConditions(right) - case Star(left, right) => getConditions(left) ++ getConditions(right) - case other => Seq[Expr[Pre]](other) - } +// def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] = +// preds match { +// case UnitAccountedPredicate(pred) => getConditions(pred) +// case SplitAccountedPredicate(left, right) => +// getConditions(left) ++ getConditions(right) +// } +// +// def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] = +// e match { +// case And(left, right) => getConditions(left) ++ getConditions(right) +// case Star(left, right) => getConditions(left) ++ getConditions(right) +// case other => Seq[Expr[Pre]](other) +// } override def dispatch(loopContract: LoopContract[Pre]): LoopContract[Post] = { val loopInvariant: LoopInvariant[Pre] = @@ -163,12 +170,10 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() case _ => return dispatch(loopContract) } - val infoGetter = new AnnotationVariableInfoGetter[Pre]() - val conditions = getConditions(loopInvariant.invariant) - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter.getInfo(conditions)) - ) + topLevel = true + infoGetter.setupInfo() val invariant = dispatch(loopInvariant.invariant) + topLevel = false equalityChecker = ExpressionEqualityCheck() val decreases = loopInvariant.decreases.map(element => dispatch(element)) @@ -178,24 +183,19 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() override def dispatch( contract: ApplicableContract[Pre] ): ApplicableContract[Post] = { - val infoGetter = new AnnotationVariableInfoGetter[Pre]() - val reqConditions = getConditions(contract.requires) - val contextConditions = getConditions(contract.contextEverywhere) - val ensureConditions = getConditions(contract.ensures) - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter.getInfo(reqConditions ++ contextConditions)) - ) + + topLevel = true + infoGetter.setupInfo() val requires = dispatch(contract.requires) - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter.getInfo(ensureConditions ++ contextConditions)) - ) + equalityChecker = ExpressionEqualityCheck() + infoGetter.setupInfo() val ensures = dispatch(contract.ensures) - equalityChecker = ExpressionEqualityCheck( - Some(infoGetter.getInfo(contextConditions)) - ) - val contextEverywhere = dispatch(contract.contextEverywhere) + topLevel = false equalityChecker = ExpressionEqualityCheck() + // TODO: Is context everywhere al distributed here? If not, we need to do more. + val contextEverywhere = dispatch(contract.contextEverywhere) + val signals = contract.signals.map(element => dispatch(element)) val givenArgs = variables.collect { contract.givenArgs.foreach(dispatch) }._1 diff --git a/test/main/vct/test/integration/features/Quantifiers.scala b/test/main/vct/test/integration/features/Quantifiers.scala new file mode 100644 index 0000000000..e4a3632da5 --- /dev/null +++ b/test/main/vct/test/integration/features/Quantifiers.scala @@ -0,0 +1,18 @@ +package vct.test.integration.features + +import vct.test.integration.helper.VercorsSpec + +class Quantifiers extends VercorsSpec { + + vercors should verify using silicon in "correct order of quantifier rewrite (issue #1215)" c """ +/*@ + requires ant1 != NULL && \pointer_length(ant1) == _n_vis; + requires (\forall* int i; 0<=i && i< \pointer_length(ant1); Perm(&ant1[i], write)); + requires (\forall int _0; 0 <= _0 && _0 < _n_vis; 0 <= ant1[_0]); + requires _n_vis == 230930; +@*/ + int main(int *ant1, int _n_vis) { + } + """ + +}