Skip to content

Commit

Permalink
Add new definition mechanism for (mutrec) tailrec functions
Browse files Browse the repository at this point in the history
The new automation has two parts: the first part proves
that a tail-recursive function exists; the second part
uses new_specification to define such a function.

The given equations must have only curried variable
arguments left of the equality, e.g.

    foo m n = ...

is allowed, but the following is not:

    foo (m, n) = ...

This commit also renames:

    examples/machine-code/hoare-triple/tailrecLib.{sml,sig}
    ->
    examples/machine-code/hoare-triple/mc_tailrecLib.{sml,sig}

Here's an example use of the new definition mechanism:

    val _ = List.map Parse.hide ["foo","bar"];

    val foo_def = tailrec_define "foo_def"
      “(foo m n = if m = (n:num) then bar m (SOME 8) else bar 4 NONE) ∧
       (bar k l = case l of
                  | NONE => k - 6
                  | SOME i =>
                      let (q,r) = ARB i
                      and (t,w,a) = ARB k l in
                        foo (q + r) (t + w + a))”;
  • Loading branch information
myreen authored and mn200 committed Nov 27, 2023
1 parent 77e242f commit a1129a6
Show file tree
Hide file tree
Showing 19 changed files with 252 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/machine-code/compiler/compilerLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ open reg_allocLib;
open prog_armLib prog_ppcLib prog_x86Lib prog_x64Lib;
open wordsTheory wordsLib addressTheory;
open helperLib;
open tailrecLib;
open mc_tailrecLib;
structure Parse = struct
open Parse
val (Type,Term) =
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/decompiler/decompilerLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,7 @@ fun extract_function name th entry exit function_in_out = let
val func_name = name
val tm_option = NONE
val (main_thm,main_def,pre_thm,pre_def) =
tailrecLib.tailrec_define_from_step func_name step_fun tm_option
mc_tailrecLib.tailrec_define_from_step func_name step_fun tm_option
val finalise =
CONV_RULE (REMOVE_TAGS_CONV THENC DEPTH_CONV (LET_EXPAND_POS_CONV))
val main_thm = finalise main_thm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
open HolKernel boolLib bossLib Parse;
open wordsTheory;
open decompilerLib;
open tailrecLib listTheory pred_setTheory arithmeticTheory;
open mc_tailrecLib listTheory pred_setTheory arithmeticTheory;

val decompile_arm = decompile prog_armLib.arm_tools;
val decompile_ppc = decompile prog_ppcLib.ppc_tools;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ val _ = ParseExtras.temp_loose_equality()
open wordsTheory arithmeticTheory wordsLib listTheory pred_setTheory pairTheory;
open combinTheory finite_mapTheory;

open addressTheory tailrecLib tailrecTheory;
open addressTheory mc_tailrecLib tailrecTheory;
open cheney_gcTheory cheney_allocTheory arm_cheney_gcTheory;


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ open decompilerLib prog_armLib;

open wordsTheory arithmeticTheory wordsLib listTheory pred_setTheory pairTheory;
open combinTheory finite_mapTheory addressTheory;
open tailrecLib tailrecTheory;
open mc_tailrecLib tailrecTheory;
open cheney_gcTheory; (* an abstract implementation is imported *)

val decompile_arm = decompile arm_tools;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ val gc_step_def = Define `
let i = i + n + 1 in
(i,j,m)`;

val gc_loop_def = tailrecLib.tailrec_define ``
val gc_loop_def = mc_tailrecLib.tailrec_define ``
gc_loop (i,j,m) = if i = j then (i,m) else
let (i,j,m) = gc_step (i,j,m) in
gc_loop (i,j,m)``;
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/garbage-collectors/lisp_gcScript.sml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ open decompilerLib compilerLib prog_armLib;
open wordsTheory arithmeticTheory wordsLib listTheory pred_setTheory pairTheory;
open combinTheory finite_mapTheory addressTheory;

open tailrecLib tailrecTheory;
open mc_tailrecLib tailrecTheory;
open cheney_gcTheory cheney_allocTheory; (* an abstract implementation is imported *)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
signature tailrecLib =
signature mc_tailrecLib =
sig

