diff --git a/cvc5_pythonic_api/cvc5_pythonic.py b/cvc5_pythonic_api/cvc5_pythonic.py index 35ee9ad..8810692 100644 --- a/cvc5_pythonic_api/cvc5_pythonic.py +++ b/cvc5_pythonic_api/cvc5_pythonic.py @@ -6212,6 +6212,42 @@ def model(self): """ return ModelRef(self) + def proof(self): + """Return a proof for the last `check()`. + + This function raises an exception if + a proof is not available (e.g., last `check()` does not return unsat). + + >>> s = Solver() + >>> s.set('produce-proofs','true') + >>> a = Int('a') + >>> s.add(a + 2 == 0) + >>> s.check() + sat + >>> try: + ... s.proof() + ... except RuntimeError: + ... print("failed to get proof (last `check()` must have returned unsat)") + failed to get proof (last `check()` must have returned unsat) + >>> s.add(a == 0) + >>> s.check() + unsat + >>> s.proof() + (SCOPE: Not(And(a + 2 == 0, a == 0)), + (SCOPE: Not(And(a + 2 == 0, a == 0)), + [a + 2 == 0, a == 0], + (EQ_RESOLVE: False, + (ASSUME: a == 0, [a == 0]), + (MACRO_SR_EQ_INTRO: (a == 0) == False, + [a == 0, 7, 12], + (EQ_RESOLVE: a == -2, + (ASSUME: a + 2 == 0, [a + 2 == 0]), + (MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2), + [a + 2 == 0, 7, 12])))))) + """ + p = self.solver.getProof()[0] + return ProofRef(self, p) + def assertions(self): """Return an AST vector containing all added constraints. @@ -6760,6 +6796,98 @@ def evaluate(t): return m[t] +class ProofRef: + """A proof tree where every proof reference corresponds to the + root step of a proof. The branches of the root step are the + premises of the step.""" + + def __init__(self, solver, proof): + self.proof = proof + self.solver = solver + + def __del__(self): + if self.solver is not None: + self.solver = None + + def __repr__(self): + return obj_to_string(self) + + def getRule(self): + """Returns the proof rule used by the root step of the proof. + + >>> s = Solver() + >>> s.set('produce-proofs','true') + >>> a = Int('a') + >>> s.add(a + 2 == 0, a == 0) + >>> s.check() + unsat + >>> p = s.proof() + >>> p.getRule() + + """ + return self.proof.getRule() + + def getResult(self): + """Returns the conclusion of the root step of the proof. + + >>> s = Solver() + >>> s.set('produce-proofs','true') + >>> a = Int('a') + >>> s.add(a + 2 == 0, a == 0) + >>> s.check() + unsat + >>> p = s.proof() + >>> p.getResult() + Not(And(a + 2 == 0, a == 0)) + """ + return _to_expr_ref(self.proof.getResult(), Context(self.solver)) + + def getChildren(self): + """Returns the premises, i.e., proofs themselvels, of the root step of + the proof. + + >>> s = Solver() + >>> s.set('produce-proofs','true') + >>> a = Int('a') + >>> s.add(a + 2 == 0, a == 0) + >>> s.check() + unsat + >>> p = s.proof() + >>> p = p.getChildren()[0].getChildren()[0] + >>> p + (EQ_RESOLVE: False, + (ASSUME: a == 0, [a == 0]), + (MACRO_SR_EQ_INTRO: (a == 0) == False, + [a == 0, 7, 12], + (EQ_RESOLVE: a == -2, + (ASSUME: a + 2 == 0, [a + 2 == 0]), + (MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2), + [a + 2 == 0, 7, 12])))) + """ + children = self.proof.getChildren() + return [ProofRef(self.solver, cp) for cp in children] + + def getArguments(self): + """Returns the arguments of the root step of the proof as a list of + expressions. + + >>> s = Solver() + >>> s.set('produce-proofs','true') + >>> a = Int('a') + >>> s.add(a + 2 == 0, a == 0) + >>> s.check() + unsat + >>> p = s.proof() + >>> p.getArguments() + [] + >>> p = p.getChildren()[0] + >>> p.getArguments() + [a + 2 == 0, a == 0] + """ + args = self.proof.getArguments() + return [_to_expr_ref(a, Context(self.solver)) for a in args] + + def simplify(a): """Simplify the expression `a`. diff --git a/cvc5_pythonic_api/cvc5_pythonic_printer.py b/cvc5_pythonic_api/cvc5_pythonic_printer.py index d198c36..fab4f0c 100644 --- a/cvc5_pythonic_api/cvc5_pythonic_printer.py +++ b/cvc5_pythonic_api/cvc5_pythonic_printer.py @@ -1318,6 +1318,27 @@ def pp_model(self, m): break return seq3(r, "[", "]") + def pp_proof(self, p, d): + if d > self.max_depth: + return self.pp_ellipses() + r = [] + rule = str(p.getRule())[10:] + result = p.getResult() + childrenProofs = p.getChildren() + args = p.getArguments() + result_pp = self.pp_expr(result, 0, []) + r.append( + compose(to_format("{}: ".format(rule)), indent(_len(rule) + 2, result_pp)) + ) + if args: + r_args = [] + for arg in args: + r_args.append(self.pp_expr(arg, 0, [])) + r.append(seq3(r_args, "[", "]")) + for cPf in childrenProofs: + r.append(self.pp_proof(cPf, d + 1)) + return seq3(r) + def pp_func_entry(self, e): num = e.num_args() if num > 1: @@ -1377,6 +1398,8 @@ def main(self, a): return self.pp_seq(a.assertions(), 0, []) elif isinstance(a, cvc.ModelRef): return self.pp_model(a) + elif isinstance(a, cvc.ProofRef): + return self.pp_proof(a, 0) elif isinstance(a, list) or isinstance(a, tuple): return self.pp_list(a) else: diff --git a/test/pgm_outputs/proof.py.out b/test/pgm_outputs/proof.py.out new file mode 100644 index 0000000..3f65111 --- /dev/null +++ b/test/pgm_outputs/proof.py.out @@ -0,0 +1 @@ +unsat diff --git a/test/pgms/proof.py b/test/pgms/proof.py new file mode 100644 index 0000000..cd97ee5 --- /dev/null +++ b/test/pgms/proof.py @@ -0,0 +1,34 @@ +from cvc5 import ProofRule +from cvc5_pythonic_api import * + +def collect_initial_assumptions(proof): + # the initial assumptions are all the arguments of the initial + # SCOPE applications in the proof + proof_assumptions = [] + while (proof.getRule() == ProofRule.SCOPE): + proof_assumptions += proof.getArguments() + proof = proof.getChildren()[0] + return proof_assumptions + +def validate_proof_assumptions(assertions, proof_assumptions): + # checks that the assumptions in the produced proof match the + # assertions in the problem + return sum([c in assertions for c in proof_assumptions]) == len(proof_assumptions) + + +p1, p2, p3 = Bools('p1 p2 p3') +x, y = Ints('x y') +s = Solver() +s.set('produce-proofs','true') +assertions = [p1, p2, p3, Implies(p1, x > 0), Implies(p2, y > x), Implies(p2, y < 1), Implies(p3, y > -3)] + +for a in assertions: + s.add(a) + +print(s.check()) + +proof = s.proof() + +proof_assumptions = collect_initial_assumptions(proof) + +assert validate_proof_assumptions(assertions, proof_assumptions)