diff --git a/plugins/qcheck-stm/src/stm_of_ir.ml b/plugins/qcheck-stm/src/stm_of_ir.ml index 7fc187bb..32541725 100644 --- a/plugins/qcheck-stm/src/stm_of_ir.ml +++ b/plugins/qcheck-stm/src/stm_of_ir.ml @@ -371,7 +371,9 @@ let next_state_case state config state_ident nb_models value = let translate_checks = translate_checks config state value state_ident in - let* checks = map translate_checks value.postcond.checks in + let* checks = + map (fun t -> translate_checks t.term) value.postcond.checks + in match checks with | [] -> ok (idx, new_state) | _ -> @@ -434,12 +436,21 @@ let postcond_case config state invariants idx state_ident new_state_ident value let open Reserr in let translate_postcond t = subst_term state ~gos_t:value.sut_var ~old_t:(Some state_ident) ~new_lz:true - ~new_t:(Some new_state_ident) t + ~new_t:(Some new_state_ident) t.term >>= ocaml_of_term config and translate_invariants id t = subst_term state ~gos_t:id ~old_t:None ~new_t:(Some new_state_ident) ~new_lz:true t >>= ocaml_of_term config + and wrap_check t e = + let term = estring t.text + and cmd = Fmt.str "%a" Ident.pp value.id |> estring + and l = + Fmt.str "%a" Location.print t.Ir.term.Gospel.Tterm.t_loc |> estring + in + pexp_ifthenelse e + (pexp_construct (lident "[]") None) + (Some [%expr [ ([%e cmd], [%e term], [%e l]) ]]) in let idx = List.sort Int.compare idx in let lhs0 = mk_cmd_pattern value in @@ -490,13 +501,14 @@ let postcond_case config state invariants idx state_ident new_state_ident value in aux idx value.postcond.normal in - let* postcond = map translate_postcond normal + let* postcond = map (fun t -> wrap_check t <$> translate_postcond t) normal and* invariants = Option.fold ~none:(ok []) - ~some:(fun (id, xs) -> map (translate_invariants id) xs) + ~some:(fun (id, xs) -> + map (fun t -> wrap_check t <$> translate_invariants id t.term) xs) invariants in - list_and (postcond @ invariants) |> ok + list_concat (postcond @ invariants) |> ok in let res, pat_ret = match value.ret with @@ -515,7 +527,7 @@ let postcond_case config state invariants idx state_ident new_state_ident value case ~lhs:(ppat_construct (lident "Ok") (Some pat_ret)) ~guard:None ~rhs in let* cases_error = - Fun.flip ( @ ) [ case ~lhs:ppat_any ~guard:None ~rhs:(ebool false) ] + Fun.flip ( @ ) [ case ~lhs:ppat_any ~guard:None ~rhs:(elist []) ] <$> map (fun (x, p, t) -> let lhs = @@ -524,7 +536,7 @@ let postcond_case config state invariants idx state_ident new_state_ident value (Option.map Ortac_core.Ocaml_of_gospel.pattern p) in let lhs = ppat_construct (lident "Error") (Some lhs) in - let* rhs = translate_postcond t in + let* rhs = wrap_check t <$> translate_postcond t in case ~lhs ~guard:None ~rhs |> ok) value.postcond.exceptional in @@ -533,22 +545,31 @@ let postcond_case config state invariants idx state_ident new_state_ident value in let* rhs = let translate_checks = translate_checks config state value state_ident in - let* checks = map translate_checks value.postcond.checks in + let* checks = + map + (fun t -> wrap_check t <$> translate_checks t.term) + value.postcond.checks + in match checks with | [] -> ok rhs | _ -> let inv_arg = ppat_construct (lident "Invalid_argument") (Some ppat_any) in - pexp_ifthenelse (list_and checks) rhs - (Some - (pexp_match res - [ - case - ~lhs:(ppat_construct (lident "Error") (Some inv_arg)) - ~guard:None ~rhs:(ebool true); - case ~lhs:ppat_any ~guard:None ~rhs:(ebool false); - ])) + let validate_inv_arg = + pexp_match res + [ + case + ~lhs:(ppat_construct (lident "Error") (Some inv_arg)) + ~guard:None ~rhs:(elist []); + case ~lhs:ppat_any ~guard:None ~rhs:(list_concat checks); + ] + in + pexp_match (list_concat checks) + [ + case ~lhs:(ppat_construct (lident "[]") None) ~guard:None ~rhs; + case ~lhs:ppat_any ~guard:None ~rhs:validate_inv_arg; + ] |> ok in ok (case ~lhs ~guard:None ~rhs) @@ -576,7 +597,7 @@ let postcond config idx ir = let new_state_ident = Ident.create ~loc:Location.none new_state_name in let open Reserr in let* cases = - (Fun.flip ( @ )) [ case ~lhs:ppat_any ~guard:None ~rhs:(ebool true) ] + (Fun.flip ( @ )) [ case ~lhs:ppat_any ~guard:None ~rhs:(elist []) ] <$> map (fun v -> postcond_case config ir.state ir.invariants (List.assoc v.id idx) @@ -584,10 +605,14 @@ let postcond config idx ir = ir.values in let body = - pexp_match (pexp_tuple [ evar cmd_name; evar res_name ]) cases - |> new_state_let - in - let pat = pvar "postcond" in + pexp_open + Ast_helper.(Opn.mk (Mod.ident (lident "Spec"))) + (pexp_open + Ast_helper.(Opn.mk (Mod.ident (lident "STM"))) + (pexp_match (pexp_tuple [ evar cmd_name; evar res_name ]) cases + |> new_state_let)) + in + let pat = pvar "ortac_postcond" in let expr = efun [ @@ -599,6 +624,14 @@ let postcond config idx ir = in pstr_value Nonrecursive [ value_binding ~pat ~expr ] |> ok +let dummy_postcond = + let expr = + efun + [ (Nolabel, ppat_any); (Nolabel, ppat_any); (Nolabel, ppat_any) ] + (ebool true) + and pat = pvar "postcond" in + pstr_value Nonrecursive [ value_binding ~pat ~expr ] + let cmd_constructor value = let name = String.capitalize_ascii value.id.Ident.id_str |> noloc in let args = @@ -757,7 +790,8 @@ let wrapped_init_state config ir = let* invariants = list_and <$> Option.fold ~none:(ok []) - ~some:(fun (id, xs) -> map (translate_invariants id) xs) + ~some:(fun (id, xs) -> + map (fun t -> translate_invariants id t.term) xs) ir.invariants in let msg = @@ -861,8 +895,7 @@ let stm include_ config ir = let open_mod m = pstr_open Ast_helper.(Opn.mk (Mod.ident (lident m))) in let spec_expr = pmod_structure - ([ open_mod "STM"; warn ] - @ incl + ((open_mod "STM" :: incl) @ [ sut; cmd; @@ -874,7 +907,7 @@ let stm include_ config ir = arb_cmd; next_state; precond; - postcond; + dummy_postcond; run; ]) in @@ -905,6 +938,6 @@ let stm include_ config ir = ])] in ok - ([ open_mod module_name ] - @ ghost_functions - @ [ stm_spec; tests; wrapped_init_state; agree_prop; call_tests ]) + ((warn :: open_mod module_name :: ghost_functions) + @ [ stm_spec; tests; wrapped_init_state; postcond; agree_prop; call_tests ] + )