Skip to content

Commit

Permalink
Only rewrite with knowledge up to now
Browse files Browse the repository at this point in the history
  • Loading branch information
sakehl committed Jun 25, 2024
1 parent 003a9ed commit ae2a7e7
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 91 deletions.
62 changes: 32 additions & 30 deletions src/col/vct/col/ast/util/ExpressionEqualityCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -699,38 +699,28 @@ class AnnotationVariableInfoGetter[G]() {
}
}

def getInfo(annotations: Seq[Expr[G]]): AnnotationVariableInfo[G] = {
variableEqualities.clear()
variableValues.clear()
variableSynonyms.clear()
currentSynonymGroup = 0
variableNotZero.clear()
lessThanEqVars.clear()
upperBound.clear()
lowerBound.clear()
usefullConditions.clear()

for (clause <- annotations) { extractEqualities(clause) }

val res = AnnotationVariableInfo[G](
variableEqualities.view.mapValues(_.toList).toMap,
variableValues.toMap,
variableSynonyms.toMap,
Set[Local[G]](),
Map[Local[G], Set[Local[G]]](),
Map[Local[G], BigInt](),
Map[Local[G], BigInt](),
usefullConditions,
)
equalCheck = ExpressionEqualityCheck(Some(res))

for (clause <- annotations) {
if (isSimpleExpr(clause)) {
extractComparisons(clause)
usefullConditions.addOne(clause)
}
def addInfo(annotation: Expr[G]): Unit = {
extractEqualities(annotation)

if (isSimpleExpr(annotation)) {
val res = AnnotationVariableInfo[G](
variableEqualities.view.mapValues(_.toList).toMap,
variableValues.toMap,
variableSynonyms.toMap,
Set[Local[G]](),
Map[Local[G], Set[Local[G]]](),
Map[Local[G], BigInt](),
Map[Local[G], BigInt](),
usefullConditions,
)

equalCheck = ExpressionEqualityCheck(Some(res))
extractComparisons(annotation)
usefullConditions.addOne(annotation)
}
}

def finalInfo(): AnnotationVariableInfo[G] = {
distributeInfo()

AnnotationVariableInfo(
Expand All @@ -745,6 +735,18 @@ class AnnotationVariableInfoGetter[G]() {
)
}

def setupInfo(): Unit = {
variableEqualities.clear()
variableValues.clear()
variableSynonyms.clear()
currentSynonymGroup = 0
variableNotZero.clear()
lessThanEqVars.clear()
upperBound.clear()
lowerBound.clear()
usefullConditions.clear()
}

def distributeInfo(): Unit = {
// First check if expressions have become integers
for ((name, equals) <- variableEqualities) {
Expand Down
122 changes: 61 additions & 61 deletions src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@ import com.typesafe.scalalogging.LazyLogging
import vct.col.ast._
import vct.col.ast.util.{AnnotationVariableInfoGetter, ExpressionEqualityCheck}
import vct.col.rewrite.util.Comparison
import vct.col.origin.{
ArrayInsufficientPermission,
DiagnosticOrigin,
LabelContext,
Origin,
PanicBlame,
PointerBounds,
PreferredName,
}
import vct.col.origin.{ArrayInsufficientPermission, DiagnosticOrigin, LabelContext, Origin, PanicBlame, PointerBounds, PreferredName}
import vct.col.ref.Ref
import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder}
import vct.col.util.AstBuildHelpers._
Expand Down Expand Up @@ -56,21 +48,43 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
private def one: IntegerValue[Pre] = IntegerValue(1)

var equalityChecker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck()
var topLevel: Boolean = false
var infoGetter: AnnotationVariableInfoGetter[Pre] = new AnnotationVariableInfoGetter[Pre]()

override def dispatch(e: Expr[Pre]): Expr[Post] = {
e match {
// Consider elements of top level stars and ands also toplevel
case e: Star[Pre] if topLevel =>
val left = dispatch(e.left)
topLevel = true
val right = dispatch(e.right)
topLevel = true
Star(left, right)(e.o)
case e: And[Pre] if topLevel =>
val left = dispatch(e.left)
topLevel = true
val right = dispatch(e.right)
topLevel = true
And(left, right)(e.o)
case e: Forall[Pre] =>
topLevel = false
equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo()))
mapUnfoldedStar(
e.body,
(b: Expr[Pre]) =>
rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)),
)
case e: Starall[Pre] =>
topLevel = false
equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo()))
mapUnfoldedStar(
e.body,
(b: Expr[Pre]) =>
rewriteBinder(Starall(e.bindings, e.triggers, b)(e.blame)(e.o)),
)
case other if topLevel => infoGetter.addInfo(other)
topLevel = false
other.rewriteDefault()
case other => other.rewriteDefault()
}
}
Expand Down Expand Up @@ -106,55 +120,48 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
}

