Skip to content

Commit

Permalink
Add foundation for faster z3 model importing
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaslindner committed Aug 29, 2024
1 parent 87157d0 commit ce10d5c
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 17 deletions.
23 changes: 23 additions & 0 deletions src/shared/examples/test-z3_wrapper.sml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,29 @@ in
end;
end;

(*
val use_holsmt = true;
val name = "simple addition";
val query =
``
((x:word32) + y = 10w)
``;
val name = "simple contradiction";
val query =
``
((x:word32) + y = 10w) /\
((x:word32) + y = 11w)
``;
val use_holsmt = false;
open bslSyntax;
val name = "simple addition bir";
val query = beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 10);
val name = "simple contradiction bir";
val query = band (beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 10),
beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 11));
*)
val _ = List.map (fn (name, query) =>
let
val _ = print ("\n\n=============== >>> RUNNING TEST CASE '" ^ name ^ "'\n");
Expand Down
19 changes: 18 additions & 1 deletion src/shared/smt/bir_smtLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,28 @@ fun bir_smt_set_trace use_holsmt =
else
(fn _ => ());

(* TODO: should not be operating on word expressions in this library, just bir expressions *)
fun bir_smt_get_model use_holsmt =
if use_holsmt then
Z3_SAT_modelLib.Z3_GET_SAT_MODEL
else
raise ERR "bir_smt_get_model" "not implemented";
let
open holba_z3Lib;
open bir_smtlibLib;
in
(fn bexp =>
let
val _ = if type_of bexp = bir_expSyntax.bir_exp_t_ty then () else
raise ERR "bir_smt_get_model" "need a bir expression";
val exst = export_bexp bexp exst_empty;
val q = querysmt_mk_q (exst_to_querysmt exst);
val (res, model) = querysmt_getmodel q;
val _ = if res = BirSmtSat then () else
raise ERR "bir_smt_get_model" "unsatisfiable";
in
smtmodel_to_wordfmap model
end)
end;

(* ======================================= *)

Expand Down
15 changes: 14 additions & 1 deletion src/shared/smt/bir_smtlibLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,20 @@ BExp_Store (BExp_Den (BVar "fr_269_MEM" (BType_Mem Bit32 Bit8)))
exst
end);

(* TODO: add a model importer *)
local
(* TODO: need to add conversion from word to bir: values are constant bitvector/imm or constant array/memory *)
(* TODO: also need variable name conversion, holv_ are going to be words, birv_ would be bir constant expressions *)
fun modellines_to_pairs [] acc = acc
| modellines_to_pairs [_] _ = raise ERR "modellines_to_pairs" "the returned model does not have an even number of lines"
| modellines_to_pairs (vname::holterm::lines) acc =
modellines_to_pairs lines ((vname, Parse.Term [QUOTE holterm])::acc);
open wordsSyntax;
open finite_mapSyntax;
in
fun smtmodel_to_wordfmap model =
rev (modellines_to_pairs model []);
(*fun smtmodel_to_bexp model = ;*)
end

end (* local *)

Expand Down
91 changes: 89 additions & 2 deletions src/shared/smt/holba_z3Lib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@ val z3bin = "/home/andreas/data/hol/HolBA_opt/z3-4.8.4/bin/z3";
fun openz3 z3bin =
(Unix.execute (z3bin, ["-in"])) : (TextIO.instream, TextIO.outstream) Unix.proc;

(*
val z3wrap = "/home/andreas/data/hol/HolBA_symbexec/src/shared/smt/z3_wrapper.py";
val prelude_path = "/home/andreas/data/hol/HolBA_symbexec/src/shared/smt/holba_z3Lib_prelude.z3";
*)
fun openz3wrap z3wrap prelude_path =
(Unix.execute (z3wrap, [prelude_path, "loop"])) : (TextIO.instream, TextIO.outstream) Unix.proc;

fun endmeexit p = Unix.fromStatus (Unix.reap p);

fun get_streams p = Unix.streamsOf p;

