Skip to content

Commit

Permalink
disable optimization of disjointness indexes; treated the same way as…
Browse files Browse the repository at this point in the history
… eloc and inv; things work again
  • Loading branch information
nikswamy committed Dec 2, 2023
1 parent 5ddfb04 commit 935f55a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 79 deletions.
90 changes: 57 additions & 33 deletions src/3d/InterpreterTarget.fst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type disj_pre =
| Disj_triv : disj_pre
| Disj_pair : eloc -> eloc -> disj_pre
| Disj_conj : disj_pre -> disj_pre -> disj_pre
| Disj_name : A.ident -> list expr -> disj_pre

let rec subst_eloc subst (e:eloc)
: eloc
Expand All @@ -67,12 +68,14 @@ let rec subst_disj_index subst (d:disj_pre)
| Disj_conj d1 d2 ->
Disj_conj (subst_disj_index subst d1)
(subst_disj_index subst d2)
| Disj_name hd args ->
Disj_name hd (List.Tot.map (T.subst_expr subst) args)

let disj_conj d0 d1 =
match d0, d1 with
| Disj_triv, d
| d, Disj_triv -> d
| _, _ -> Disj_conj d0 d1
let disj_conj d0 d1 = Disj_conj d0 d1
// match d0, d1 with
// | Disj_triv, d
// | d, Disj_triv -> d
// | _, _ -> Disj_conj d0 d1