override def dispatch(stat: Statement[Pre]): Statement[Post] = {
val e =
stat match {
case Exhale(e) => e
case Inhale(e) => e
stat match {
case Exhale(e) =>
case Inhale(e) =>
case proof: FramedProof[Pre] => return checkFramedProof(proof)
case _ => return stat.rewriteDefault()
}

val conditions = getConditions(e)
val infoGetter = new AnnotationVariableInfoGetter[Pre]()
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter.getInfo(conditions))
)
}
topLevel = true
infoGetter.setupInfo()
val result = stat.rewriteDefault()
topLevel = false
equalityChecker = ExpressionEqualityCheck()
result
}

def checkFramedProof(proof: FramedProof[Pre]): Statement[Post] = {
val conditions_pre = getConditions(proof.pre)
val infoGetter_pre = new AnnotationVariableInfoGetter[Pre]()
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter_pre.getInfo(conditions_pre))
)
topLevel = true
infoGetter.setupInfo()
val pre = dispatch(proof.pre)

val conditions_post = getConditions(proof.post)
val infoGetter_post = new AnnotationVariableInfoGetter[Pre]()
ExpressionEqualityCheck(Some(infoGetter_post.getInfo(conditions_post)))
equalityChecker = ExpressionEqualityCheck()
infoGetter.setupInfo()
val post = dispatch(proof.post)
topLevel = false
equalityChecker = ExpressionEqualityCheck()

val body = dispatch(proof.body)

FramedProof[Post](pre, body, post)(proof.blame)(proof.o)
}

def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] =
preds match {
case UnitAccountedPredicate(pred) => getConditions(pred)
case SplitAccountedPredicate(left, right) =>
getConditions(left) ++ getConditions(right)
}

def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] =
e match {
case And(left, right) => getConditions(left) ++ getConditions(right)
case Star(left, right) => getConditions(left) ++ getConditions(right)
case other => Seq[Expr[Pre]](other)
}
// def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] =
// preds match {
// case UnitAccountedPredicate(pred) => getConditions(pred)
// case SplitAccountedPredicate(left, right) =>
// getConditions(left) ++ getConditions(right)
// }
//
// def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] =
// e match {
// case And(left, right) => getConditions(left) ++ getConditions(right)
// case Star(left, right) => getConditions(left) ++ getConditions(right)
// case other => Seq[Expr[Pre]](other)
// }

override def dispatch(loopContract: LoopContract[Pre]): LoopContract[Post] = {
val loopInvariant: LoopInvariant[Pre] =
Expand All @@ -163,12 +170,10 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
case _ => return dispatch(loopContract)
}

val infoGetter = new AnnotationVariableInfoGetter[Pre]()
val conditions = getConditions(loopInvariant.invariant)
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter.getInfo(conditions))
)
topLevel = true
infoGetter.setupInfo()
val invariant = dispatch(loopInvariant.invariant)
topLevel = false
equalityChecker = ExpressionEqualityCheck()
val decreases = loopInvariant.decreases.map(element => dispatch(element))

Expand All @@ -178,24 +183,19 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
override def dispatch(
contract: ApplicableContract[Pre]
): ApplicableContract[Post] = {
val infoGetter = new AnnotationVariableInfoGetter[Pre]()
val reqConditions = getConditions(contract.requires)
val contextConditions = getConditions(contract.contextEverywhere)
val ensureConditions = getConditions(contract.ensures)
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter.getInfo(reqConditions ++ contextConditions))
)

topLevel = true
infoGetter.setupInfo()
val requires = dispatch(contract.requires)
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter.getInfo(ensureConditions ++ contextConditions))
)
equalityChecker = ExpressionEqualityCheck()
infoGetter.setupInfo()
val ensures = dispatch(contract.ensures)
equalityChecker = ExpressionEqualityCheck(
Some(infoGetter.getInfo(contextConditions))
)
val contextEverywhere = dispatch(contract.contextEverywhere)
topLevel = false
equalityChecker = ExpressionEqualityCheck()

// TODO: Is context everywhere al distributed here? If not, we need to do more.
val contextEverywhere = dispatch(contract.contextEverywhere)

val signals = contract.signals.map(element => dispatch(element))
val givenArgs =
variables.collect { contract.givenArgs.foreach(dispatch) }._1
Expand Down
18 changes: 18 additions & 0 deletions test/main/vct/test/integration/features/Quantifiers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package vct.test.integration.features

import vct.test.integration.helper.VercorsSpec

class Quantifiers extends VercorsSpec {

vercors should verify using silicon in "correct order of quantifier rewrite (issue #1215)" c """
/*@
requires ant1 != NULL && \pointer_length(ant1) == _n_vis;
requires (\forall* int i; 0<=i && i< \pointer_length(ant1); Perm(&ant1[i], write));
requires (\forall int _0; 0 <= _0 && _0 < _n_vis; 0 <= ant1[_0]);
requires _n_vis == 230930;
@*/
int main(int *ant1, int _n_vis) {
}
"""

}

0 comments on commit ae2a7e7

Please sign in to comment.