diff --git a/prusti-interface/src/environment/mod.rs b/prusti-interface/src/environment/mod.rs index f0cf90c3349..e5c90a96987 100644 --- a/prusti-interface/src/environment/mod.rs +++ b/prusti-interface/src/environment/mod.rs @@ -9,7 +9,7 @@ use prusti_rustc_interface::middle::mir; use prusti_rustc_interface::hir::hir_id::HirId; use prusti_rustc_interface::hir::def_id::{DefId, LocalDefId}; -use prusti_rustc_interface::middle::ty::{self, Binder, BoundConstness, ImplPolarity, TraitPredicate, TraitRef, TyCtxt}; +use prusti_rustc_interface::middle::ty::{self, Binder, BoundConstness, ImplPolarity, TraitPredicate, TraitRef, TyCtxt, TypeSuperFoldable}; use prusti_rustc_interface::middle::ty::subst::{Subst, SubstsRef}; use prusti_rustc_interface::trait_selection::infer::{InferCtxtExt, TyCtxtInferExt}; use prusti_rustc_interface::trait_selection::traits::{ImplSource, Obligation, ObligationCause, SelectionContext}; @@ -55,6 +55,11 @@ struct CachedBody<'tcx> { monomorphised_bodies: HashMap, Rc>>, /// Cached borrowck information. borrowck_facts: Rc, + /// Copies of the MIR body with the given substs applied, called from the + /// given caller. This also allows for associated types to be correctly + /// normalised. + /// TODO: merge more nicely with monomorphised_bodies? + monomorphised_bodies_with_caller: HashMap<(SubstsRef<'tcx>, LocalDefId), Rc>>, } struct CachedExternalBody<'tcx> { @@ -303,12 +308,56 @@ impl<'tcx> Environment<'tcx> { base_body: Rc::new(body), monomorphised_bodies: HashMap::new(), borrowck_facts: Rc::new(facts), + monomorphised_bodies_with_caller: HashMap::new(), } }); body .monomorphised_bodies .entry(substs) - .or_insert_with(|| ty::EarlyBinder(body.base_body.clone()).subst(self.tcx, substs)) + .or_insert_with(|| { + let param_env = self.tcx.param_env(def_id); + let substituted = ty::EarlyBinder(body.base_body.clone()).subst(self.tcx, substs); + self.resolve_assoc_types(substituted.clone(), param_env) + }) + .clone() + } + + pub fn local_mir_with_caller( + &self, + def_id: LocalDefId, + caller_def_id: LocalDefId, + substs: SubstsRef<'tcx>, + ) -> Rc> { + // TODO: duplication ... + let mut bodies = self.bodies.borrow_mut(); + let body = bodies.entry(def_id) + .or_insert_with(|| { + // SAFETY: This is safe because we are feeding in the same `tcx` + // that was used to store the data. + let body_with_facts = unsafe { + self::mir_storage::retrieve_mir_body(self.tcx, def_id) + }; + let body = body_with_facts.body; + let facts = BorrowckFacts { + input_facts: RefCell::new(Some(body_with_facts.input_facts)), + output_facts: body_with_facts.output_facts, + location_table: RefCell::new(Some(body_with_facts.location_table)), + }; + + CachedBody { + base_body: Rc::new(body), + monomorphised_bodies: HashMap::new(), + borrowck_facts: Rc::new(facts), + monomorphised_bodies_with_caller: HashMap::new(), + } + }); + body + .monomorphised_bodies_with_caller + .entry((substs, caller_def_id)) + .or_insert_with(|| { + let param_env = self.tcx.param_env(caller_def_id); + self.tcx.subst_and_normalize_erasing_regions(substs, param_env, body.base_body.clone()) + }) .clone() } @@ -538,10 +587,35 @@ impl<'tcx> Environment<'tcx> { // Normalize the type to account for associated types let ty = self.resolve_assoc_types(ty, param_env); let ty = self.tcx.erase_late_bound_regions(ty); + + // Associated type was not normalised but it might still have a + // `Copy` bound declared on it. + // TODO: implement this more generally in `type_implements_trait` and + // `type_implements_trait_with_trait_substs`. + if let ty::TyKind::Projection(proj) = ty.kind() { + let item_bounds = self.tcx.bound_item_bounds(proj.item_def_id); + if item_bounds.0.iter().any(|predicate| { + if let ty::PredicateKind::Trait(ty::TraitPredicate { + trait_ref: ty::TraitRef { def_id: trait_def_id, .. }, + polarity: ty::ImplPolarity::Positive, + .. + }) = predicate.kind().skip_binder() { + trait_def_id == self.tcx.lang_items() + .copy_trait() + .unwrap() + } else { + false + } + }) { + return true; + } + } + ty.is_copy_modulo_regions(self.tcx.at(prusti_rustc_interface::span::DUMMY_SP), param_env) } /// Checks whether the given type implements the trait with the given DefId. + /// The trait should have no type parameters. pub fn type_implements_trait(&self, ty: ty::Ty<'tcx>, trait_def_id: DefId, param_env: ty::ParamEnv<'tcx>) -> bool { self.type_implements_trait_with_trait_substs(ty, trait_def_id, ty::List::empty(), param_env) } @@ -583,22 +657,51 @@ impl<'tcx> Environment<'tcx> { /// Normalizes associated types in foldable types, /// i.e. this resolves projection types ([ty::TyKind::Projection]s) - /// **Important:** Regions while be erased during this process - pub fn resolve_assoc_types + std::fmt::Debug + Copy>(&self, normalizable: T, param_env: ty::ParamEnv<'tcx>) -> T { - let norm_res = self.tcx.try_normalize_erasing_regions( - param_env, - normalizable - ); + pub fn resolve_assoc_types + std::fmt::Debug>(&self, normalizable: T, param_env: ty::ParamEnv<'tcx>) -> T { + struct Normalizer<'a, 'tcx> { + tcx: &'a ty::TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + } + impl<'a, 'tcx> ty::fold::TypeFolder<'tcx> for Normalizer<'a, 'tcx> { + fn tcx(&self) -> ty::TyCtxt<'tcx> { + *self.tcx + } - match norm_res { - Ok(normalized) => { - debug!("Normalized {:?}: {:?}", normalizable, normalized); - normalized - }, - Err(err) => { - debug!("Error while resolving associated types for {:?}: {:?}", normalizable, err); - normalizable + fn fold_mir_const(&mut self, c: mir::ConstantKind<'tcx>) -> mir::ConstantKind<'tcx> { + // rustc by default panics when we execute this TypeFolder on a mir::* type, + // but we want to resolve associated types when we retrieve a local mir::Body + c + } + + fn fold_ty(&mut self, ty: ty::Ty<'tcx>) -> ty::Ty<'tcx> { + match ty.kind() { + ty::TyKind::Projection(_) => { + let normalized = self.tcx.infer_ctxt().enter(|infcx| { + use prusti_rustc_interface::trait_selection::traits::{fully_normalize, FulfillmentContext}; + + let normalization_result = fully_normalize( + &infcx, + FulfillmentContext::new(), + ObligationCause::dummy(), + self.param_env, + ty + ); + + match normalization_result { + Ok(res) => res, + Err(errors) => { + debug!("Error while resolving associated types: {:?}", errors); + ty + } + } + }); + normalized.super_fold_with(self) + } + _ => ty.super_fold_with(self) + } } } + + normalizable.fold_with(&mut Normalizer {tcx: &self.tcx, param_env}) } } diff --git a/prusti-interface/src/environment/procedure.rs b/prusti-interface/src/environment/procedure.rs index 6fb94729264..e4df7acba13 100644 --- a/prusti-interface/src/environment/procedure.rs +++ b/prusti-interface/src/environment/procedure.rs @@ -40,7 +40,7 @@ impl<'tcx> Procedure<'tcx> { trace!("Encoding procedure {:?}", proc_def_id); let tcx = env.tcx(); // TOOD(tymap): add substs to procedure? check usages - let mir = env.local_mir(proc_def_id.expect_local(), env.identity_substs(proc_def_id)); + let mir = env.local_mir_with_caller(proc_def_id.expect_local(), proc_def_id.expect_local(), env.identity_substs(proc_def_id)); let real_edges = RealEdges::new(&mir); let reachable_basic_blocks = build_reachable_basic_blocks(&mir, &real_edges); let nonspec_basic_blocks = build_nonspec_basic_blocks(&mir, &real_edges, &tcx); diff --git a/prusti-tests/tests/verify_overflow/pass/generic/associated-copy.rs b/prusti-tests/tests/verify_overflow/pass/generic/associated-copy.rs new file mode 100644 index 00000000000..163f2ad139b --- /dev/null +++ b/prusti-tests/tests/verify_overflow/pass/generic/associated-copy.rs @@ -0,0 +1,15 @@ +use prusti_contracts::*; + +trait Trait { + type Assoc: Copy; + + #[pure] fn output_copy(&self) -> Self::Assoc; + #[pure] fn input_copy(&self, x: Self::Assoc) -> bool; +} + +#[requires(x.output_copy() === y)] +#[requires(x.input_copy(z))] +fn test>(x: &mut T, y: U, z: U) {} + +#[trusted] +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/pass/generic/resolution.rs b/prusti-tests/tests/verify_overflow/pass/generic/resolution.rs new file mode 100644 index 00000000000..befff4df0ce --- /dev/null +++ b/prusti-tests/tests/verify_overflow/pass/generic/resolution.rs @@ -0,0 +1,231 @@ +use prusti_contracts::*; + +// Test generics and type parameter resolution at call sites. There are various +// aspects/features that can complicate a method call resolution: +// +// 1. type parameters in the signature of the method +// 2. lifetime parameters in the signature of the method +// 3. type parameters in the impl block containing the method (= its struct) +// 4. lifetime parameters in the impl block containing the method (= its struct) +// 5. the method belongs to a trait +// 6. associated types +// +// Test also calls from/to pure functions. + +type TupleIntInt = (i32, i32); +type TupleIntUsize = (i32, usize); +type TupleUsizeInt = (usize, i32); + +trait Valid1 { #[pure] fn valid1(&self) -> bool; } +trait Valid2 { #[pure] fn valid2(&self) -> bool; } + +#[refine_trait_spec] impl Valid1 for i32 { + #[pure] fn valid1(&self) -> bool { + *self == 3 + } +} +#[refine_trait_spec] impl Valid2 for i32 { + #[pure] fn valid2(&self) -> bool { + *self == 7 + } +} + +#[refine_trait_spec] impl Valid1 for TupleIntInt { + #[pure] fn valid1(&self) -> bool { + let valid = (8, 9); + *self == valid + } +} +#[refine_trait_spec] impl Valid2 for TupleIntInt { + #[pure] fn valid2(&self) -> bool { + let valid = (4, 5); + *self == valid + } +} + +#[refine_trait_spec] impl Valid1 for TupleIntUsize { + #[pure] fn valid1(&self) -> bool { + let valid = (42, 43); + *self == valid + } +} +#[refine_trait_spec] impl Valid1 for TupleUsizeInt { + #[pure] fn valid1(&self) -> bool { + let valid = (33, 44); + *self == valid + } +} + +#[refine_trait_spec] impl Valid1 for bool { + #[pure] fn valid1(&self) -> bool { + *self + } +} +#[refine_trait_spec] impl Valid2 for bool { + #[pure] fn valid2(&self) -> bool { + !*self + } +} + +#[trusted] +#[requires(a.valid1() && b.valid2())] +#[ensures(result.valid1())] +fn fn1(a: A, b: B) -> C { unimplemented!() } + +#[trusted] +#[pure] +#[requires(a.valid1() && b.valid2())] +#[ensures(result.valid1())] +fn pure_fn1(a: A, b: B) -> C { unimplemented!() } + +#[requires( + pure_fn1::(3, (4, 5)) +)] +fn test_fn1() { + assert!(fn1::(3, (4, 5))); + assert!(pure_fn1::(3, (4, 5))); +} + +#[trusted] +#[requires(a.valid2() && b.valid1())] +#[ensures(result.valid2())] +fn fn2<'a, 'b, A: Valid2, B: Valid1, C: Valid2>(a: &'a A, b: &'b B) -> C { unimplemented!() } + +#[trusted] +#[pure] +#[requires(a.valid2() && b.valid1())] +#[ensures(result.valid2())] +fn pure_fn2<'a, 'b, A: Valid2, B: Valid1, C: Valid2 + Copy>(a: &'a A, b: &'b B) -> C { unimplemented!() } + +// TODO: fold/unfold error when the precondition is enabled +/* +#[requires({ + let a = 7; + let b = (8, 9); + pure_fn2::(&a, &b); +})] +*/ +fn test_fn2() { + let a = 7; + let b = (8, 9); + assert!(!fn2::(&a, &b)); + assert!(!pure_fn2::(&a, &b)); +} + +struct X1(A, B); +impl X1 { + #[trusted] + #[requires(a.valid1() && b.valid2())] + #[ensures(result.valid1())] + fn fn3<'a, 'b, C: Valid1>(&self, a: &'a A, b: &'b B) -> C { unimplemented!() } + + #[trusted] + #[pure] + #[requires(a.valid1() && b.valid2())] + #[ensures(result.valid1())] + fn pure_fn3<'a, 'b, C: Valid1 + Copy>(&self, a: &'a A, b: &'b B) -> C { unimplemented!() } +} + +fn test_fn3() { + let a = 3; + let b = (4, 5); + let x = X1::(0, (0, 0)); + assert!(x.fn3::(&a, &b)); + assert!(x.pure_fn3::(&a, &b)); +} + +// Using `&'a A` or `&'b B` directly in `X2` fails because Prusti tries to +// encode the reference-typed fields `X2.0` resp. `X2.1`. +struct X2<'a, 'b, A, B>(std::marker::PhantomData<&'a A>, std::marker::PhantomData<&'b B>); +impl<'a, 'b, A: Valid2, B: Valid1> X2<'a, 'b, A, B> { + #[trusted] + #[requires(a.valid2() && b.valid1())] + #[ensures(result.valid2())] + fn fn4(&self, a: &'a A, b: &'b B) -> C { unimplemented!() } + + #[trusted] + #[pure] + #[requires(a.valid2() && b.valid1())] + #[ensures(result.valid2())] + fn pure_fn4(&self, a: &'a A, b: &'b B) -> C { unimplemented!() } +} + +fn test_fn4<'a, 'b>(x: X2<'a, 'b, i32, (i32, i32)>) { + let a = 7; + let b = (8, 9); + assert!(!x.fn4::(&a, &b)); + assert!(!x.pure_fn4::(&a, &b)); +} + +trait T1 { + #[requires(a.valid1() && b.valid2())] + #[ensures(result.valid1())] + fn fn5(&self, a: A, b: B) -> C; + + #[pure] + #[requires(a.valid1() && b.valid2())] + #[ensures(result.valid1())] + fn pure_fn5(&self, a: A, b: B) -> C; +} + +struct X3 {} +impl T1 for X3 { + #[trusted] + fn fn5(&self, a: A, b: B) -> C { unimplemented!() } + + #[trusted] + #[pure] + fn pure_fn5(&self, a: A, b: B) -> C { unimplemented!() } +} + +fn test_fn5(t: T, x: X3) { + assert!(t.fn5::(3, (4, 5))); + assert!(t.pure_fn5::(3, (4, 5))); + assert!(x.fn5::(3, (4, 5))); + assert!(x.pure_fn5::(3, (4, 5))); +} + +trait T2 { + type AT1: Valid1 + Copy; + type AT2: Valid1 + Copy; + + // TODO: the c.valid1() && d.valid1() constraint causes a panic; + // we probably need to use the ParamEnv-respecting `local_mir` + // throughout the codebase + + #[requires(a.valid2() && b.valid1() + // && c.valid1() && d.valid1() + )] + #[ensures(result.valid2())] + fn fn6(&self, a: A, b: B, c: Self::AT1, d: Self::AT2) -> C; + + #[pure] + #[requires(a.valid2() && b.valid1() + // && c.valid1() && d.valid1() + )] + #[ensures(result.valid2())] + fn pure_fn6(&self, a: A, b: B, c: Self::AT1, d: Self::AT2) -> C; +} + +struct X4 {} +impl T2 for X4 { + type AT1 = (i32, usize); + type AT2 = (usize, i32); + + #[trusted] + fn fn6(&self, a: A, b: B, c: (i32, usize), d: (usize, i32)) -> C { unimplemented!() } + + #[trusted] + #[pure] + fn pure_fn6(&self, a: A, b: B, c: (i32, usize), d: (usize, i32)) -> C { unimplemented!() } +} + +fn test_fn6>(t: T, x: X4) { + assert!(!t.fn6::(7, (8, 9), (42, 43), (33, 44))); + assert!(!t.pure_fn6::(7, (8, 9), (42, 43), (33, 44))); + assert!(!x.fn6::(7, (8, 9), (42, 43), (33, 44))); + assert!(!x.pure_fn6::(7, (8, 9), (42, 43), (33, 44))); +} + +#[trusted] +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/pass/generic/trait-lifetime.rs b/prusti-tests/tests/verify_overflow/pass/generic/trait-lifetime.rs new file mode 100644 index 00000000000..0ab7071122e --- /dev/null +++ b/prusti-tests/tests/verify_overflow/pass/generic/trait-lifetime.rs @@ -0,0 +1,12 @@ +use prusti_contracts::*; + +trait Trait<'a> { + type AT; +} + +fn test<'a, T: Trait<'a>>(_t: T) -> T::AT { + unimplemented!() +} + +#[trusted] +fn main() {} diff --git a/prusti-tests/tests/verify_partial/fail/issues/issue-729-1.rs b/prusti-tests/tests/verify_partial/fail/issues/issue-729-1.rs index acb0729ee92..97b6e1c5e5b 100644 --- a/prusti-tests/tests/verify_partial/fail/issues/issue-729-1.rs +++ b/prusti-tests/tests/verify_partial/fail/issues/issue-729-1.rs @@ -1,7 +1,7 @@ // FIXME: remove this compile flag when the new encoder is finished // compile-flags: -Puse_new_encoder=false -// error-pattern: Precondition of function snap$__$TY$__Snap$struct$m_A$ might not hold +// error-pattern: Precondition of function snap$__$TY$__Snap$struct$m_A$struct$m_A$Snap$struct$m_A might not hold // FIXME: https://github.com/viperproject/prusti-dev/issues/729 #![allow(unused_comparisons)] use prusti_contracts::*; diff --git a/prusti-viper/src/encoder/encoder.rs b/prusti-viper/src/encoder/encoder.rs index a80b1b6040e..c4727692856 100644 --- a/prusti-viper/src/encoder/encoder.rs +++ b/prusti-viper/src/encoder/encoder.rs @@ -712,7 +712,7 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { // TODO: Make sure that this encoded function does not end up in // the Viper file because that would be unsound. let identity_substs = self.env().identity_substs(proc_def_id); - if let Err(error) = self.encode_pure_function_def(proc_def_id, identity_substs) { + if let Err(error) = self.encode_pure_function_def(proc_def_id, proc_def_id, identity_substs) { self.register_encoding_error(error); debug!("Error encoding function: {:?}", proc_def_id); // Skip encoding the function as a method. diff --git a/prusti-viper/src/encoder/high/types/fields.rs b/prusti-viper/src/encoder/high/types/fields.rs index 5807986ab91..d7f7c40f385 100644 --- a/prusti-viper/src/encoder/high/types/fields.rs +++ b/prusti-viper/src/encoder/high/types/fields.rs @@ -32,7 +32,8 @@ pub(crate) fn create_value_field(ty: vir::Type) -> EncodingResult vir::FieldDecl::new("val_ref", 0usize, ty), + | vir::Type::TypeVar(_) + | vir::Type::Projection(_) => vir::FieldDecl::new("val_ref", 0usize, ty), vir::Type::Reference(vir::ty::Reference { target_type, .. }) => { vir::FieldDecl::new("val_ref", 0usize, (*target_type).clone()) @@ -49,7 +50,6 @@ pub(crate) fn create_value_field(ty: vir::Type) -> EncodingResult { return Err(EncodingError::unsupported(format!( "{} type is not supported", diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder.rs index 6ce44a6738c..716ac54199c 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder.rs @@ -103,12 +103,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { let span = encoder.get_spec_span(proc_def_id); // TODO: move this to a signatures module - use prusti_rustc_interface::middle::ty::subst::Subst; - let sig = ty::EarlyBinder(encoder.env().tcx().fn_sig(proc_def_id)) - .subst(encoder.env().tcx(), substs); - let sig = encoder - .env() - .resolve_assoc_types(sig, encoder.env().tcx().param_env(proc_def_id)); + let sig = encoder.env().tcx().fn_sig(proc_def_id); + let sig = encoder.env().tcx().subst_and_normalize_erasing_regions( + substs, + encoder.env().tcx().param_env(parent_def_id), + sig, + ); PureFunctionEncoder { encoder, @@ -123,10 +123,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { } pub fn encode_function(&mut self) -> SpannedEncodingResult { - let mir = self - .encoder - .env() - .local_mir(self.proc_def_id.expect_local(), self.substs); + let mir = self.encoder.env().local_mir_with_caller( + self.proc_def_id.expect_local(), + self.parent_def_id.expect_local(), + self.substs, + ); let interpreter = PureFunctionBackwardInterpreter::new( self.encoder, &mir, @@ -263,7 +264,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { if config::check_overflows() { debug_assert!(self.encoder.env().type_is_copy( self.sig.output(), - self.encoder.env().tcx().param_env(self.proc_def_id) + self.encoder.env().tcx().param_env(self.parent_def_id) )); let mut return_bounds: Vec<_> = self .encoder @@ -282,7 +283,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { debug_assert!(self .encoder .env() - .type_is_copy(typ, self.encoder.env().tcx().param_env(self.proc_def_id))); + .type_is_copy(typ, self.encoder.env().tcx().param_env(self.parent_def_id))); let mut bounds = self .encoder .encode_type_bounds(&vir::Expr::local(formal_arg.clone()), typ.skip_binder()); @@ -528,7 +529,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { let ty = self.sig.output(); // Return an error for unsupported return types - let param_env = self.encoder.env().tcx().param_env(self.proc_def_id); + let param_env = self.encoder.env().tcx().param_env(self.parent_def_id); if !self.encoder.env().type_is_copy(ty, param_env) { return Err(SpannedEncodingError::incorrect( "return type of pure function does not implement Copy", @@ -555,7 +556,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { let var_name = format!("{:?}", local); let var_span = self.get_local_span(local); - let param_env = self.encoder.env().tcx().param_env(self.proc_def_id); + let param_env = self.encoder.env().tcx().param_env(self.parent_def_id); if !self.encoder.env().type_is_copy(local_ty, param_env) { return Err(SpannedEncodingError::incorrect( "pure function parameters must be Copy", diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs index 193cd6b3365..8646f053093 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs @@ -3,25 +3,54 @@ use super::encoder::{FunctionCallInfo, FunctionCallInfoHigh, PureFunctionEncoder}; use crate::encoder::{ errors::{SpannedEncodingResult, WithSpan}, - mir::{generics::MirGenericsEncoderInterface, specifications::SpecificationsInterface}, + mir::specifications::SpecificationsInterface, snapshot::interface::SnapshotEncoderInterface, stub_function_encoder::StubFunctionEncoder, }; use log::{debug, trace}; use prusti_common::config; use prusti_interface::data::ProcedureDefId; -use prusti_rustc_interface::middle::ty::subst::SubstsRef; +use prusti_rustc_interface::middle::{ty, ty::subst::SubstsRef}; use rustc_hash::{FxHashMap, FxHashSet}; use prusti_interface::specs::typed::ProcedureSpecificationKind; use std::cell::RefCell; use vir_crate::{common::identifier::WithIdentifier, high as vir_high, polymorphic as vir_poly}; -/// Key of stored call infos, consisting of the DefId of the called function -/// and (the VIR encoding of) the type substitutions applied to it. This means -/// that each generic variant of a pure function will be encoded as a separate -/// pure function in Viper. -type Key = (ProcedureDefId, Vec); +/// Key of stored call infos, consisting of the DefId of the called function, +/// the normalised type substitutions, and the function signature after type +/// substitution and normalisation. The second and third components are stored +/// to account for different monomorphisations resulting from the function +/// being called from callers (with different parameter environments). Each +/// variant of a pure function will be encoded as a separate Viper function. +type Key<'tcx> = ( + ProcedureDefId, + SubstsRef<'tcx>, + &'tcx ty::List>, +); + +/// Compute the key for the given call. +fn compute_key<'v, 'tcx: 'v>( + encoder: &crate::encoder::encoder::Encoder<'v, 'tcx>, + proc_def_id: ProcedureDefId, + caller_def_id: ProcedureDefId, + substs: SubstsRef<'tcx>, +) -> SpannedEncodingResult> { + let tcx = encoder.env().tcx(); + let sig = if tcx.is_closure(proc_def_id) { + substs.as_closure().sig() + } else { + tcx.fn_sig(proc_def_id) + }; + let param_env = tcx.param_env(caller_def_id); + let sig = tcx.subst_and_normalize_erasing_regions(substs, param_env, sig); + let substs = tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + encoder.env().identity_substs(proc_def_id), + ); + Ok((proc_def_id, substs, sig.inputs_and_output().skip_binder())) +} type FunctionConstructor<'v, 'tcx> = Box< dyn FnOnce( @@ -52,22 +81,22 @@ pub(crate) enum PureEncodingContext { #[derive(Default)] pub(crate) struct PureFunctionEncoderState<'v, 'tcx: 'v> { - bodies_high: RefCell>, - bodies_poly: RefCell>, + bodies_high: RefCell, vir_high::Expression>>, + bodies_poly: RefCell, vir_poly::Expr>>, /// Information necessary to encode a function call. FIXME: Remove this one /// and have only call_infos_high. - call_infos_poly: RefCell>, + call_infos_poly: RefCell, FunctionCallInfo>>, /// Information necessary to encode a function call. - call_infos_high: RefCell>, + call_infos_high: RefCell, FunctionCallInfoHigh>>, /// Pure functions whose encoding started (and potentially already /// finished). This is used to break recursion. - pure_functions_encoding_started: RefCell>, + pure_functions_encoding_started: RefCell>>, // A mapping from the function identifier to an information needed to encode // that function. function_descriptions: RefCell>>, /// Mapping from keys on MIR level to function identifiers on VIR level. - function_identifiers: RefCell>, + function_identifiers: RefCell, vir_poly::FunctionIdentifier>>, /// Encoded functions. The encoding of these functions was triggered by the /// definition collector requesting their definition. functions: RefCell>>, @@ -79,6 +108,7 @@ pub(crate) struct PureFunctionEncoderState<'v, 'tcx: 'v> { #[derive(Clone, Debug)] pub(crate) struct FunctionDescription<'tcx> { proc_def_id: ProcedureDefId, + parent_def_id: ProcedureDefId, substs: SubstsRef<'tcx>, } @@ -102,6 +132,7 @@ pub(crate) trait PureFunctionEncoderInterface<'v, 'tcx> { fn encode_pure_function_def( &self, proc_def_id: ProcedureDefId, + parent_def_id: ProcedureDefId, substs: SubstsRef<'tcx>, ) -> SpannedEncodingResult<()>; @@ -159,11 +190,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> parent_def_id: ProcedureDefId, substs: SubstsRef<'tcx>, ) -> SpannedEncodingResult { - let mir_span = self.env().tcx().def_span(proc_def_id); - let substs_key = self - .encode_generic_arguments_high(proc_def_id, substs) - .with_span(mir_span)?; - let key = (proc_def_id, substs_key); + let key = compute_key(self, proc_def_id, parent_def_id, substs)?; if !self .pure_function_encoder_state .bodies_high @@ -186,7 +213,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> .pure_function_encoder_state .bodies_high .borrow_mut() - .insert(key.clone(), body) + .insert(key, body) .is_none()); } Ok(self.pure_function_encoder_state.bodies_high.borrow()[&key].clone()) @@ -200,11 +227,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> parent_def_id: ProcedureDefId, substs: SubstsRef<'tcx>, ) -> SpannedEncodingResult { - let mir_span = self.env().tcx().def_span(proc_def_id); - let substs_key = self - .encode_generic_arguments_high(proc_def_id, substs) - .with_span(mir_span)?; - let key = (proc_def_id, substs_key); + let key = compute_key(self, proc_def_id, parent_def_id, substs)?; if !self .pure_function_encoder_state @@ -227,7 +250,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> self.pure_function_encoder_state .bodies_poly .borrow_mut() - .insert(key.clone(), body); + .insert(key, body); } Ok(self.pure_function_encoder_state.bodies_poly.borrow()[&key].clone()) } @@ -235,6 +258,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> fn encode_pure_function_def( &self, proc_def_id: ProcedureDefId, + parent_def_id: ProcedureDefId, substs: SubstsRef<'tcx>, ) -> SpannedEncodingResult<()> { trace!("[enter] encode_pure_function_def({:?})", proc_def_id); @@ -245,10 +269,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> ); let mir_span = self.env().tcx().def_span(proc_def_id); - let substs_key = self - .encode_generic_arguments_high(proc_def_id, substs) - .with_span(mir_span)?; - let key = (proc_def_id, substs_key); + let key = compute_key(self, proc_def_id, parent_def_id, substs)?; if !self .pure_function_encoder_state @@ -266,13 +287,13 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> self.pure_function_encoder_state .pure_functions_encoding_started .borrow_mut() - .insert(key.clone()); + .insert(key); let mut pure_function_encoder = PureFunctionEncoder::new( self, proc_def_id, PureEncodingContext::Code, - proc_def_id, + parent_def_id, substs, ); @@ -329,6 +350,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> Ok(self.insert_function(function)) })( ); + match maybe_identifier { Ok(identifier) => { self.pure_function_encoder_state @@ -372,6 +394,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> drop(function_descriptions); self.encode_pure_function_def( function_description.proc_def_id, + function_description.parent_def_id, function_description.substs, )?; } else { @@ -392,17 +415,13 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> proc_def_id ); - let mir_span = self.env().tcx().def_span(proc_def_id); - let substs_key = self - .encode_generic_arguments_high(proc_def_id, substs) - .with_span(mir_span)?; - let key = (proc_def_id, substs_key); + let key = compute_key(self, proc_def_id, parent_def_id, substs)?; let mut call_infos = self .pure_function_encoder_state .call_infos_poly .borrow_mut(); - if !call_infos.contains_key(&key) { + if let std::collections::hash_map::Entry::Vacant(e) = call_infos.entry(key) { // Compute information necessary to encode the function call and // memoize it. let pure_function_encoder = PureFunctionEncoder::new( @@ -430,10 +449,11 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> .entry(function_identifier) .or_insert(FunctionDescription { proc_def_id, + parent_def_id, substs, }); - call_infos.insert(key.clone(), function_call_info); + e.insert(function_call_info); } let function_call_info = &call_infos[&key]; @@ -455,17 +475,13 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> proc_def_id ); - let mir_span = self.env().tcx().def_span(proc_def_id); - let substs_key = self - .encode_generic_arguments_high(proc_def_id, substs) - .with_span(mir_span)?; - let key = (proc_def_id, substs_key); + let key = compute_key(self, proc_def_id, parent_def_id, substs)?; let mut call_infos = self .pure_function_encoder_state .call_infos_high .borrow_mut(); - if !call_infos.contains_key(&key) { + if let std::collections::hash_map::Entry::Vacant(e) = call_infos.entry(key) { // Compute information necessary to encode the function call and // memoize it. let function_call_info = super::new_encoder::encode_function_call_info( @@ -493,7 +509,7 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> let _ = self.encode_pure_function_use(proc_def_id, parent_def_id, substs)?; } - call_infos.insert(key.clone(), function_call_info); + e.insert(function_call_info); } let function_call_info = &call_infos[&key]; diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/new_encoder.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/new_encoder.rs index 5f5992f6c96..79c237d4a68 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/new_encoder.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/new_encoder.rs @@ -309,7 +309,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureEncoder<'p, 'v, 'tcx> { let ty = self.sig.output(); let span = self.get_return_span(); - let param_env = self.encoder.env().tcx().param_env(self.proc_def_id); + let param_env = self.encoder.env().tcx().param_env(self.parent_def_id); if !self.encoder.env().type_is_copy(ty, param_env) { return Err(SpannedEncodingError::incorrect( "return type of pure function does not implement Copy", @@ -398,7 +398,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureEncoder<'p, 'v, 'tcx> { /// Encodes a VIR local with the original MIR type. fn encode_mir_local(&self, local: mir::Local) -> SpannedEncodingResult { let ty = self.get_local_ty(local); - let param_env = self.encoder.env().tcx().param_env(self.proc_def_id); + let param_env = self.encoder.env().tcx().param_env(self.parent_def_id); if !self.encoder.env().type_is_copy(ty, param_env) { return Err(SpannedEncodingError::incorrect( "pure function parameters must be Copy", diff --git a/vir/defs/polymorphic/ast/function.rs b/vir/defs/polymorphic/ast/function.rs index b48b1ba17ee..6b29f7bf368 100644 --- a/vir/defs/polymorphic/ast/function.rs +++ b/vir/defs/polymorphic/ast/function.rs @@ -100,14 +100,10 @@ impl Function { pub fn compute_identifier( name: &str, type_arguments: &[Type], - _formal_args: &[LocalVar], - _return_type: &Type, + formal_args: &[LocalVar], + return_type: &Type, ) -> String { let mut identifier = name.to_string(); - // Include the signature of the function in the function name - if !type_arguments.is_empty() { - identifier.push_str("__$TY$__"); - } fn type_name(typ: &Type) -> String { match typ { Type::Int => "$int$".to_string(), @@ -126,10 +122,18 @@ pub fn compute_identifier( ), } } + identifier.push_str("__$TY$__"); + // Include the type parameters of the function in the function name for arg in type_arguments { identifier.push_str(&type_name(arg)); identifier.push('$'); } + // Include the signature of the function in the function name + for arg in formal_args { + identifier.push_str(&type_name(&arg.typ)); + identifier.push('$'); + } + identifier.push_str(&type_name(return_type)); identifier }