include Abbrev
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
structure tailrecLib :> tailrecLib =
structure mc_tailrecLib :> mc_tailrecLib =
struct

open HolKernel boolLib bossLib Parse;
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/lisp/divideScript.sml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

open HolKernel boolLib bossLib Parse;
open tailrecTheory tailrecLib compilerLib codegen_x86Lib;
open tailrecTheory mc_tailrecLib compilerLib codegen_x86Lib;
open wordsTheory addressTheory wordsLib arithmeticTheory;

open decompilerLib set_sepTheory prog_x86Lib;
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/lisp/lisp_equalScript.sml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open wordsTheory arithmeticTheory wordsLib listTheory pred_setTheory pairTheory;
open combinTheory finite_mapTheory addressTheory;

open decompilerLib compilerLib;
open tailrecLib tailrecTheory cheney_gcTheory cheney_allocTheory;
open mc_tailrecLib tailrecTheory cheney_gcTheory cheney_allocTheory;
open lisp_gcTheory lisp_typeTheory lisp_invTheory;


Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/lisp/lisp_printScript.sml
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ val (thms,arm_set_return_def,arm_set_return_pre_def) = compile_all ``
let r4 = 5w:word32 in
(r4,r5,r8,dh,h)``;

(* val (arm_print_loop_aux_def,arm_print_loop_aux_pre_def) = tailrecLib.tailrec_define `` *)
(* val (arm_print_loop_aux_def,arm_print_loop_aux_pre_def) = mc_tailrecLib.tailrec_define *)
val (thms,arm_print_loop_aux_def,arm_print_loop_aux_pre_def) = compile_all ``
arm_print_loop_aux (r3:word32,r4:word32,r7:word32,r8:word32,dh:word32 set,h:word32->word32,df:word32 set,f:word32->word8) =
let (r3,r4,r5) = arm_print_exit(r3,r4) in
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/multiword/mc_multiwordScript.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ open HolKernel Parse boolLib bossLib;
open multiwordTheory helperLib;
open wordsTheory wordsLib addressTheory arithmeticTheory listTheory pairSyntax;
open addressTheory pairTheory set_sepTheory rich_listTheory integerTheory;
local open tailrecLib blastLib intLib in end
local open mc_tailrecLib blastLib intLib in end

val _ = new_theory "mc_multiword";
val _ = ParseExtras.temp_loose_equality()
Expand Down
2 changes: 1 addition & 1 deletion examples/machine-code/x64_compiler/x64_compilerLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ open prog_x64Lib;
open prog_x64_extraTheory;
open wordsTheory wordsLib addressTheory;
open helperLib;
open tailrecLib;
open mc_tailrecLib;


fun AUTO_ALPHA_CONV () = let
Expand Down
4 changes: 4 additions & 0 deletions src/num/theories/cv_compute/cvSyntax.sig
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ sig
val cv_lt_tm : term;
val cv_if_tm : term;
val cv_eq_tm : term;
val c2b_tm : term;

val mk_cv_pair : term * term -> term;
val mk_cv_num : term -> term;
Expand All @@ -32,6 +33,7 @@ sig
val mk_cv_lt : term * term -> term;
val mk_cv_if : term * term * term -> term;
val mk_cv_eq : term * term -> term;
val mk_c2b : term -> term;

val dest_cv_pair : term -> term * term;
val dest_cv_num : term -> term;
Expand All @@ -46,6 +48,7 @@ sig
val dest_cv_lt : term -> term * term;
val dest_cv_if : term -> term * term * term;
val dest_cv_eq : term -> term * term;
val dest_c2b : term -> term;

val is_cv_pair : term -> bool;
val is_cv_num : term -> bool;
Expand All @@ -60,5 +63,6 @@ sig
val is_cv_lt : term -> bool;
val is_cv_if : term -> bool;
val is_cv_eq : term -> bool;
val is_c2b : term -> bool;

