-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsmt2-printer.sml
320 lines (286 loc) · 9.45 KB
/
smt2-printer.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
structure SMT2Printer = struct
open Util
open Expr
open UVar
open Normalize
open VC
infixr 0 $
infixr 1 -->
exception SMTError of string
fun escape s = if s = "_" then "__!escaped_from_underscore_for_smt" else String.map (fn c => if c = #"'" then #"!" else c) s
fun evar_name n = "!!" ^ str_int n
fun print_idx_bin_op opr =
case opr of
AddI => "+"
| BoundedMinusI => "-"
| MultI => "*"
| EqI => "="
| AndI => "and"
| ExpNI => "exp_i_i"
| LtI => "<"
| GeI => ">="
| MaxI => raise Impossible "print_idx_bin_op ()"
| MinI => raise Impossible "print_idx_bin_op ()"
| TimeApp => raise Impossible "print_idx_bin_op ()"
fun print_i ctx i =
case i of
VarI id =>
(case id of
ID (n, _) =>
(List.nth (ctx, n) handle Subscript => raise SMTError $ "Unbound variable: " ^ str_int n)
| QID _ =>
raise SMTError $ "Unbound variable: " ^ LongId.str_raw_long_id str_int id
)
| IConst (c, _) =>
(case c of
ICNat n => str_int n
| ICTime x => TimeType.toString x
| ICBool b => str_bool b
| ICTT => "TT"
| ICAdmit => "TT"
)
| UnOpI (opr, i, _) =>
(case opr of
ToReal => sprintf "(to_real $)" [print_i ctx i]
| Log2 =>
sprintf "(log2 $)" [print_i ctx i]
(* raise SMTError "can't handle log2" *)
| Ceil => sprintf "(ceil $)" [print_i ctx i]
| Floor => sprintf "(floor $)" [print_i ctx i]
| B2n => sprintf "(b2i $)" [print_i ctx i]
| Neg => sprintf "(not $)" [print_i ctx i]
| IUDiv n => sprintf "(/ $ $)" [print_i ctx i, str_int n]
| IUExp s => sprintf "(^ $ $)" [print_i ctx i, s]
)
| BinOpI (opr, i1, i2) =>
(case opr of
MaxI =>
let
fun max a b =
sprintf "(ite (>= $ $) $ $)" [a, b, a, b]
in
max (print_i ctx i1) (print_i ctx i2)
end
| MinI =>
let
fun min a b =
sprintf "(ite (<= $ $) $ $)" [a, b, a, b]
in
min (print_i ctx i1) (print_i ctx i2)
end
| BoundedMinusI =>
let
fun bounded_minus a b =
sprintf "(ite (< $ $) 0 (- $ $))" [a, b, a, b]
in
bounded_minus (print_i ctx i1) (print_i ctx i2)
end
| IApp =>
let
val (f, is) = collect_IApp i1
val is = f :: is
val is = is @ [i2]
in
(* sprintf "(app_$$)" [str_int (length is - 1), join_prefix " " $ map (print_i ctx) is] *)
sprintf "($)" [join " " $ map (print_i ctx) is]
end
(* | ExpNI => sprintf "($ $)" [print_idx_bin_op opr, print_i ctx i2] *)
| _ =>
sprintf "($ $ $)" [print_idx_bin_op opr, print_i ctx i1, print_i ctx i2]
)
| Ite (i1, i2, i3, _) => sprintf "(ite $ $ $)" [print_i ctx i1, print_i ctx i2, print_i ctx i3]
| IAbs _ => raise SMTError "can't handle abstraction"
| UVarI (x, _) =>
case !x of
Refined i => print_i ctx i
| Fresh _ => raise SMTError "index contains uvar"
fun negate s = sprintf "(not $)" [s]
fun print_base_sort b =
case b of
UnitSort => "Unit"
| BoolSort => "Bool"
| Nat => "Int"
| Time => "Real"
fun print_bsort bsort =
case bsort of
Base b => print_base_sort b
| BSArrow _ => raise SMTError "can't handle higher-order sorts"
| UVarBS x =>
case !x of
Refined b => print_bsort b
| Fresh _ => raise SMTError "bsort contains uvar"
fun print_p ctx p =
let
fun str_conn opr =
case opr of
And => "and"
| Or => "or"
| Imply => "=>"
| Iff => "="
fun str_pred opr =
case opr of
EqP => "="
| LeP => "<="
| LtP => "<"
| GeP => ">="
| GtP => ">"
| BigO => raise SMTError "can't handle big-O"
fun f p =
case p of
PTrueFalse (b, _) => str_bool b
| Not (p, _) => negate (f p)
| BinConn (opr, p1, p2) => sprintf "($ $ $)" [str_conn opr, f p1, f p2]
(* | BinPred (BigO, i1, i2) => sprintf "(bigO $ $)" [print_i ctx i1, print_i ctx i2] *)
(* | BinPred (BigO, i1, i2) => "true" *)
| BinPred (opr, i1, i2) => sprintf "($ $ $)" [str_pred opr, print_i ctx i1, print_i ctx i2]
| Quan (Exists _, bs, Bind ((name, _), p), _) => raise SMTError "Don't trust SMT solver to solve existentials"
| Quan (q, bs, Bind ((name, _), p), _) => sprintf "($ (($ $)) $)" [str_quan q, name, print_bsort bs, print_p (name :: ctx) p]
in
f p
end
fun declare_const x sort =
(* sprintf "(declare-const $ $)" [x, sort] *)
sprintf "(declare-fun $ () $)" [x, sort]
fun assert s =
sprintf "(assert $)" [s]
fun assert_p ctx p =
assert (print_p ctx p)
fun print_hyp ctx h =
case h of
VarH (name, bs) =>
(case update_bs bs of
Base b =>
(declare_const name (print_base_sort b), name :: ctx)
| BSArrow _ =>
let
val (args, ret) = collect_BSArrow bs
in
(sprintf "(declare-fun $ ($) $)" [name, join " " $ map print_bsort args, print_bsort ret], name :: ctx)
end
| UVarBS x => raise SMTError "hypothesis contains uvar"
)
| PropH p =>
let
val p = assert (print_p ctx p)
handle SMTError _ => "" (* always sound to discard hypothesis *)
in
(p, ctx)
end
fun prelude get_ce = [
(* "(set-logic ALL_SUPPORTED)", *)
if get_ce then "(set-option :produce-models true)" else "",
(* "(set-option :produce-proofs true)", *)
"(declare-datatypes () ((Unit TT)))",
"(declare-fun exp_i_i (Int Int) Int)",
(* "(declare-fun exp_i_i (Int) Int)", *)
"(declare-fun log2 (Real) Real)",
(* "(assert (forall ((x Real) (y Real))", *)
(* " (! (=> (and (< 0 x) (< 0 y)) (= (log2 ( * x y)) (+ (log2 x) (log2 y))))", *)
(* " :pattern ((log2 ( * x y))))))", *)
(* "(assert (forall ((x Real) (y Real))", *)
(* " (! (=> (and (< 0 x) (< 0 y)) (= (log2 (/ x y)) (- (log2 x) (log2 y))))", *)
(* " :pattern ((log2 (/ x y))))))", *)
(* "(assert (= (log2 1) 0))", *)
(* "(assert (= (log2 2) 1))", *)
(* "(assert (forall ((x Real) (y Real)) (=> (and (< 0 x) (< 0 y)) (=> (< x y) (< (log2 x) (log2 y))))))", *)
"(define-fun floor ((x Real)) Int",
"(to_int x))",
"(define-fun ceil ((x Real)) Int",
"(to_int (+ x 0.5)))",
"(define-fun b2i ((b Bool)) Int",
"(ite b 1 0))",
(* "(declare-datatypes () ((Fun_1 fn1)))", *)
(* "(declare-datatypes () ((Fun_2 fn2)))", *)
(* "(declare-fun app_1 (Fun_1 Int) Real)", *)
(* "(declare-fun app_2 (Fun_2 Int Int) Real)", *)
(* "(declare-fun bigO (Fun_2 Fun_2) Bool)", *)
""
]
val push = [
"(push 1)"
]
val pop = [
"(pop 1)"
]
fun check get_ce = [
"(check-sat)",
if get_ce then "(get-model)" else ""
(* "(get-proof)" *)
(* "(get-value (n))", *)
]
(* convert to Z3's types and naming conventions *)
fun conv_base_sort b =
case b of
UnitSort => (UnitSort, NONE)
| BoolSort => (BoolSort, NONE)
| Nat => (Nat, SOME (BinPred (LeP, ConstIN (0, dummy), VarI (ID (0, dummy)))))
| Time => (Time, SOME (BinPred (LeP, ConstIT (TimeType.zero, dummy), VarI (ID (0, dummy)))))
fun conv_bsort bsort =
case bsort of
Base b =>
let
val (b, p) = conv_base_sort b
in
(Base b, p)
end
| BSArrow _ => (bsort, NONE)
| UVarBS x =>
case !x of
Refined b => conv_bsort b
| Fresh _ => raise SMTError "bsort contains uvar"
fun conv_p p =
case p of
Quan (q, bs, Bind ((name, r), p), r_all) =>
let
val (bs, p1) = conv_bsort bs
val p = conv_p p
val p = case p1 of
NONE => p
| SOME p1 => (p1 --> p)
in
Quan (q, bs, Bind ((escape name, r), p), r_all)
end
| Not (p, r) => Not (conv_p p, r)
| BinConn (opr, p1, p2) => BinConn (opr, conv_p p1, conv_p p2)
| BinPred _ => p
| PTrueFalse _ => p
fun conv_hyp h =
case h of
PropH _ => [h]
| VarH (name, bs) =>
let
val (bs, p) = conv_bsort bs
val hs = [VarH (escape name, bs)]
val hs = hs @ (case p of SOME p => [PropH p] | _ => [])
in
hs
end
fun print_vc get_ce ((hyps, goal) : vc) =
let
val hyps = rev hyps
val hyps = concatMap conv_hyp hyps
val goal = conv_p goal
val lines = push
val (hyps, ctx) = foldl (fn (h, (hs, ctx)) => let val (h, ctx) = print_hyp ctx h in (h :: hs, ctx) end) ([], []) hyps
val hyps = rev hyps
val lines = lines @ hyps
val lines = lines @ [assert (negate (print_p ctx goal))]
val lines = lines @ check get_ce
val lines = lines @ pop
val lines = lines @ [""]
in
lines
end
fun to_smt2 get_ce vcs =
let
val lines =
concatMap (print_vc get_ce) vcs
val lines = prelude get_ce @ lines
val s = join_lines lines
in
s
end
end
(* open CheckNoUVar *)
(* val vcs = map no_uvar_vc vcs *)
(* handle NoUVarError _ => raise SMTError "VC contains uvar" *)