diff --git a/examples/algorithms/unification/triangular/first-order/compilation/tcallUnifScript.sml b/examples/algorithms/unification/triangular/first-order/compilation/tcallUnifScript.sml index 507f3b0819..804623c8bb 100644 --- a/examples/algorithms/unification/triangular/first-order/compilation/tcallUnifScript.sml +++ b/examples/algorithms/unification/triangular/first-order/compilation/tcallUnifScript.sml @@ -454,49 +454,12 @@ Theorem kunifywl_thm = REWRITE_RULE [GSYM kunifywl_def] (CONJ unifywl0_NIL $ cj 3 unifywl0) (* now to do guard-elimination *) - -fun findin f t = - if aconv f t then SOME [] - else - case total dest_comb t of - NONE => NONE - | SOME (t1,t2) => - case findin f t1 of - NONE => NONE - | SOME pfx => SOME (pfx @ [t2]) - - -fun tcallify fn_t inty t = - if TypeBase.is_case t then - let val (f, ts) = strip_comb t - val {Thy,Name,Ty} = dest_thy_const f - val f0 = prim_mk_const{Name=Name,Thy=Thy} - val basety = type_of f0 - val (argtys, rngty) = strip_fun basety - val rng_th = match_type rngty (sumSyntax.mk_sum(inty,type_of t)) - val argty_th = match_type (hd argtys) (type_of (hd ts)) - val ft = Term.inst (rng_th @ argty_th) f0 - val ft1 = mk_comb(ft, hd ts) - val ts' = map (tcallify fn_t inty) (tl ts) - in - list_mk_comb(ft1, ts') - end - else - case dest_term t of - CONST _ => sumSyntax.mk_inr(t,inty) - | VAR _ => sumSyntax.mk_inr(t,inty) - | LAMB(vt,bt) => mk_abs(vt, tcallify fn_t inty bt) - | COMB _ => - case findin fn_t t of - NONE => sumSyntax.mk_inr(t,inty) - | SOME args => sumSyntax.mk_inl(pairSyntax.list_mk_pair args, type_of t) - fun tcallify_th th = let val (l,r) = dest_eq (concl th) val (lf, args) = strip_comb l val atup = pairSyntax.list_mk_pair args val inty = type_of atup - val body_t = tcallify lf inty r + val body_t = tailrecLib.mk_sum_term lf inty r in pairSyntax.mk_pabs(atup, body_t) end diff --git a/src/num/theories/cv_compute/tailrecLib.sig b/src/num/theories/cv_compute/tailrecLib.sig index 971d3e0b03..dd0a9a75bd 100644 --- a/src/num/theories/cv_compute/tailrecLib.sig +++ b/src/num/theories/cv_compute/tailrecLib.sig @@ -3,7 +3,17 @@ sig include Abbrev + val mk_sum_term : term -> hol_type -> term -> term + val tailrec_define : string -> term -> thm val prove_tailrec_exists : term -> thm end + +(* [mk_sum_term fnt inty t] generates an abstraction term c that can be an + argument to TAILREC (or TAILCALL) such that ("roughly") + + TAILCALL c fnt x = fnt x + + The argument inty is type of the argument to fnt (x above) +*) diff --git a/src/num/theories/cv_compute/tailrecLib.sml b/src/num/theories/cv_compute/tailrecLib.sml index a3dfdce2c2..60338e6a0c 100644 --- a/src/num/theories/cv_compute/tailrecLib.sml +++ b/src/num/theories/cv_compute/tailrecLib.sml @@ -3,6 +3,9 @@ struct open HolKernel Parse boolLib simpLib boolSimps +fun mk_HOL_ERR f msg = HOL_ERR {origin_structure = "tailrecLib", + origin_function = f, message = msg} + val Cases = BasicProvers.Cases val PairCases = pairLib.PairCases @@ -29,6 +32,36 @@ val TAILREC_def = whileTheory.TAILREC |> CONV_RULE (DEPTH_CONV ETA_CONV) |> REWRITE_RULE [GSYM combinTheory.I_EQ_IDABS]; +fun mk_sum_term fn_t inty tm = + let + fun build_sum t = + if TypeBase.is_case t then + let val (a,b,xs) = TypeBase.dest_case t + val ys = map (apsnd build_sum) xs + in + TypeBase.mk_case (b,ys) + end + else if can pairSyntax.dest_anylet t then + let val (xs,x) = pairSyntax.dest_anylet t + in pairSyntax.mk_anylet(xs,build_sum x) end + else if cvSyntax.is_cv_if tm then + let val (b,x,y) = cvSyntax.dest_cv_if tm + in mk_cond(cvSyntax.mk_c2b b,build_sum x,build_sum y) end + else + let val (f, xs) = strip_comb t + in + if aconv f fn_t then + if null xs then raise mk_HOL_ERR "mk_sum_term" "malformed term" + else + sumSyntax.mk_inl (pairSyntax.list_mk_pair xs, type_of t) + else if is_abs t then + mk_abs (apsnd build_sum (dest_abs t)) + else sumSyntax.mk_inr(t,inty) + end + in + build_sum tm + end + fun prove_simple_tailrec_exists tm = let val (l,r) = dest_eq tm val (f_tm,arg_tm) = dest_comb l @@ -39,26 +72,7 @@ fun prove_simple_tailrec_exists tm = let fun mk_inl x = sumSyntax.mk_inl(x,output_ty) fun mk_inr x = sumSyntax.mk_inr(x,input_ty) (* building the witness *) - fun build_sum tm = - if is_comb tm andalso aconv (rator tm) f_tm then - mk_inl (rand tm) - else if List.all (not o aconv f_tm) (free_vars tm) then - mk_inr tm - else if is_cond tm then let - val (b,x,y) = dest_cond tm - in mk_cond(b,build_sum x,build_sum y) end - else if cvSyntax.is_cv_if tm then let - val (b,x,y) = cvSyntax.dest_cv_if tm - in mk_cond(cvSyntax.mk_c2b b,build_sum x,build_sum y) end - else if can pairSyntax.dest_anylet tm then let - val (xs,x) = pairSyntax.dest_anylet tm - in pairSyntax.mk_anylet(xs,build_sum x) end - else if TypeBase.is_case tm then let - val (a,b,xs) = TypeBase.dest_case tm - val ys = map (fn (x,tm) => (x,build_sum tm)) xs - in TypeBase.mk_case(b,ys) end - else failwith ("Unsupported: " ^ term_to_string tm) - val sum_tm = build_sum r + val sum_tm = mk_sum_term f_tm input_ty r val abs_sum_tm = pairSyntax.mk_pabs(arg_tm,sum_tm) val witness = ISPEC abs_sum_tm whileTheory.TAILREC |> SPEC_ALL |> concl |> dest_eq |> fst |> rator