Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error when postconditions of pure functions contain old() expressions #1474

Merged
merged 7 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use prusti_contracts::*;

struct MyWrapper(u32);

impl MyWrapper {
#[pure]
#[ensures(old(self.0) == self.0)]
fn unwrap(&self) -> u32 { //~ ERROR old expressions should not appear in the postconditions of pure functions
self.0
}
}

fn test(x: &MyWrapper) -> u32 {
// Following error is due to stub encoding of invalid spec for function `unwrap()`
x.unwrap() //~ ERROR precondition of pure function call might not hold
}

fn main() { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use prusti_contracts::*;

#[extern_spec]
impl<T> std::option::Option<T> {
#[pure] // <=== Error triggered by this
#[requires(self.is_some())]
#[ensures(old(self) === Some(result))]
pub fn unwrap(self) -> T; //~ ERROR old expressions should not appear in the postconditions of pure functions

#[pure]
#[ensures(result == matches!(self, Some(_)))]
pub const fn is_some(&self) -> bool;
}

#[pure]
#[requires(x.is_some())]
fn test(x: Option<i32>) -> i32 {
// Following error is due to stub encoding of invalid external spec for function `unwrap()`
x.unwrap() //~ ERROR precondition of pure function call might not hold
}

fn main() { }
46 changes: 46 additions & 0 deletions prusti-viper/src/encoder/interface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::encoder::{
errors::{SpannedEncodingResult, WithSpan},
snapshot::interface::SnapshotEncoderInterface,
Encoder,
};

use prusti_rustc_interface::{
middle::{mir, ty, ty::Binder},
span::Span,
};

use vir_crate::polymorphic as vir_poly;

pub(crate) trait PureFunctionFormalArgsEncoderInterface<'p, 'v: 'p, 'tcx: 'v> {
fn encoder(&self) -> &'p Encoder<'v, 'tcx>;

fn check_type(
&self,
var_span: Span,
ty: Binder<'tcx, ty::Ty<'tcx>>,
) -> SpannedEncodingResult<()>;

fn get_span(&self, local: mir::Local) -> Span;

fn encode_formal_args(
&self,
sig: ty::PolyFnSig<'tcx>,
) -> SpannedEncodingResult<Vec<vir_poly::LocalVar>> {
let mut formal_args = vec![];
for local_idx in 0..sig.skip_binder().inputs().len() {
let local_ty = sig.input(local_idx);
let local = mir::Local::from_usize(local_idx + 1);
let var_name = format!("{local:?}");
let var_span = self.get_span(local);

self.check_type(var_span, local_ty)?;

let var_type = self
.encoder()
.encode_snapshot_type(local_ty.skip_binder())
.with_span(var_span)?;
formal_args.push(vir_poly::LocalVar::new(var_name, var_type))
}
Ok(formal_args)
}
}
75 changes: 43 additions & 32 deletions prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use crate::encoder::{
errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult, WithSpan},
high::{generics::HighGenericsEncoderInterface, types::HighTypeEncoderInterface},
interface::PureFunctionFormalArgsEncoderInterface,
mir::{
contracts::{ContractsEncoderInterface, ProcedureContract},
pure::{
Expand Down Expand Up @@ -50,7 +51,7 @@ pub(super) struct PureFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> {
/// Span of the function declaration.
span: Span,
/// Signature of the function to be encoded.
sig: ty::PolyFnSig<'tcx>,
pub(crate) sig: ty::PolyFnSig<'tcx>,
/// Spans of MIR locals, when encoding a local pure function.
local_spans: Option<Vec<Span>>,
}
Expand Down Expand Up @@ -137,6 +138,38 @@ fn encode_mir<'p, 'v: 'p, 'tcx: 'v>(
Ok(body_expr)
}

impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx>
for PureFunctionEncoder<'p, 'v, 'tcx>
{
fn encoder(&self) -> &'p Encoder<'v, 'tcx> {
self.encoder
}

fn check_type(
&self,
var_span: Span,
ty: ty::Binder<'tcx, ty::Ty<'tcx>>,
) -> SpannedEncodingResult<()> {
if !self
.encoder
.env()
.query
.type_is_copy(ty, self.parent_def_id)
{
Err(SpannedEncodingError::incorrect(
"pure function parameters must be Copy",
var_span,
))
} else {
Ok(())
}
}

fn get_span(&self, local: mir::Local) -> Span {
self.get_local_span(local)
}
}

impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
#[tracing::instrument(
name = "PureFunctionEncoder::new",
Expand Down Expand Up @@ -314,7 +347,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
let mut precondition = vec![type_precondition, func_precondition];
let mut postcondition = vec![self.encode_postcondition_expr(&contract)?];

let formal_args = self.encode_formal_args()?;
let formal_args = self.encode_formal_args(self.sig)?;
let return_type = self.encode_function_return_type()?;

let res_value_range_pos = self.encoder.error_manager().register_error(
Expand Down Expand Up @@ -545,6 +578,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
.replace_place(&encoded_return.into(), &pure_fn_return_variable.into())
.set_default_pos(postcondition_pos);

if post.has_old_expression() {
return Err(SpannedEncodingError::incorrect(
"old expressions should not appear in the postconditions of pure functions",
self.span,
));
}

Ok(post)
}

Expand Down Expand Up @@ -620,40 +660,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
.with_span(self.span)
}

fn encode_formal_args(&self) -> SpannedEncodingResult<Vec<vir::LocalVar>> {
let mut formal_args = vec![];
for local_idx in 0..self.sig.skip_binder().inputs().len() {
let local_ty = self.sig.input(local_idx);
let local = prusti_rustc_interface::middle::mir::Local::from_usize(local_idx + 1);
let var_name = format!("{local:?}");
let var_span = self.get_local_span(local);

if !self
.encoder
.env()
.query
.type_is_copy(local_ty, self.parent_def_id)
{
return Err(SpannedEncodingError::incorrect(
"pure function parameters must be Copy",
var_span,
));
}

let var_type = self
.encoder
.encode_snapshot_type(local_ty.skip_binder())
.with_span(var_span)?;
formal_args.push(vir::LocalVar::new(var_name, var_type))
}
Ok(formal_args)
}

pub fn encode_function_call_info(&self) -> SpannedEncodingResult<FunctionCallInfo> {
Ok(FunctionCallInfo {
name: self.encode_function_name(),
type_arguments: self.encode_type_arguments()?,
formal_args: self.encode_formal_args()?,
formal_args: self.encode_formal_args(self.sig)?,
return_type: self.encode_function_return_type()?,
})
}
Expand Down
35 changes: 26 additions & 9 deletions prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,11 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx>
substs,
);

let is_bodyless = self.is_trusted(proc_def_id, Some(substs))
|| !self.env().query.has_body(proc_def_id);

