Skip to content

Commit a2a3746

Browse files
authored
Add Boolean overloads: &, |, ^, ~ (#103)
Also: add a dummy `Model` pseudo-object and fix how `evaluate` handles contexts.
1 parent be54c23 commit a2a3746

File tree

3 files changed

+94
-35
lines changed

3 files changed

+94
-35
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
with:
1919
options: "--check --verbose"
2020
src: "cvc5_pythonic_api"
21-
version: "23.7.0"
21+
version: "24.10.0"
2222

2323
- uses: actions/checkout@v2
2424
with:
@@ -56,7 +56,7 @@ jobs:
5656
- name: Build cvc5
5757
run: |
5858
cd cvc5/
59-
./configure.sh production --auto-download --python-bindings --cocoa
59+
./configure.sh production --auto-download --python-bindings --cocoa --gpl
6060
cd build/
6161
make -j${{ env.num_proc }}
6262

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ check:
77
pyright ./cvc5_pythonic_api
88

99
fmt:
10-
black --required-version 23.7.0 ./cvc5_pythonic_api
10+
black --required-version 24 ./cvc5_pythonic_api
1111

1212
check-fmt:
13-
black --check --verbose --required-version 23.7.0 ./cvc5_pythonic_api
13+
black --check --verbose --required-version 24 ./cvc5_pythonic_api
1414

1515
coverage:
1616
coverage run test_unit.py && coverage report && coverage html

cvc5_pythonic_api/cvc5_pythonic.py

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
* Missing features:
5353
* Patterns
5454
* Models for uninterpreted sorts
55+
* The `Model` function
56+
* In our API, this function returns an object whose only method is `evaluate`.
5557
* Pseudo-boolean counting constraints
5658
* AtMost, AtLeast, PbLe, PbGe, PbEq
5759
* HTML integration
@@ -558,9 +560,6 @@ def _ctx_from_ast_arg_list(args, default_ctx=None):
558560
if is_ast(a):
559561
if ctx is None:
560562
ctx = a.ctx
561-
else:
562-
if debugging():
563-
_assert(ctx == a.ctx, "Context mismatch")
564563
if ctx is None:
565564
ctx = default_ctx
566565
return ctx
@@ -1245,8 +1244,6 @@ def If(a, b, c, ctx=None):
12451244
s = BoolSort(ctx)
12461245
a = s.cast(a)
12471246
b, c = _coerce_exprs(b, c, ctx)
1248-
if debugging():
1249-
_assert(a.ctx == b.ctx, "Context mismatch")
12501247
return _to_expr_ref(ctx.solver.mkTerm(Kind.ITE, a.ast, b.ast, c.ast), ctx)
12511248

12521249

@@ -1429,6 +1426,38 @@ def __mul__(self, other):
14291426
return 0
14301427
return If(self, other, 0)
14311428

1429+
def __and__(self, other):
1430+
"""Create the SMT and expression `self & other`.
1431+
1432+
>>> solve(Bool("x") & Bool("y"))
1433+
[x = True, y = True]
1434+
"""
1435+
return And(self, other)
1436+
1437+
def __or__(self, other):
1438+
"""Create the SMT or expression `self | other`.
1439+
1440+
>>> solve(Bool("x") | Bool("y"), Not(Bool("x")))
1441+
[x = False, y = True]
1442+
"""
1443+
return Or(self, other)
1444+
1445+
def __xor__(self, other):
1446+
"""Create the SMT xor expression `self ^ other`.
1447+
1448+
>>> solve(Bool("x") ^ Bool("y"), Not(Bool("x")))
1449+
[x = False, y = True]
1450+
"""
1451+
return Xor(self, other)
1452+
1453+
def __invert__(self):
1454+
"""Create the SMT not expression `~self`.
1455+
1456+
>>> solve(~Bool("x"))
1457+
[x = False]
1458+
"""
1459+
return Not(self)
1460+
14321461

14331462
def is_bool(a):
14341463
"""Return `True` if `a` is an SMT Boolean expression.
@@ -1875,8 +1904,6 @@ def cast(self, val):
18751904
String
18761905
"""
18771906
if is_expr(val):
1878-
if debugging():
1879-
_assert(self.ctx == val.ctx, "Context mismatch")
18801907
val_s = val.sort()
18811908
if self.eq(val_s):
18821909
return val
@@ -2617,8 +2644,6 @@ def cast(self, val):
26172644
failed
26182645
"""
26192646
if is_expr(val):
2620-
if debugging():
2621-
_assert(self.ctx == val.ctx, "Context mismatch")
26222647
val_s = val.sort()
26232648
if self.eq(val_s):
26242649
return val
@@ -4067,8 +4092,6 @@ def cast(self, val):
40674092
'#b00000000000000000000000000001010'
40684093
"""
40694094
if is_expr(val):
4070-
if debugging():
4071-
_assert(self.ctx == val.ctx, "Context mismatch")
40724095
# Idea: use sign_extend if sort of val is a bitvector of smaller size
40734096
return val
40744097
else:
@@ -5494,7 +5517,6 @@ def ArraySort(*sig):
54945517
if debugging():
54955518
for s in sig:
54965519
_assert(is_sort(s), "SMT sort expected")
5497-
_assert(s.ctx == r.ctx, "Context mismatch")
54985520
ctx = d.ctx
54995521
if len(sig) == 2:
55005522
return ArraySortRef(ctx.solver.mkArraySort(d.ast, r.ast), ctx)
@@ -6238,12 +6260,22 @@ def proof(self):
62386260
[a + 2 == 0, a == 0],
62396261
(EQ_RESOLVE: False,
62406262
(ASSUME: a == 0, [a == 0]),
6241-
(MACRO_SR_EQ_INTRO: (a == 0) == False,
6242-
[a == 0, 7, 12],
6243-
(EQ_RESOLVE: a == -2,
6244-
(ASSUME: a + 2 == 0, [a + 2 == 0]),
6245-
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
6246-
[a + 2 == 0, 7, 12]))))))
6263+
(TRANS: (a == 0) == False,
6264+
(CONG: (a == 0) == (-2 == 0),
6265+
[5],
6266+
(EQ_RESOLVE: a == -2,
6267+
(ASSUME: a + 2 == 0, [a + 2 == 0]),
6268+
(TRANS: (a + 2 == 0) == (a == -2),
6269+
(CONG: (a + 2 == 0) == (2 + a == 0),
6270+
[5],
6271+
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
6272+
[a + 2 == 2 + a, 3, 7]),
6273+
(REFL: 0 == 0, [0])),
6274+
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
6275+
[(2 + a == 0) == (a == -2), 3, 7]))),
6276+
(REFL: 0 == 0, [0])),
6277+
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
6278+
[(-2 == 0) == False, 3, 7])))))
62476279
"""
62486280
p = self.solver.getProof()[0]
62496281
return ProofRef(self, p)
@@ -6789,13 +6821,36 @@ def decls(self):
67896821

67906822

67916823
def evaluate(t):
6792-
"""Evaluates the given term (assuming it is constant!)"""
6824+
"""Evaluates the given term (assuming it is constant!)
6825+
6826+
>>> evaluate(evaluate(BitVecVal(1, 8) + BitVecVal(2, 8)) + BitVecVal(3, 8))
6827+
6
6828+
"""
6829+
if not isinstance(t, ExprRef):
6830+
raise TypeError("Can only evaluate `ExprRef`s")
67936831
s = Solver()
67946832
s.check()
67956833
m = s.model()
67966834
return m[t]
67976835

67986836

6837+
class EmptyModel:
6838+
def evaluate(self, t):
6839+
return evaluate(t)
6840+
6841+
6842+
def Model(ctx=None):
6843+
"""Return an object for evaluating terms.
6844+
6845+
We recommend using the standalone `evaluate` function for this instead,
6846+
but we also provide this function and its return object for z3 compatibility.
6847+
6848+
>>> Model().evaluate(BitVecVal(1, 8) + BitVecVal(2, 8))
6849+
3
6850+
"""
6851+
return EmptyModel()
6852+
6853+
67996854
class ProofRef:
68006855
"""A proof tree where every proof reference corresponds to the
68016856
root step of a proof. The branches of the root step are the
@@ -6857,12 +6912,22 @@ def getChildren(self):
68576912
>>> p
68586913
(EQ_RESOLVE: False,
68596914
(ASSUME: a == 0, [a == 0]),
6860-
(MACRO_SR_EQ_INTRO: (a == 0) == False,
6861-
[a == 0, 7, 12],
6862-
(EQ_RESOLVE: a == -2,
6863-
(ASSUME: a + 2 == 0, [a + 2 == 0]),
6864-
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
6865-
[a + 2 == 0, 7, 12]))))
6915+
(TRANS: (a == 0) == False,
6916+
(CONG: (a == 0) == (-2 == 0),
6917+
[5],
6918+
(EQ_RESOLVE: a == -2,
6919+
(ASSUME: a + 2 == 0, [a + 2 == 0]),
6920+
(TRANS: (a + 2 == 0) == (a == -2),
6921+
(CONG: (a + 2 == 0) == (2 + a == 0),
6922+
[5],
6923+
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
6924+
[a + 2 == 2 + a, 3, 7]),
6925+
(REFL: 0 == 0, [0])),
6926+
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
6927+
[(2 + a == 0) == (a == -2), 3, 7]))),
6928+
(REFL: 0 == 0, [0])),
6929+
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
6930+
[(-2 == 0) == False, 3, 7])))
68666931
"""
68676932
children = self.proof.getChildren()
68686933
return [ProofRef(self.solver, cp) for cp in children]
@@ -6965,8 +7030,6 @@ def cast(self, val):
69657030
'(fp #b0 #b01111111 #b00000000000000000000000)'
69667031
"""
69677032
if is_expr(val):
6968-
if debugging():
6969-
_assert(self.ctx == val.ctx, "Context mismatch")
69707033
return val
69717034
else:
69727035
return FPVal(val, None, self, self.ctx)
@@ -8633,7 +8696,6 @@ def CreateDatatypes(*ds):
86338696
_assert(
86348697
all([isinstance(d, Datatype) for d in ds]), "Arguments must be Datatypes"
86358698
)
8636-
_assert(all([d.ctx == ds[0].ctx for d in ds]), "Context mismatch")
86378699
_assert(all([d.constructors != [] for d in ds]), "Non-empty Datatypes expected")
86388700
ctx = ds[0].ctx
86398701
s = ctx.solver
@@ -9240,9 +9302,6 @@ def cast(self, val):
92409302
'#f10m31'
92419303
"""
92429304
if is_expr(val):
9243-
if debugging():
9244-
_assert(self.ctx == val.ctx, "Context mismatch")
9245-
# Idea: use sign_extend if sort of val is a bitvector of smaller size
92469305
return val
92479306
else:
92489307
return FiniteFieldVal(val, self)

0 commit comments

Comments
 (0)