val z3proc_bin_o = ref (NONE : string option);
val z3proc_o = ref (NONE : ((TextIO.instream, TextIO.outstream) Unix.proc) option);
val prelude_z3 = read_from_file (holpathdb.subst_pathvars "$(HOLBADIR)/src/shared/smt/holba_z3Lib_prelude.z3");
val prelude_z3_path = holpathdb.subst_pathvars "$(HOLBADIR)/src/shared/smt/holba_z3Lib_prelude.z3";
val prelude_z3 = read_from_file prelude_z3_path;
val prelude_z3_n = prelude_z3 ^ "\n";
val use_stack = true;
val debug_print = false;
Expand Down Expand Up @@ -59,6 +67,25 @@ fun get_z3proc z3bin =
p
end;

val z3wrapproc_o = ref (NONE : ((TextIO.instream, TextIO.outstream) Unix.proc) option);
fun get_z3wrapproc () =
let
val z3wrapproc_ = !z3wrapproc_o;
val p = if isSome z3wrapproc_ then valOf z3wrapproc_ else
let
val z3wrap = case OS.Process.getEnv "HOL4_Z3_WRAPPED_EXECUTABLE" of
SOME x => x
| NONE => raise ERR "get_z3wrapproc" "variable HOL4_Z3_WRAPPED_EXECUTABLE not defined";
val _ = if not debug_print then () else
print ("starting: " ^ z3wrap ^ "\n");
val p = openz3wrap z3wrap prelude_z3_path;
in (z3wrapproc_o := SOME p; p) end;
in
p
end;

(* =========================================================== *)

fun inputLines_until m ins acc =
let
val line_o = TextIO.inputLine ins;
Expand Down Expand Up @@ -102,6 +129,34 @@ fun sendreceive_query z3bin q =
in
out_lines
end;

fun sendreceive_wrap_query q =
let
val p = get_z3wrapproc ();
val (s_in,s_out) = get_streams p;

