diff --git a/prusti-common/src/vir/optimizations/methods/quantifier_fixer.rs b/prusti-common/src/vir/optimizations/methods/quantifier_fixer.rs index ca07e056406..9a7d3399732 100644 --- a/prusti-common/src/vir/optimizations/methods/quantifier_fixer.rs +++ b/prusti-common/src/vir/optimizations/methods/quantifier_fixer.rs @@ -13,8 +13,8 @@ use std::{collections::HashMap, mem}; /// /// 1. Replace all `old(...)` inside `forall ..` with `let tmp == (old(..)) in forall ..`. /// 2. Pull out all `unfolding ... in` that are inside `forall` to outside of `forall`. -/// 3. Replace all arithmetic expressions inside `forall` that do not depend on bound variables -/// with `let tmp == (...) in forall ..`. +/// 3. Replace all arithmetic and conditional expressions inside `forall` that +/// do not depend on bound variables with `let tmp == (...) in forall ..`. /// /// Note: this seems to be required to workaround some Silicon incompleteness. pub fn fix_quantifiers(cfg: vir::CfgMethod) -> vir::CfgMethod { @@ -191,6 +191,46 @@ impl<'a> vir::ExprFolder for Replacer<'a> { original_expr } } + + fn fold_cond( + &mut self, + vir::Cond { + guard, + then_expr, + else_expr, + position, + }: vir::Cond, + ) -> vir::Expr { + let contains_bounded = self + .bound_vars + .iter() + .any(|v| guard.find(v) || then_expr.find(v) || else_expr.find(v)); + if contains_bounded { + // Do not extract conditional branches into let-vars: it's possible that + // the "then" branch is well-defined only when `guard` is true, or + // vice-versa (i.e the "else" branch is only defined when `guard` is + // false). For example, the expression: + // `x >= 0 ? sqrt(x) + 1 : 1` + // is well-defined, but + // `let (t == sqrt(x) + 1) in x >= 0 ? t :1` + // is not. (assuming sqrt(x) is defined only for x >= 0) + vir::Expr::Cond(vir::Cond { + guard: self.fold_boxed(guard), + then_expr, + else_expr, + position, + }) + } else { + let original_expr = vir::Expr::Cond(vir::Cond { + guard, + then_expr, + else_expr, + position, + }); + self.replace_expr(original_expr, position) + } + } + fn fold_bin_op( &mut self, vir::BinOp { diff --git a/prusti-tests/tests/verify_overflow/pass/quantifiers/conditionals.rs b/prusti-tests/tests/verify_overflow/pass/quantifiers/conditionals.rs new file mode 100644 index 00000000000..2de479ba4cd --- /dev/null +++ b/prusti-tests/tests/verify_overflow/pass/quantifiers/conditionals.rs @@ -0,0 +1,40 @@ +use prusti_contracts::*; + +struct AccountID(u32); + +struct Bank { + +} + +impl Bank { + + #[pure] + #[trusted] + fn balance_of(&self, acct_id: &AccountID) -> u32 { + 0 + } + + #[trusted] + #[requires(amt >= 0 ==> u32::MAX - self.balance_of(acct_id) >= (amt as u32))] + #[ensures( + forall(|acct_id2: &AccountID| + self.balance_of(acct_id2) == + if(acct_id === acct_id2 && amt >= 0) { + old(self.balance_of(acct_id)) + (amt as u32) + } else { + 0 + } + ) + )] + fn adjust_amount(&mut self, acct_id: &AccountID, amt: i32) { + } + +} + + +#[requires(amt < 0 && bank.balance_of(to) >= (0 - amt) as u32)] +fn go(bank: &mut Bank, to: &AccountID, amt: i32) { + bank.adjust_amount(to, amt); +} + +pub fn main(){}