let maybe_identifier: SpannedEncodingResult<vir_poly::FunctionIdentifier> = (|| {
let proc_kind = self.get_proc_kind(proc_def_id, Some(substs));
let is_bodyless = self.is_trusted(proc_def_id, Some(substs))
|| !self.env().query.has_body(proc_def_id);
let mut function = if is_bodyless {
pure_function_encoder.encode_bodyless_function()?
} else {
Expand Down Expand Up @@ -393,13 +394,29 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx>
Err(error) => {
self.register_encoding_error(error);
debug!("Error encoding pure function: {:?}", proc_def_id);
let body = self
.env()
.body
.get_pure_fn_body(proc_def_id, substs, parent_def_id);
// TODO(tymap): does stub encoder need substs?
let stub_encoder = StubFunctionEncoder::new(self, proc_def_id, &body, substs);
let function = stub_encoder.encode_function()?;
let function = if !is_bodyless {
let pure_fn_body =
self.env()
.body
.get_pure_fn_body(proc_def_id, substs, parent_def_id);
let encoder = StubFunctionEncoder::new(
self,
proc_def_id,
Some(&pure_fn_body),
substs,
pure_function_encoder.sig,
);
encoder.encode_function()?
} else {
let encoder = StubFunctionEncoder::new(
self,
proc_def_id,
None,
substs,
pure_function_encoder.sig,
);
encoder.encode_function()?
};
self.log_vir_program_before_viper(function.to_string());
let identifier = self.insert_function(function);
self.pure_function_encoder_state
Expand Down
1 change: 1 addition & 0 deletions prusti-viper/src/encoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod encoder;
mod errors;
mod foldunfold;
mod initialisation;
mod interface;
mod loop_encoder;
mod mir_encoder;
mod mir_successor;
Expand Down
67 changes: 42 additions & 25 deletions prusti-viper/src/encoder/stub_function_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,82 @@
use crate::encoder::{
errors::{SpannedEncodingResult, WithSpan},
high::generics::HighGenericsEncoderInterface,
mir_encoder::{MirEncoder, PlaceEncoder},
interface::PureFunctionFormalArgsEncoderInterface,
snapshot::interface::SnapshotEncoderInterface,
Encoder,
};
use log::debug;
use prusti_rustc_interface::{
hir::def_id::DefId,
middle::{mir, ty::GenericArgsRef},
middle::{
mir, ty,
ty::{Binder, GenericArgsRef},
},
span::Span,
};
use vir_crate::polymorphic as vir;

use super::mir::specifications::SpecificationsInterface;

pub struct StubFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> {
encoder: &'p Encoder<'v, 'tcx>,
mir: &'p mir::Body<'tcx>,
mir_encoder: MirEncoder<'p, 'v, 'tcx>,
mir: Option<&'p mir::Body<'tcx>>,
proc_def_id: DefId,
substs: GenericArgsRef<'tcx>,
sig: ty::PolyFnSig<'tcx>,
}

impl<'p, 'v, 'tcx> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx>
for StubFunctionEncoder<'p, 'v, 'tcx>
{
fn check_type(&self, _span: Span, _ty: Binder<ty::Ty<'tcx>>) -> SpannedEncodingResult<()> {
Ok(())
}

fn encoder(&self) -> &'p Encoder<'v, 'tcx> {
self.encoder
}

fn get_span(&self, _local: mir::Local) -> Span {
self.encoder.get_spec_span(self.proc_def_id)
}
}

impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
#[tracing::instrument(name = "StubFunctionEncoder::new", level = "trace", skip(encoder, mir))]
pub fn new(
encoder: &'p Encoder<'v, 'tcx>,
proc_def_id: DefId,
mir: &'p mir::Body<'tcx>,
mir: Option<&'p mir::Body<'tcx>>,
substs: GenericArgsRef<'tcx>,
sig: ty::PolyFnSig<'tcx>,
) -> Self {
StubFunctionEncoder {
encoder,
mir,
mir_encoder: MirEncoder::new(encoder, mir, proc_def_id),
proc_def_id,
substs,
sig,
}
}

fn default_span(&self) -> Span {
self.mir
.map(|m| m.span)
.unwrap_or_else(|| self.encoder.get_spec_span(self.proc_def_id))
}

#[tracing::instrument(level = "debug", skip(self))]
pub fn encode_function(&self) -> SpannedEncodingResult<vir::Function> {
let function_name = self.encode_function_name();
debug!("Encode stub function {}", function_name);

let formal_args: Vec<_> = self
.mir
.args_iter()
.map(|local| {
let var_name = self.mir_encoder.encode_local_var_name(local);
let mir_type = self.mir_encoder.get_local_ty(local);
self.encoder
.encode_snapshot_type(mir_type)
.map(|var_type| vir::LocalVar::new(var_name, var_type))
})
.collect::<Result<_, _>>()
.with_span(self.mir.span)?;
let formal_args = self.encode_formal_args(self.sig)?;

let type_arguments = self
.encoder
.encode_generic_arguments(self.proc_def_id, self.substs)
.with_span(self.mir.span)?;
.with_span(self.default_span())?;

let return_type = self.encode_function_return_type()?;

Expand All @@ -74,8 +92,6 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
formal_args,
return_type,
pres: vec![false.into()],
// Note: Silicon is currently unsound when declaring a function that ensures `false`
// See: https://github.com/viperproject/silicon/issues/376
posts: vec![],
body: None,
};
Expand All @@ -94,9 +110,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
}

pub fn encode_function_return_type(&self) -> SpannedEncodingResult<vir::Type> {
let ty = self.mir.return_ty();
let return_local = mir::Place::return_place().as_local().unwrap();
let span = self.mir_encoder.get_local_span(return_local);
self.encoder.encode_snapshot_type(ty).with_span(span)
let ty = self.sig.output();

self.encoder
.encode_snapshot_type(ty.skip_binder())
.with_span(self.encoder.get_spec_span(self.proc_def_id))
}
}
Loading
Loading