end (* signature *)
4 changes: 4 additions & 0 deletions src/num/theories/cv_compute/cvSyntax.sml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct
val cv_lt_tm = prim_mk_const {Name="cv_lt", Thy="cv"};
val cv_if_tm = prim_mk_const {Name="cv_if", Thy="cv"};
val cv_eq_tm = prim_mk_const {Name="cv_eq", Thy="cv"};
val c2b_tm = prim_mk_const {Name="c2b", Thy="cv"};

(* -------------------------------------------------------------------------
* Constructors
Expand All @@ -48,6 +49,7 @@ struct
val mk_cv_lt = mk_binop cv_lt_tm;
val mk_cv_if = mk_triop cv_if_tm;
val mk_cv_eq = mk_binop cv_eq_tm;
val mk_c2b = mk_monop c2b_tm;

(* -------------------------------------------------------------------------
* Destructors
Expand Down Expand Up @@ -91,6 +93,7 @@ struct
val dest_cv_lt = dest_binop cv_lt_tm;
val dest_cv_if = dest_triop cv_if_tm;
val dest_cv_eq = dest_binop cv_eq_tm;
val dest_c2b = dest_monop c2b_tm;

(* -------------------------------------------------------------------------
* Recognizers
Expand All @@ -109,5 +112,6 @@ struct
val is_cv_lt = can (dest_binop cv_lt_tm);
val is_cv_if = can (dest_triop cv_if_tm);
val is_cv_eq = can (dest_binop cv_eq_tm);
val is_c2b = can (dest_monop c2b_tm);

end (* struct *)
9 changes: 9 additions & 0 deletions src/num/theories/cv_compute/tailrecLib.sig
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
signature tailrecLib =
sig

include Abbrev

val tailrec_define : string -> term -> thm
val prove_tailrec_exists : term -> thm

end
196 changes: 196 additions & 0 deletions src/num/theories/cv_compute/tailrecLib.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
structure tailrecLib :> tailrecLib =
struct

open HolKernel Parse boolLib bossLib;

(*----------------------------------------------------------------------*
Miscellaneous helper functions
*----------------------------------------------------------------------*)

fun list_dest_conj tm =
if is_conj tm then let
val (x,y) = dest_conj tm
in list_dest_conj x @ list_dest_conj y end
else [tm];

fun list_dest_exists tm = let
val (v,y) = dest_exists tm
val (vs,t) = list_dest_exists y
in (v::vs,t) end
handle HOL_ERR _ => ([],tm);

fun list_mk_pair_case pat r =
if not (pairSyntax.is_pair pat) then (pat,r) else let
val v = genvar (type_of pat)
val (x1,rest_pat) = pairSyntax.dest_pair pat
val (y1,r1) = list_mk_pair_case rest_pat r
val new_pat = pairSyntax.mk_pair(x1,y1)
in (v,TypeBase.mk_case(v,[(new_pat,r1)])) end

fun auto_prove goal_tm (tac:tactic) = snd (tac ([],goal_tm)) [];

(*----------------------------------------------------------------------*
Function for proving that non-mutually recursive tail-recursive
functions exist. The input function can only take one argument.
*----------------------------------------------------------------------*)

val TAILREC_def = whileTheory.TAILREC
|> CONV_RULE (DEPTH_CONV ETA_CONV)
|> REWRITE_RULE [GSYM combinTheory.I_EQ_IDABS];

fun prove_simple_tailrec_exists tm = let
val (l,r) = dest_eq tm
val (f_tm,arg_tm) = dest_comb l
val arg_tms = if is_var arg_tm then [arg_tm] else free_vars arg_tm
val goal_tm = mk_exists(f_tm,list_mk_forall(arg_tms,tm))
val input_ty = type_of arg_tm
val output_ty = type_of r
fun mk_inl x = sumSyntax.mk_inl(x,output_ty)
fun mk_inr x = sumSyntax.mk_inr(x,input_ty)
(* building the witness *)
fun build_sum tm =
if is_comb tm andalso aconv (rator tm) f_tm then
mk_inl (rand tm)
else if List.all (not o aconv f_tm) (free_vars tm) then
mk_inr tm
else if is_cond tm then let
val (b,x,y) = dest_cond tm
in mk_cond(b,build_sum x,build_sum y) end
else if cvSyntax.is_cv_if tm then let
val (b,x,y) = cvSyntax.dest_cv_if tm
in mk_cond(cvSyntax.mk_c2b b,build_sum x,build_sum y) end
else if can pairSyntax.dest_anylet tm then let
val (xs,x) = pairSyntax.dest_anylet tm
in pairSyntax.mk_anylet(xs,build_sum x) end
else if TypeBase.is_case tm then let
val (a,b,xs) = TypeBase.dest_case tm
val ys = map (fn (x,tm) => (x,build_sum tm)) xs
in TypeBase.mk_case(b,ys) end
else failwith ("Unsupported: " ^ term_to_string tm)
val sum_tm = build_sum r
val abs_sum_tm = pairSyntax.mk_pabs(arg_tm,sum_tm)
val witness = ISPEC abs_sum_tm whileTheory.TAILREC |> SPEC_ALL
|> concl |> dest_eq |> fst |> rator
fun sum_case_exp tm = tm |> rator |> rator |> rand
val sum_case_exp_conv = RATOR_CONV o RATOR_CONV o RAND_CONV
(* tactic to solve goal *)
fun tailrec_tac (assum_tms,goal_tm) =
if (goal_tm |> dest_eq |> fst |> sum_case_exp |> sumSyntax.is_inr) then
REWRITE_TAC [sumTheory.sum_case_def,combinTheory.I_THM]
(assum_tms,goal_tm)
else if (goal_tm |> dest_eq |> fst |> sum_case_exp |> sumSyntax.is_inl) then
REWRITE_TAC [sumTheory.sum_case_def,combinTheory.I_THM]
(assum_tms,goal_tm)
else if cvSyntax.is_cv_if (rand goal_tm) then
(CONV_TAC (RAND_CONV (REWR_CONV cvTheory.cv_if)) THEN tailrec_tac)
(assum_tms,goal_tm)
else if can pairSyntax.dest_anylet (goal_tm |> rand) then let
val xs = pairSyntax.dest_anylet (goal_tm |> rand) |> fst
val vs = xs |> map (fn (x,y) => (y,genvar (type_of y)))
val specs = foldl (fn (x,t) => SPEC_TAC x THEN t) ALL_TAC vs
val gens = foldr (fn ((_,x),t) =>
if can pairSyntax.dest_prod (type_of x) then PairCases THEN t
else gen_tac THEN t) ALL_TAC vs
fun expand_lets 0 = ALL_CONV
| expand_lets 1 = (REWR_CONV LET_THM THENC PairRules.PBETA_CONV)
| expand_lets n = ((RATOR_CONV o RAND_CONV) (expand_lets (n-1))
THENC expand_lets 1)
val exp_conv = expand_lets (length vs)
val exp_both_conv = RAND_CONV exp_conv THENC
(RATOR_CONV o RAND_CONV o sum_case_exp_conv) exp_conv
in (specs THEN gens THEN CONV_TAC exp_both_conv)
(assum_tms,goal_tm) end
else if TypeBase.is_case (rand goal_tm) then let
val (a,b,xs) = TypeBase.dest_case (rand goal_tm)
val ty = type_of b
val new_v = genvar ty
val case_def = TypeBase.case_def_of ty
in (SPEC_TAC (b,new_v) THEN Cases
THEN PURE_ONCE_REWRITE_TAC [case_def]
THEN CONV_TAC (DEPTH_CONV BETA_CONV))
(assum_tms,goal_tm) end
else NO_TAC (assum_tms,goal_tm);
(* prove main theorem *)
val tac = exists_tac witness
THEN rpt gen_tac
THEN CONV_TAC ((RATOR_CONV o RAND_CONV) (REWR_CONV TAILREC_def))
THEN SPEC_TAC (witness, genvar (type_of witness))
THEN gen_tac
THEN PURE_REWRITE_TAC [boolTheory.literal_case_DEF]
THEN CONV_TAC (DEPTH_CONV PairRules.PBETA_CONV)
THEN rpt (tailrec_tac \\ rpt conj_tac)
val lemma = auto_prove goal_tm tac;
in lemma end;

(*----------------------------------------------------------------------*
Function for proving that mutually recursive tail-recursive
functions exist. One equation per function. The arguments on the
LHS of each equation must be variables.
*----------------------------------------------------------------------*)

fun prove_tailrec_exists def_tm = let
val defs = list_dest_conj def_tm
(* build the goal to prove *)
fun extract_def def_tm = let
val (l,r) = dest_eq def_tm
val (f_tm,args) = strip_comb l
val _ = List.all is_var args orelse failwith "bad input"
in (f_tm,list_mk_forall(args,def_tm)) end
val xs = map extract_def defs
val goal_tm = list_mk_exists(map fst xs, list_mk_conj(map snd xs))
(* build one function *)
val fs_args = map (strip_comb o fst o dest_eq) defs
val tuples = map (fn (f,args) => pairSyntax.list_mk_pair args) fs_args
val rhs_list = map (snd o dest_eq) defs
fun build_sum_ty [] = fail()
| build_sum_ty [(tm,r)] =
if is_var tm then (type_of tm,[tm],tm,r) else let
val (v,rhs_tm) = list_mk_pair_case tm r
in (type_of tm,[tm],v,rhs_tm) end
| build_sum_ty ((tm,r)::tms) = let
val (ty1,calls1,v1,r1) = build_sum_ty [(tm,r)]
val (ty2,calls2,v2,r2) = build_sum_ty tms
val ty = sumSyntax.mk_sum(ty1,ty2)
val v = genvar ty
val pat1 = sumSyntax.mk_inl(v1,ty2)
val pat2 = sumSyntax.mk_inr(v2,ty1)
in (ty,
map (fn x => sumSyntax.mk_inl(x,ty2)) calls1 @
map (fn x => sumSyntax.mk_inr(x,ty1)) calls2,
v,
TypeBase.mk_case(v,[(pat1,r1),(pat2,r2)])) end
val (input_ty, call_tms, arg_tm, rhs_tm) = build_sum_ty (zip tuples rhs_list)
val output_ty = defs |> hd |> dest_eq |> snd |> type_of
val combined_var_tm = mk_var("combined", input_ty --> output_ty)
val lhs_tm = mk_comb(combined_var_tm, arg_tm)
val calls =
map2 (fn (f,args) => fn call_tm =>
(f, list_mk_abs(args,mk_comb(combined_var_tm,call_tm))))
fs_args call_tms
val fixed_rhs_tm =
rhs_tm |> subst (map (fn (x,y) => x |-> y) calls)
|> QCONV (DEPTH_CONV BETA_CONV) |> concl |> rand
val combined_eq = mk_eq(lhs_tm, fixed_rhs_tm)
val combined_th = prove_simple_tailrec_exists combined_eq
(* prove defining theorem *)
val exists = foldr (fn (x,t) => EXISTS_TAC x THEN t) ALL_TAC (map snd calls)
val tac =
strip_assume_tac combined_th THEN exists
THEN CONV_TAC (DEPTH_CONV BETA_CONV)
THEN rpt conj_tac THEN rpt gen_tac
THEN pop_assum (fn th => CONV_TAC ((RATOR_CONV o RAND_CONV) (REWR_CONV th)))
THEN SIMP_TAC bool_ss [sumTheory.sum_case_def,pairTheory.pair_case_def]
val lemma = auto_prove goal_tm tac
in lemma end

(*----------------------------------------------------------------------*
Defines tail-recursive functions based on the existance proofs
that the above function can prove. Same restrictions apply.
*----------------------------------------------------------------------*)

fun tailrec_define name def_tm = let
val lemma = prove_tailrec_exists def_tm
val names = lemma |> concl |> list_dest_exists |> fst |> map (fst o dest_var)
in new_specification(name,names,lemma) end

end
Loading

0 comments on commit a1129a6

Please sign in to comment.