diff --git a/prusti-common/src/vir/to_viper.rs b/prusti-common/src/vir/to_viper.rs index 6436efbfeb8..ecf3e9d5f7e 100644 --- a/prusti-common/src/vir/to_viper.rs +++ b/prusti-common/src/vir/to_viper.rs @@ -174,6 +174,9 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for Stmt { assert!(!pos.is_default(), "stmt with default pos: {self}"); ast.exhale(expr.to_viper(context, ast), pos.to_viper(context, ast)) } + Stmt::Assume(ref expr, ref pos) => { + ast.assume(expr.to_viper(context, ast), pos.to_viper(context, ast)) + } Stmt::Assert(ref expr, ref pos) => { ast.assert(expr.to_viper(context, ast), pos.to_viper(context, ast)) } diff --git a/prusti-viper/src/encoder/foldunfold/requirements.rs b/prusti-viper/src/encoder/foldunfold/requirements.rs index cd22aaba51c..5e63c7269ce 100644 --- a/prusti-viper/src/encoder/foldunfold/requirements.rs +++ b/prusti-viper/src/encoder/foldunfold/requirements.rs @@ -75,6 +75,10 @@ impl RequiredStmtPermissionsGetter for vir::Stmt { ref expr, ref position, }) + | &vir::Stmt::Assume(vir::Assume { + ref expr, + ref position, + }) | &vir::Stmt::Assert(vir::Assert { ref expr, ref position, diff --git a/prusti-viper/src/encoder/foldunfold/semantics.rs b/prusti-viper/src/encoder/foldunfold/semantics.rs index 85a92ba18a7..e33790aaddb 100644 --- a/prusti-viper/src/encoder/foldunfold/semantics.rs +++ b/prusti-viper/src/encoder/foldunfold/semantics.rs @@ -60,6 +60,7 @@ impl ApplyOnState for vir::Stmt { match self { &vir::Stmt::Comment(_) | &vir::Stmt::Label(_) + | &vir::Stmt::Assume(_) | &vir::Stmt::Assert(_) | &vir::Stmt::Refute(_) | &vir::Stmt::Obtain(_) => {} diff --git a/prusti-viper/src/encoder/procedure_encoder.rs b/prusti-viper/src/encoder/procedure_encoder.rs index aad32703d27..5cbac026b21 100644 --- a/prusti-viper/src/encoder/procedure_encoder.rs +++ b/prusti-viper/src/encoder/procedure_encoder.rs @@ -247,7 +247,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { } let assume_expr = self.encoder.encode_invariant(self.mir, bb, self.proc_def_id, cl_substs)?; - let assume_stmt = vir::Stmt::inhale(assume_expr, vir::Position::default()); + let assume_stmt = vir::Stmt::assume(assume_expr, vir::Position::default()); encoded_statements.push(assume_stmt); return Ok(true); diff --git a/viper-sys/build.rs b/viper-sys/build.rs index 01779ff4991..04665c0df51 100644 --- a/viper-sys/build.rs +++ b/viper-sys/build.rs @@ -180,6 +180,9 @@ fn main() { object_getter!(), method!("pretty", "(Lviper/silver/ast/Node;)Ljava/lang/String;") ]), + java_class!("viper.silver.ast.utility.ImpureAssumeRewriter", vec![ + method!("rewriteAssumes"), + ]), java_class!("viper.silver.ast.AbstractAssign$", vec![ object_getter!(), method!("apply"), @@ -226,6 +229,9 @@ fn main() { java_class!("viper.silver.ast.Assert", vec![ constructor!(), ]), + java_class!("viper.silver.ast.Assume", vec![ + constructor!(), + ]), java_class!("viper.silver.plugin.standard.refute.Refute", vec![ constructor!(), ]), diff --git a/viper/src/ast_factory/statement.rs b/viper/src/ast_factory/statement.rs index a2592f3539c..e6f43fe68fc 100644 --- a/viper/src/ast_factory/statement.rs +++ b/viper/src/ast_factory/statement.rs @@ -123,6 +123,16 @@ impl<'a> AstFactory<'a> { Stmt::new(obj) } + pub fn assume(&self, expr: Expr, pos: Position) -> Stmt<'a> { + let obj = self.jni.unwrap_result(ast::Assume::with(self.env).new( + expr.to_jobject(), + pos.to_jobject(), + self.no_info(), + self.no_trafos(), + )); + Stmt::new(obj) + } + pub fn assert(&self, expr: Expr, pos: Position) -> Stmt<'a> { let obj = self.jni.unwrap_result(ast::Assert::with(self.env).new( expr.to_jobject(), diff --git a/viper/src/ast_utils.rs b/viper/src/ast_utils.rs index 55120e17539..5afdc822429 100644 --- a/viper/src/ast_utils.rs +++ b/viper/src/ast_utils.rs @@ -33,6 +33,16 @@ impl<'a> AstUtils<'a> { .map(|java_vec| self.jni.seq_to_vec(java_vec)) } + pub(crate) fn rewrite_impure_assumes( + &self, + program: Program<'a>, + ) -> Result, JavaException> { + self.jni + .unwrap_or_exception( + silver::ast::utility::ImpureAssumeRewriter::with(self.env).call_rewriteAssumes(program.to_jobject())) + .map(|rewritten_prg_obj| Program::new(rewritten_prg_obj)) + } + #[tracing::instrument(level = "debug", skip_all)] pub fn pretty_print(&self, program: Program<'a>) -> String { let fast_pretty_printer_wrapper = diff --git a/viper/src/verifier.rs b/viper/src/verifier.rs index ea39719ee72..23bfa3b299b 100644 --- a/viper/src/verifier.rs +++ b/viper/src/verifier.rs @@ -173,6 +173,14 @@ impl<'a> Verifier<'a> { self.ast_utils.pretty_print(program) ); + // verification backends rely on Silver first performing this rewrite of the Viper AST + let program = match self.ast_utils.rewrite_impure_assumes(program) { + Ok(prg) => prg, + Err(java_exception) => { + return VerificationResult::JavaException(java_exception); + } + }; + run_timed!("Viper consistency checks", debug, let consistency_errors = match self.ast_utils.check_consistency(program) { Ok(errors) => errors, @@ -194,7 +202,6 @@ impl<'a> Verifier<'a> { .collect(), ); } - let program_option = self.jni.new_option(Some(program.to_jobject())); self.jni.unwrap_result(self.frontend_wrapper.set___program(self.frontend_instance, program_option)); diff --git a/vir/defs/polymorphic/ast/stmt.rs b/vir/defs/polymorphic/ast/stmt.rs index 4aa51d3463f..e8adb068b8f 100644 --- a/vir/defs/polymorphic/ast/stmt.rs +++ b/vir/defs/polymorphic/ast/stmt.rs @@ -20,6 +20,7 @@ pub enum Stmt { Label(Label), Inhale(Inhale), Exhale(Exhale), + Assume(Assume), Assert(Assert), Refute(Refute), /// MethodCall: method_name, args, targets @@ -74,6 +75,7 @@ impl fmt::Display for Stmt { Stmt::Label(label) => label.fmt(f), Stmt::Inhale(inhale) => inhale.fmt(f), Stmt::Exhale(exhale) => exhale.fmt(f), + Stmt::Assume(assume) => assume.fmt(f), Stmt::Assert(assert) => assert.fmt(f), Stmt::Refute(refute) => refute.fmt(f), Stmt::MethodCall(method_call) => method_call.fmt(f), @@ -161,13 +163,19 @@ impl fmt::Display for Exhale { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct Assert { +pub struct Assume { pub expr: Expr, pub position: Position, } +impl fmt::Display for Assume { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "assume {}", self.expr) + } +} + #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct Refute { +pub struct Assert { pub expr: Expr, pub position: Position, } @@ -178,6 +186,12 @@ impl fmt::Display for Assert { } } +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct Refute { + pub expr: Expr, + pub position: Position, +} + impl fmt::Display for Refute { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "refute {}", self.expr) @@ -512,6 +526,10 @@ impl Stmt { Stmt::Exhale(Exhale { expr, position }) } + pub fn assume(expr: Expr, position: Position) -> Self { + Stmt::Assume(Assume { expr, position }) + } + pub fn package_magic_wand( lhs: Expr, rhs: Expr, @@ -598,6 +616,7 @@ pub trait StmtFolder { Stmt::Label(label) => self.fold_label(label), Stmt::Inhale(inhale) => self.fold_inhale(inhale), Stmt::Exhale(exhale) => self.fold_exhale(exhale), + Stmt::Assume(assume) => self.fold_assume(assume), Stmt::Assert(assert) => self.fold_assert(assert), Stmt::Refute(refute) => self.fold_refute(refute), Stmt::MethodCall(method_call) => self.fold_method_call(method_call), @@ -647,6 +666,15 @@ pub trait StmtFolder { }) } + fn fold_assume(&mut self, statement: Assume) -> Stmt { + let Assume { expr, position } = statement; + Stmt::Assume(Assume { + expr: self.fold_expr(expr), + position, + }) + } + + fn fold_assert(&mut self, statement: Assert) -> Stmt { let Assert { expr, position } = statement; Stmt::Assert(Assert { @@ -817,6 +845,7 @@ pub trait FallibleStmtFolder { Stmt::Label(label) => self.fallible_fold_label(label), Stmt::Inhale(inhale) => self.fallible_fold_inhale(inhale), Stmt::Exhale(exhale) => self.fallible_fold_exhale(exhale), + Stmt::Assume(assume) => self.fallible_fold_assume(assume), Stmt::Assert(assert) => self.fallible_fold_assert(assert), Stmt::Refute(refute) => self.fallible_fold_refute(refute), Stmt::MethodCall(method_call) => self.fallible_fold_method_call(method_call), @@ -870,6 +899,14 @@ pub trait FallibleStmtFolder { })) } + fn fallible_fold_assume(&mut self, statement: Assume) -> Result { + let Assume { expr, position } = statement; + Ok(Stmt::Assume(Assume { + expr: self.fallible_fold_expr(expr)?, + position, + })) + } + fn fallible_fold_assert(&mut self, statement: Assert) -> Result { let Assert { expr, position } = statement; Ok(Stmt::Assert(Assert { @@ -1069,6 +1106,7 @@ pub trait StmtWalker { Stmt::Label(label) => self.walk_label(label), Stmt::Inhale(inhale) => self.walk_inhale(inhale), Stmt::Exhale(exhale) => self.walk_exhale(exhale), + Stmt::Assume(assume) => self.walk_assume(assume), Stmt::Assert(assert) => self.walk_assert(assert), Stmt::Refute(refute) => self.walk_refute(refute), Stmt::MethodCall(method_call) => self.walk_method_call(method_call), @@ -1108,6 +1146,11 @@ pub trait StmtWalker { self.walk_expr(expr); } + fn walk_assume(&mut self, statement: &Assume) { + let Assume { expr, .. } = statement; + self.walk_expr(expr); + } + fn walk_assert(&mut self, statement: &Assert) { let Assert { expr, .. } = statement; self.walk_expr(expr); @@ -1228,6 +1271,7 @@ pub trait FallibleStmtWalker { Stmt::Label(label) => self.fallible_walk_label(label), Stmt::Inhale(inhale) => self.fallible_walk_inhale(inhale), Stmt::Exhale(exhale) => self.fallible_walk_exhale(exhale), + Stmt::Assume(assume) => self.fallible_walk_assume(assume), Stmt::Assert(assert) => self.fallible_walk_assert(assert), Stmt::Refute(refute) => self.fallible_walk_refute(refute), Stmt::MethodCall(method_call) => self.fallible_walk_method_call(method_call), @@ -1281,6 +1325,12 @@ pub trait FallibleStmtWalker { Ok(()) } + fn fallible_walk_assume(&mut self, statement: &Assume) -> Result<(), Self::Error> { + let Assume { expr, .. } = statement; + self.fallible_walk_expr(expr)?; + Ok(()) + } + fn fallible_walk_assert(&mut self, statement: &Assert) -> Result<(), Self::Error> { let Assert { expr, .. } = statement; self.fallible_walk_expr(expr)?; diff --git a/vir/src/converter/polymorphic_to_legacy.rs b/vir/src/converter/polymorphic_to_legacy.rs index 042321f698d..d3fde7f7e90 100644 --- a/vir/src/converter/polymorphic_to_legacy.rs +++ b/vir/src/converter/polymorphic_to_legacy.rs @@ -667,6 +667,9 @@ impl From for legacy::Stmt { polymorphic::Stmt::Exhale(exhale) => { legacy::Stmt::Exhale(exhale.expr.into(), exhale.position.into()) } + polymorphic::Stmt::Assume(assume) => { + legacy::Stmt::Assume(assume.expr.into(), assume.position.into()) + } polymorphic::Stmt::Assert(assert) => { legacy::Stmt::Assert(assert.expr.into(), assert.position.into()) } diff --git a/vir/src/converter/type_substitution.rs b/vir/src/converter/type_substitution.rs index 034b1f71574..22ffb30d29a 100644 --- a/vir/src/converter/type_substitution.rs +++ b/vir/src/converter/type_substitution.rs @@ -572,6 +572,7 @@ impl Generic for Stmt { Stmt::Label(label) => Stmt::Label(label.substitute(map)), Stmt::Inhale(inhale) => Stmt::Inhale(inhale.substitute(map)), Stmt::Exhale(exhale) => Stmt::Exhale(exhale.substitute(map)), + Stmt::Assume(assume) => Stmt::Assume(assume.substitute(map)), Stmt::Assert(assert) => Stmt::Assert(assert.substitute(map)), Stmt::Refute(refute) => Stmt::Refute(refute.substitute(map)), Stmt::MethodCall(method_call) => Stmt::MethodCall(method_call.substitute(map)), @@ -639,6 +640,14 @@ impl Generic for Exhale { } } +impl Generic for Assume { + fn substitute(self, map: &FxHashMap) -> Self { + let mut assume = self; + assume.expr = assume.expr.substitute(map); + assume + } +} + impl Generic for Assert { fn substitute(self, map: &FxHashMap) -> Self { let mut assert = self; diff --git a/vir/src/legacy/ast/stmt.rs b/vir/src/legacy/ast/stmt.rs index 9b53a3f9120..fa14516d451 100644 --- a/vir/src/legacy/ast/stmt.rs +++ b/vir/src/legacy/ast/stmt.rs @@ -22,6 +22,7 @@ pub enum Stmt { Label(String), Inhale(Expr, Position), Exhale(Expr, Position), + Assume(Expr, Position), Assert(Expr, Position), Refute(Expr, Position), /// MethodCall: method_name, args, targets @@ -84,6 +85,7 @@ impl Hash for Stmt { Stmt::Label(s) => s.hash(state), Stmt::Inhale(e, p) => (e, p).hash(state), Stmt::Exhale(e, p) => (e, p).hash(state), + Stmt::Assume(e, p) => (e, p).hash(state), Stmt::Assert(e, p) => (e, p).hash(state), Stmt::Refute(e, p) => (e, p).hash(state), Stmt::MethodCall(s, v1, v2) => (s, v1, v2).hash(state), @@ -132,6 +134,9 @@ impl fmt::Display for Stmt { write!(f, "inhale {expr}") } Stmt::Exhale(ref expr, _) => write!(f, "exhale {expr}"), + Stmt::Assume(ref expr, _) => { + write!(f, "assume {expr}") + } Stmt::Assert(ref expr, _) => { write!(f, "assert {expr}") } @@ -303,6 +308,7 @@ impl Stmt { match self { Stmt::Inhale(_, ref p) | Stmt::Exhale(_, ref p) + | Stmt::Assume(_, ref p) | Stmt::Assert(_, ref p) | Stmt::Refute(_, ref p) | Stmt::Fold(_, _, _, _, ref p) @@ -317,6 +323,7 @@ impl Stmt { match self { Stmt::Inhale(_, ref mut p) | Stmt::Exhale(_, ref mut p) + | Stmt::Assume(_, ref mut p) | Stmt::Assert(_, ref mut p) | Stmt::Refute(_, ref mut p) | Stmt::Fold(_, _, _, _, ref mut p) @@ -387,6 +394,7 @@ pub trait StmtFolder { Stmt::Label(s) => self.fold_label(s), Stmt::Inhale(expr, pos) => self.fold_inhale(expr, pos), Stmt::Exhale(e, p) => self.fold_exhale(e, p), + Stmt::Assume(expr, pos) => self.fold_assume(expr, pos), Stmt::Assert(expr, pos) => self.fold_assert(expr, pos), Stmt::Refute(expr, pos) => self.fold_refute(expr, pos), Stmt::MethodCall(s, ve, vv) => self.fold_method_call(s, ve, vv), @@ -425,6 +433,10 @@ pub trait StmtFolder { Stmt::Exhale(self.fold_expr(e), p) } + fn fold_assume(&mut self, expr: Expr, pos: Position) -> Stmt { + Stmt::Assume(self.fold_expr(expr), pos) + } + fn fold_assert(&mut self, expr: Expr, pos: Position) -> Stmt { Stmt::Assert(self.fold_expr(expr), pos) } @@ -540,6 +552,7 @@ pub trait FallibleStmtFolder { Stmt::Label(s) => self.fallible_fold_label(s), Stmt::Inhale(expr, pos) => self.fallible_fold_inhale(expr, pos), Stmt::Exhale(e, p) => self.fallible_fold_exhale(e, p), + Stmt::Assume(expr, pos) => self.fallible_fold_assume(expr, pos), Stmt::Assert(expr, pos) => self.fallible_fold_assert(expr, pos), Stmt::Refute(expr, pos) => self.fallible_fold_refute(expr, pos), Stmt::MethodCall(s, ve, vv) => self.fallible_fold_method_call(s, ve, vv), @@ -580,6 +593,10 @@ pub trait FallibleStmtFolder { Ok(Stmt::Exhale(self.fallible_fold_expr(e)?, p)) } + fn fallible_fold_assume(&mut self, expr: Expr, pos: Position) -> Result { + Ok(Stmt::Assume(self.fallible_fold_expr(expr)?, pos)) + } + fn fallible_fold_assert(&mut self, expr: Expr, pos: Position) -> Result { Ok(Stmt::Assert(self.fallible_fold_expr(expr)?, pos)) } @@ -737,6 +754,7 @@ pub trait StmtWalker { Stmt::Label(s) => self.walk_label(s), Stmt::Inhale(expr, pos) => self.walk_inhale(expr, pos), Stmt::Exhale(e, p) => self.walk_exhale(e, p), + Stmt::Assume(expr, pos) => self.walk_assume(expr, pos), Stmt::Assert(expr, pos) => self.walk_assert(expr, pos), Stmt::Refute(expr, pos) => self.walk_refute(expr, pos), Stmt::MethodCall(s, ve, vv) => self.walk_method_call(s, ve, vv), @@ -771,6 +789,10 @@ pub trait StmtWalker { self.walk_expr(expr); } + fn walk_assume(&mut self, expr: &Expr, _pos: &Position) { + self.walk_expr(expr); + } + fn walk_assert(&mut self, expr: &Expr, _pos: &Position) { self.walk_expr(expr); }