noeq
type on_success =
Expand All @@ -84,9 +87,19 @@ let typ_indexes = inv & eloc & disj_pre & on_success
let typ_indexes_nil = Inv_true, Eloc_none, Disj_triv, On_success false
let typ_indexes_union (i, e, d, b) (i', e', d', b') =
Inv_conj i i', Eloc_union e e', disj_conj d d', On_success_union b b'
let typ_indexes_name_except_disj hd args =
let typ_indexes_union_l_bias (i, e, d, b) (i', e', d', b') =
if not (Disj_triv? d)
then failwith "Unexpected disjunctive index in typ_indexes_union_l_bias"
else Inv_conj i i', Eloc_union e e', d', On_success_union b b'
let typ_indexes_union_r_bias (i, e, d, b) (i', e', d', b') =
if not (Disj_triv? d')
then failwith "Unexpected disjunctive index in typ_indexes_union_r_bias"
else Inv_conj i i', Eloc_union e e', d, On_success_union b b'

let typ_indexes_name hd args =
Inv_name hd args,
Eloc_name hd args,
Disj_name hd args,
On_success_named hd args

let env = H.t A.ident' type_decl
Expand Down Expand Up @@ -271,23 +284,24 @@ let rec typ_indexes_of_parser (en:env) (p:T.parser)
| Some td -> td
| _ -> failwith (Printf.sprintf "Type decl not found for %s" (A.ident_to_string hd))
in
let _, _, disj_index_p, _ = td.typ_indexes in
let subst =
match T.mk_subst td.name.td_params args with
| None ->
failwith (Printf.sprintf "Unexpected number of arguments to type %s" (A.ident_to_string td.name.td_name))
| Some s -> s
in
let disj_index = subst_disj_index subst disj_index_p in
let i, e, r = typ_indexes_name_except_disj hd (filter_args_for_inv args td) in
i, e, disj_index, r
// let _, _, disj_index_p, _ = td.typ_indexes in
// let subst =
// match T.mk_subst td.name.td_params args with
// | None ->
// failwith (Printf.sprintf "Unexpected number of arguments to type %s" (A.ident_to_string td.name.td_name))
// | Some s -> s
// in
// let disj_index = subst_disj_index subst disj_index_p in
typ_indexes_name hd (filter_args_for_inv args td)
end

| T.Parse_if_else _ p q
| T.Parse_pair _ p q
| T.Parse_pair _ p q ->
typ_indexes_union (typ_indexes_of_parser p) (typ_indexes_of_parser q)

| T.Parse_dep_pair _ p (_, q)
| T.Parse_dep_pair_with_refinement _ p _ (_, q) ->
typ_indexes_union (typ_indexes_of_parser p) (typ_indexes_of_parser q)
typ_indexes_union_l_bias (typ_indexes_of_parser p) (typ_indexes_of_parser q)

| T.Parse_weaken_left p _
| T.Parse_weaken_right p _
Expand All @@ -300,18 +314,20 @@ let rec typ_indexes_of_parser (en:env) (p:T.parser)

| T.Parse_dep_pair_with_action p (_, a) (_, q)
| T.Parse_dep_pair_with_refinement_and_action _ p _ (_, a) (_, q) ->
typ_indexes_union (typ_indexes_of_parser p)
(typ_indexes_union (typ_indexes_of_action a) (typ_indexes_of_parser q))
typ_indexes_union_l_bias (typ_indexes_of_parser p)
(typ_indexes_union_l_bias (typ_indexes_of_action a) (typ_indexes_of_parser q))

| T.Parse_with_action _ p a ->
typ_indexes_union_r_bias (typ_indexes_of_parser p) (typ_indexes_of_action a)

| T.Parse_with_action _ p a
| T.Parse_with_dep_action _ p (_, a) ->
typ_indexes_union (typ_indexes_of_parser p) (typ_indexes_of_action a)
typ_indexes_union_l_bias (typ_indexes_of_parser p) (typ_indexes_of_action a)

| T.Parse_string p _ ->
typ_indexes_nil

| T.Parse_refinement_with_action n p f (_, a) ->
typ_indexes_union (typ_indexes_of_parser p) (typ_indexes_of_action a)
typ_indexes_union_l_bias (typ_indexes_of_parser p) (typ_indexes_of_action a)

| T.Parse_with_probe p _ _ dest ->
let i, l, d, s = typ_indexes_of_parser p in
Expand Down Expand Up @@ -773,12 +789,12 @@ let rec print_eloc mname (e:eloc)
let rec print_disj mname (d:disj_pre)
: ML string
= match d with
| Disj_triv -> "None"
| Disj_pair i j -> Printf.sprintf "(Some (A.disjoint %s %s))" (print_eloc mname i) (print_eloc mname j)
| Disj_triv -> "disj_none"
| Disj_pair i j -> Printf.sprintf "(A.disjoint %s %s)" (print_eloc mname i) (print_eloc mname j)
| Disj_conj i j -> Printf.sprintf "(join_disj %s %s)" (print_disj mname i) (print_disj mname j)
// | Disj_name hd args -> Printf.sprintf "(%s %s)" (print_derived_name mname "disj" hd) (print_args mname args)
| Disj_name hd args -> Printf.sprintf "(%s %s)" (print_derived_name mname "disj" hd) (print_args mname args)

let print_td_iface is_entrypoint mname root_name binders args typ_indexes_binders typ_indexes_args disj_index ar pk_wk pk_nz =
let print_td_iface is_entrypoint mname root_name binders args typ_indexes_binders typ_indexes_args ar pk_wk pk_nz =
let kind_t =
Printf.sprintf "[@@noextract_to \"krml\"]\n\
inline_for_extraction\n\
Expand All @@ -795,6 +811,13 @@ let print_td_iface is_entrypoint mname root_name binders args typ_indexes_binder
root_name
typ_indexes_binders
in
let disj_t =
Printf.sprintf "[@@noextract_to \"krml\"]\n\
noextract\n\
val disj_%s %s : disj_index"
root_name
typ_indexes_binders
in
let eloc_t =
Printf.sprintf "[@@noextract_to \"krml\"]\n\
noextract\n\
Expand All @@ -805,12 +828,12 @@ let print_td_iface is_entrypoint mname root_name binders args typ_indexes_binder
let def'_t =
Printf.sprintf "[@@noextract_to \"krml\"]\n\
noextract\n\
val def'_%s %s: typ kind_%s (inv_%s %s) (%s) (eloc_%s %s) %b"
val def'_%s %s: typ kind_%s (inv_%s %s) (disj_%s %s) (eloc_%s %s) %b"
root_name
binders
root_name
root_name typ_indexes_args
disj_index //root_name typ_indexes_args
root_name typ_indexes_args
root_name typ_indexes_args
ar
in
Expand All @@ -829,7 +852,7 @@ let print_td_iface is_entrypoint mname root_name binders args typ_indexes_binder
binders
root_name args
in
String.concat "\n\n" [kind_t; inv_t; eloc_t; def'_t; validator_t; dtyp_t]
String.concat "\n\n" [kind_t; inv_t; disj_t; eloc_t; def'_t; validator_t; dtyp_t]

let print_binders mname binders =
List.map (print_param mname) binders |>
Expand Down Expand Up @@ -906,14 +929,15 @@ let print_binding mname (td:type_decl)
(T.print_kind mname k)]
in
let print_inv_or_eloc_or_disj = print_inv_or_eloc_or_disj mname tdn root_name binders in
let typ_indexes_of_binding, disj_index, fv_binders, fv_args =
let typ_indexes_of_binding, fv_binders, fv_args =
let inv, eloc, disj, _ = td.typ_indexes in
let fvs1 = free_vars_of_inv inv in
let fvs2 = free_vars_of_disj disj in
let fvs3 = free_vars_of_eloc eloc in
let s0, _, _ = print_inv_or_eloc_or_disj "inv" None "A.slice_inv" (print_inv mname inv) (fvs1@fvs2@fvs3) in
let s1, _, _ = print_inv_or_eloc_or_disj "disj" None "disj_index" (print_disj mname disj) (fvs1@fvs2@fvs3) in
let s2, fvb, fva = print_inv_or_eloc_or_disj "eloc" None "A.eloc" (print_eloc mname eloc) (fvs1@fvs2@fvs3) in
s0 ^ s2, print_disj mname disj, fvb, fva
s0 ^ s1 ^ s2, fvb, fva
in
let def' =
OS.format
Expand Down Expand Up @@ -1035,7 +1059,7 @@ let print_binding mname (td:type_decl)
let iface =
print_td_iface td.name.td_entrypoint
mname root_name binders args
fv_binders fv_args disj_index td.allow_reading
fv_binders fv_args td.allow_reading
weak_kind k.pk_nz
in
impl, iface
Expand Down
Loading

0 comments on commit 935f55a

Please sign in to comment.