val q_fixed = String.concat (List.map (fn c => if c = #"\n" then "\\n" else str c) (String.explode q));
val _ = if not debug_print then () else
(print "sending: "; print q_fixed; print "\n");

val timer = holba_miscLib.timer_start 0;
val z3wrap_done_marker = "z3_wrapper query done";
val () = TextIO.output (s_out, q_fixed ^ "\n");
val out_lines = inputLines_until (z3wrap_done_marker ^ "\n") s_in [];
val _ = if debug_print then holba_miscLib.timer_stop
(fn delta_s => print (" wrapped query took " ^ delta_s ^ "\n")) timer else ();

val _ = if not debug_print then () else
(map print out_lines; print "\n\n");
in
out_lines
end;
(*
val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n";
val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xAA))\n(assert (= x #xFF))\n";
sendreceive_wrap_query q;
*)
(* =========================================================== *)

datatype bir_smt_result =
Expand Down Expand Up @@ -134,6 +189,9 @@ fun sendreceive_query z3bin q =
out_lines
end;

fun querysmt_prepare_getmodel z3bin_o =
querysmt_raw z3bin_o NONE "(set-option :model.compact false)\n";

(*
querysmt_raw NONE NONE "(simplify ((_ extract 3 2) #xFC))";
Expand All @@ -155,6 +213,24 @@ querysmt_raw NONE NONE "(display (_ bv20 16))"
print "\n============================\n";
raise ERR "querysmt_parse_checksat" "unknown output from z3");

fun querysmt_parse_getmodel out_lines =
if hd out_lines = "sat\n" then
let
val model_lines = tl out_lines;
val model_lines_fix = map (fn line => if (hd o rev o explode) line = #"\n" then (implode o rev o tl o rev o explode) line else line) model_lines;
in
(BirSmtSat, model_lines_fix)
end
else if hd out_lines = "unsat\n" then
(BirSmtUnsat, [])
else if hd out_lines = "unknown\n" then
(BirSmtUnknown, [])
else
(print "\n============================\n";
map print out_lines;
print "\n============================\n";
raise ERR "querysmt_parse_getmodel" "unknown output from z3");

(* https://rise4fun.com/z3/tutorial *)
(*
val q = "(declare-const a Int)\n" ^
Expand Down Expand Up @@ -182,13 +258,23 @@ querysmt_raw NONE NONE "(display (_ bv20 16))"
val q = "(check-sat)\n";
val result = querysmt_parse_checksat (querysmt_raw NONE NONE q);
val result = (querysmt_raw NONE NONE (q^"(get-model)\n"));
*)

fun querysmt_checksat_gen z3bin_o timeout_o q =
querysmt_parse_checksat (querysmt_raw z3bin_o timeout_o (q ^ "(check-sat)\n"));
val querysmt_checksat = querysmt_checksat_gen NONE;

(* TODO: add querysmt_getmodel *)
fun querysmt_getmodel q =
querysmt_parse_getmodel (sendreceive_wrap_query q);

(*
val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n";
val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xAA))\n(assert (= x #xFF))\n";
querysmt_checksat NONE q
querysmt_getmodel q
*)

(* ------------------------------------------------------------------------ *)

Expand Down Expand Up @@ -387,6 +473,7 @@ fun gen_smt_store_as_funcall valm valad valv opparam =
[("(= x #xFF)", SMTTY_Bool), ("(= x #xAA)", SMTTY_Bool)]);
querysmt_checksat NONE q
querysmt_getmodel q
*)

end (* local *)
Expand Down
79 changes: 66 additions & 13 deletions src/shared/smt/z3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ def strip_z3_name(x):
return len(x.split('_', maxsplit=1)) > 1 and x.split('_', maxsplit=1)[1] or x.split('_', maxsplit=1)[0]

# create list of string pairs from model: (varname, holterm)
def model_to_list(model):
def model_to_list(model, strip_names):
# map to pair (model variables (stripped hol name), variables value) and filter auxiliary assignments
assigns_pre = filter (lambda x: not ("!" in x[0]), map(lambda x: (strip_z3_name(str(x.name())), model[x]), model))
stripfun = strip_z3_name if strip_names else (lambda x: x)
assigns_pre = filter (lambda x: not ("!" in x[0]), map(lambda x: (stripfun(str(x.name())), model[x]), model))

# partition, sort individually, put together again
assign_ast = []
Expand All @@ -234,21 +235,79 @@ def model_to_list(model):
# return the collected hol assignments
return sml_list

def print_model_for_holba(model, strip_names = True):
hol_list = model_to_list(model, strip_names)

for (varname, term) in hol_list:
print(varname)
print(term)
#print("on stdout: {}".format(line), file=sys.stderr)

def send_query(s):
r = s.check()
model = []
if r == sat:
model = s.model()
return (r, model)

s = Solver()

# from z3_wrapper import *
# load_prelude("holba_z3Lib_prelude.z3")
# q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n"
# preluded_query(q)
def load_prelude(filename):
with open(filename, "r") as f:
pre = f.read()
s.from_string(pre)
s.push()

def preluded_query(q):
s.from_string(q)
(r, model) = send_query(s)
s.pop()
s.push()
if r == unsat:
print("unsat")
elif r == unknown:
print("unknown")
else:
print("sat")
print_model_for_holba(model, strip_names = False)

# python3 z3_wrapper.py holba_z3Lib_prelude.z3 loop
# script entry point
def main():
use_files = len(sys.argv) > 1
s = Solver()
use_files = False
preluded_loop = False
if len(sys.argv) > 1:
filename = sys.argv[1]
if len(sys.argv) > 2:
preluded_loop = True
else:
use_files = True

if preluded_loop:
load_prelude(filename)
while True:
#print("waiting for input", file=sys.stderr)
q = sys.stdin.readline().replace("\\n", "\n")
#print("sending input to query", file=sys.stderr)
preluded_query(q)
print("z3_wrapper query done", flush=True)

exit(-1)

do_debug = False
if do_debug:
debug_input(s)
elif use_files:
s.from_file(sys.argv[1])
s.from_file(filename)
else:
stdin = "\n".join(sys.stdin.readlines())
s.from_string(stdin)

r = s.check()
(r, model) = send_query(s)
if r == unsat:
print("unsat")
exit(0)
Expand All @@ -261,13 +320,7 @@ def main():
print("sat")
#print(s.model(), file=sys.stderr)

model = s.model()
hol_list = model_to_list(model)

for (varname, term) in hol_list:
print(varname)
print(term)
#print("on stdout: {}".format(line), file=sys.stderr)
print_model_for_holba(model)


if __name__ == '__main__':
Expand Down

0 comments on commit ce10d5c

Please sign in to comment.