Skip to content

Commit e3ed0a5

Browse files
committed
Fixes gathering information from AccountedPred & Scales
1 parent 0e97812 commit e3ed0a5

File tree

1 file changed

+26
-36
lines changed

1 file changed

+26
-36
lines changed

src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,14 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
149149
FramedProof[Post](pre, body, post)(proof.blame)(proof.o)
150150
}
151151

152-
// def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] =
153-
// preds match {
154-
// case UnitAccountedPredicate(pred) => getConditions(pred)
155-
// case SplitAccountedPredicate(left, right) =>
156-
// getConditions(left) ++ getConditions(right)
157-
// }
158-
//
159-
// def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] =
160-
// e match {
161-
// case And(left, right) => getConditions(left) ++ getConditions(right)
162-
// case Star(left, right) => getConditions(left) ++ getConditions(right)
163-
// case other => Seq[Expr[Pre]](other)
164-
// }
152+
override def dispatch(p: AccountedPredicate[Pre]) : AccountedPredicate[Post] = {
153+
p match {
154+
case u@UnitAccountedPredicate(pred) =>
155+
topLevel = true
156+
u.rewriteDefault()
157+
case s@SplitAccountedPredicate(left, right) => s.rewriteDefault()
158+
}
159+
}
165160

166161
override def dispatch(loopContract: LoopContract[Pre]): LoopContract[Post] = {
167162
val loopInvariant: LoopInvariant[Pre] =
@@ -186,15 +181,18 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
186181

187182
topLevel = true
188183
infoGetter.setupInfo()
184+
val contextEverywhere = dispatch(contract.contextEverywhere)
185+
val oldInfo = infoGetter
186+
187+
// Reuse information from context everywhere
189188
val requires = dispatch(contract.requires)
190189
equalityChecker = ExpressionEqualityCheck()
191-
infoGetter.setupInfo()
190+
191+
// Again reuse information from context everywhere
192+
infoGetter = oldInfo
192193
val ensures = dispatch(contract.ensures)
193-
topLevel = false
194194
equalityChecker = ExpressionEqualityCheck()
195-
196-
// TODO: Is context everywhere al distributed here? If not, we need to do more.
197-
val contextEverywhere = dispatch(contract.contextEverywhere)
195+
topLevel = false
198196

199197
val signals = contract.signals.map(element => dispatch(element))
200198
val givenArgs =
@@ -286,23 +284,29 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
286284
var newBinder = false
287285

288286
def setData(): Unit = {
289-
val allConditions = unfoldBody(Seq())
287+
val allConditions = unfoldBody(Seq(), Seq())
290288
// Split bounds that are independent of any binding variables
291289
val (newIndependentConditions, potentialBounds) = allConditions
292290
.partition(indepOf(bindings, _))
293291
independentConditions.addAll(newIndependentConditions)
294292
getBounds(potentialBounds)
295293
}
296294

297-
def unfoldBody(prevConditions: Seq[Expr[Pre]]): Seq[Expr[Pre]] = {
295+
def unfoldBody(prevConditions: Seq[Expr[Pre]], scales: Seq[Expr[Pre] => Expr[Pre]]): Seq[Expr[Pre]] = {
298296
val (allConditions, mainBody) = unfoldImplies[Pre](body)
299297
val newConditions = prevConditions ++ allConditions
300298
val (newVars, secondBody) =
301299
mainBody match {
302300
case Forall(newVars, _, secondBody) => (newVars, secondBody)
303301
case Starall(newVars, _, secondBody) => (newVars, secondBody)
302+
// Strip Scales
303+
case s@Scale(scale, res) =>
304+
val newScales = scales :+ ((r: Expr[Pre]) => Scale(scale, r)(s.o))
305+
body = res
306+
return unfoldBody(newConditions, newScales)
304307
case _ =>
305-
body = mainBody
308+
// Re-aply scales from right to left
309+
body = scales.foldRight(mainBody)((s, b) => s(b))
306310
return newConditions
307311
}
308312

@@ -316,7 +320,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
316320

317321
body = secondBody
318322

319-
unfoldBody(newConditions)
323+
unfoldBody(newConditions, scales)
320324
}
321325

322326
def containsOtherBinders(e: Expr[Pre]): Boolean = {
@@ -421,18 +425,6 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
421425
}
422426
}
423427

424-
def testPairs[A](
425-
xs: Iterable[A],
426-
ys: Iterable[A],
427-
f: (A, A) => Boolean,
428-
): Boolean = {
429-
for (x <- xs)
430-
for (y <- ys)
431-
if (f(x, y))
432-
return true
433-
false
434-
}
435-
436428
/** We check if there now any binding variables which resolve to just a
437429
* single value, which happens if it has equal lower and upper bounds. E.g.
438430
* forall(int i,j; i == 0 && i <= j && j < 5; xs[j+i]) ==> forall(int j; 0
@@ -870,7 +862,6 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
870862
* additionally add base_{i-1} / a_{i-1} < n_{i-1} (derived from (x_{i-1}
871863
* < xmin_i + n_{i-1})
872864
*/
873-
// TODO ABOVE
874865
def check_vars_list(
875866
vars: List[Variable[Pre]]
876867
): Option[SubstituteForall] = {
@@ -989,7 +980,6 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]()
989980
Seq(PointerSubscript(newGen(arrayIndex.array), xNewVar)(
990981
triggerBlame
991982
))
992-
// Seq(PointerAdd(newGen(arrayIndex.array), xNewVar)(triggerBlame))
993983
)
994984
}
995985

0 commit comments

Comments
 (0)