From d61f506da254b919f93e571a84247319de78f526 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 9 Feb 2025 22:13:28 -0800 Subject: [PATCH] feat: `simp +arith` normalizes coefficient in linear integer polynomials (#7015) This PR makes sure `simp +arith` normalizes coefficients in linear integer polynomials. There is still one todo: tightening the bound of inequalities. --- src/Init/Data/Int/Linear.lean | 148 +++++++++++++++++- .../Meta/Tactic/LinearArith/Int/Simp.lean | 72 +++++++-- tests/lean/run/simp_int_arith.lean | 47 ++++++ 3 files changed, 254 insertions(+), 13 deletions(-) diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 88ef1377508b..0c87944b3904 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -8,6 +8,7 @@ import Init.ByCases import Init.Data.Prod import Init.Data.Int.Lemmas import Init.Data.Int.LemmasAux +import Init.Data.Int.DivModLemmas import Init.Data.RArray namespace Int.Linear @@ -97,10 +98,34 @@ def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop | .eq p => p.denote ctx = 0 | .le p => p.denote ctx ≤ 0 +def Poly.div (k : Int) : Poly → Poly + | .num k' => .num (k'/k) + | .add k' x p => .add (k'/k) x (div k p) + +def Poly.divAll (k : Int) : Poly → Bool + | .num k' => (k'/k)*k == k' + | .add k' _ p => (k'/k)*k == k' && divAll k p + +def Poly.divCoeffs (k : Int) : Poly → Bool + | .num _ => true + | .add k' _ p => (k'/k)*k == k' && divCoeffs k p + +def Poly.getConst : Poly → Int + | .num k => k + | .add _ _ p => getConst p + def PolyCnstr.norm : PolyCnstr → PolyCnstr | .eq p => .eq p.norm | .le p => .le p.norm +def PolyCnstr.divAll (k : Int) : PolyCnstr → Bool + | .eq p => p.divAll k + | .le p => p.divAll k + +def PolyCnstr.div (k : Int) : PolyCnstr → PolyCnstr + | .eq p => .eq <| p.div k + | .le p => .le <| p.div k + inductive ExprCnstr where | eq (p₁ p₂ : Expr) | le (p₁ p₂ : Expr) @@ -114,6 +139,10 @@ def ExprCnstr.toPoly : ExprCnstr → PolyCnstr | .eq e₁ e₂ => .eq (e₁.sub e₂).toPoly.norm | .le e₁ e₂ => .le (e₁.sub e₂).toPoly.norm +-- Certificate for normalizing the coefficients of a constraint +def divBy (e e' : ExprCnstr) (k : Int) : Bool := + k > 0 && e.toPoly.divAll k && e'.toPoly == e.toPoly.div k + attribute [local simp] Int.add_comm Int.add_assoc Int.add_left_comm Int.add_mul Int.mul_add attribute [local simp] Poly.insert Poly.denote Poly.norm Poly.addConst @@ -143,7 +172,30 @@ private theorem sub_fold (a b : Int) : a.sub b = a - b := rfl private theorem neg_fold (a : Int) : a.neg = -a := rfl attribute [local simp] sub_fold neg_fold -attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly PolyCnstr.denote Expr.denote + +attribute [local simp] Poly.div Poly.divAll PolyCnstr.denote + +theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.divAll k → (p.div k).denote ctx * k = p.denote ctx := by + induction p with + | num _ => simp + | add k' v p ih => + simp; intro h₁ h₂ + have ih := ih h₂ + simp [ih] + apply congrArg (denote ctx p + ·) + rw [Int.mul_right_comm, h₁] + +attribute [local simp] Poly.divCoeffs Poly.getConst + +theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → (p.div k).denote ctx * k + p.getConst % k = p.denote ctx := by + induction p with + | num k' => simp; rw [Int.add_comm, Int.mul_comm, Int.ediv_add_emod] + | add k' v p ih => + simp; intro h₁ h₂ + rw [← ih h₂] + rw [Int.mul_right_comm, h₁, Int.add_assoc] + +attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly Expr.denote theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) : (toPoly'.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by @@ -172,7 +224,7 @@ theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) : theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by simp [toPoly, toPoly', Expr.denote_toPoly'_go] -attribute [local simp] Expr.denote_toPoly +attribute [local simp] Expr.denote_toPoly PolyCnstr.denote theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denote ctx = c.denote ctx := by cases c <;> simp @@ -228,6 +280,61 @@ theorem ExprCnstr.eq_of_toPoly_eq_const (ctx : Context) (x : Var) (k : Int) (c : rw [h]; simp rw [Int.add_comm, ← Int.sub_eq_add_neg, Int.sub_eq_zero] +private theorem mul_eq_zero_iff_eq_zero (a b : Int) : b ≠ 0 → (a * b = 0 ↔ a = 0) := by + intro h + constructor + · intro h' + cases Int.mul_eq_zero.mp h' + · assumption + · contradiction + · intro; simp [*] + +private theorem eq_mul_le_zero {a b : Int} : 0 < b → (a ≤ 0 ↔ a * b ≤ 0) := by + intro h + have : 0 = 0 * b := by simp + constructor + · intro h' + rw [this] + apply Int.mul_le_mul h' <;> try simp + apply Int.le_of_lt h + · intro h' + rw [this] at h' + exact Int.le_of_mul_le_mul_right h' h + +attribute [local simp] PolyCnstr.divAll PolyCnstr.div + +theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 → p.divAll k → e.toPoly = p → e'.toPoly = p.div k → e.denote ctx = e'.denote ctx := by + intro h₀ h₁ h₂ h₃ + have hz : k ≠ 0 := by intro h; simp [h] at h₀ + cases p <;> simp at h₁ + next p => + replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁ + replace h₂ := congrArg (PolyCnstr.denote ctx) h₂ + simp only [PolyCnstr.denote.eq_1, ← h₁] at h₂ + replace h₃ := congrArg (PolyCnstr.denote ctx) h₃ + simp only [PolyCnstr.denote.eq_1, PolyCnstr.div] at h₃ + rw [mul_eq_zero_iff_eq_zero _ _ hz] at h₂ + have := Eq.trans h₂ h₃.symm + rw [denote_toPoly, denote_toPoly] at this + exact this + next p => + -- TODO: this is correct but we can simplify `p ≤ 0` if `p.divCoeffs k` and `p.getConst % k > 0`. Here, we are simplifying only the case `p.getConst % k = 0` + replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁ + replace h₂ := congrArg (PolyCnstr.denote ctx) h₂ + simp only [PolyCnstr.denote.eq_2, ← h₁] at h₂ + replace h₃ := congrArg (PolyCnstr.denote ctx) h₃ + simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₃ + rw [eq_mul_le_zero h₀] at h₃ + have := Eq.trans h₂ h₃.symm + rw [denote_toPoly, denote_toPoly] at this + exact this + +theorem ExprCnstr.eq_of_toPoly_eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k → e.denote ctx = e'.denote ctx := by + intro h + simp only [divBy, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h + have ⟨⟨h₁, h₂⟩, h₃⟩ := h + exact ExprCnstr.eq_of_toPoly_eq_of_divBy' ctx e e' e.toPoly k h₁ h₂ rfl h₃ + def PolyCnstr.isUnsat : PolyCnstr → Bool | .eq (.num k) => k != 0 | .eq _ => false @@ -243,6 +350,43 @@ theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toP rw [ExprCnstr.denote_toPoly] at this assumption +def PolyCnstr.isUnsatCoeff (k : Int) : PolyCnstr → Bool + | .eq p => p.divCoeffs k && k > 0 && p.getConst % k > 0 + | .le _ => false + +private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k) (h₃ : a*k + b = 0) : False := by + have : b = -a*k := by + rw [← Int.neg_eq_of_add_eq_zero h₃, Int.neg_mul] + rw [this] at h₁ h₂ + conv at h₂ => rhs; rw [← Int.one_mul k] + have high := Int.lt_of_mul_lt_mul_right h₂ (Int.le_of_lt h₀) + rw [← Int.zero_mul k] at h₁ + have low := Int.lt_of_mul_lt_mul_right h₁ (Int.le_of_lt h₀) + replace low : 1 ≤ -a := low + have : (1 : Int) < 1 := Int.lt_of_le_of_lt low high + contradiction + +private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → k > 0 → p.getConst % k > 0 → (PolyCnstr.eq p).denote ctx = False := by + simp + intro h₁ h₂ h₃ h + have hnz : k ≠ 0 := by intro h; rw [h] at h₂; contradiction + have := Poly.denote_div_eq_of_divCoeffs ctx p k h₁ + rw [h] at this + have low := h₃ + have high := Int.emod_lt_of_pos p.getConst h₂ + exact contra h₂ low high this + +theorem ExprCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : ExprCnstr) (k : Int) : c.toPoly.isUnsatCoeff k → c.denote ctx = False := by + intro h + cases c <;> simp [toPoly, PolyCnstr.isUnsatCoeff] at h + next e₁ e₂ => + have ⟨⟨h₁, h₂⟩, h₃⟩ := h + have := PolyCnstr.eq_false ctx _ _ h₁ h₂ h₃ + simp at this + simp + intro he + simp [he] at this + def PolyCnstr.isValid : PolyCnstr → Bool | .eq (.num k) => k == 0 | .eq _ => false diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean index e8642fdf383c..eb55be295483 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean @@ -7,6 +7,40 @@ prelude import Lean.Meta.Tactic.LinearArith.Basic import Lean.Meta.Tactic.LinearArith.Int.Basic +def Int.Linear.Poly.gcdAll : Poly → Nat + | .num k => k.natAbs + | .add k _ p => go k.natAbs p +where + go (k : Nat) (p : Poly) : Nat := + if k == 1 then k + else match p with + | .num k' => Nat.gcd k k'.natAbs + | .add k' _ p => go (Nat.gcd k k'.natAbs) p + +def Int.Linear.PolyCnstr.gcdAll : PolyCnstr → Nat + | .eq p => p.gcdAll + | .le p => p.gcdAll + +def Int.Linear.Poly.gcdCoeffs : Poly → Nat + | .num _ => 1 + | .add k _ p => go k.natAbs p +where + go (k : Nat) (p : Poly) : Nat := + if k == 1 then k + else match p with + | .num _ => k + | .add k' _ p => go (Nat.gcd k k'.natAbs) p + +def Int.Linear.PolyCnstr.gcdCoeffs : PolyCnstr → Nat + | .eq p | .le p => p.gcdCoeffs + +def Int.Linear.PolyCnstr.isEq : PolyCnstr → Bool + | .eq _ => true + | .le _ => false + +def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int + | .eq p | .le p => p.getConst + namespace Lean.Meta.Linear.Int def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do @@ -16,28 +50,44 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do let p := c.toPoly if p.isUnsat then let r := mkConst ``False - let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue - return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) else if p.isValid then let r := mkConst ``True - let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue - return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) else let c' : LinearCnstr := p.toExprCnstr if c != c' then match p with | .eq (.add 1 x (.add (-1) y (.num 0))) => let r := mkIntEq atoms[x]! atoms[y]! - let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue - return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) | .eq (.add 1 x (.num k)) => let r := mkIntEq atoms[x]! (toExpr (-k)) - let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue - return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) | _ => - let r ← c'.toArith atoms - let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue - return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + let defaultK := do + let r ← c'.toArith atoms + let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) + let k := p.gcdCoeffs + if k == 1 then + defaultK + else if p.getConst % k == 0 then + let c' : LinearCnstr := (p.div k).toExprCnstr + let r ← c'.toArith atoms + let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) + else if p.isEq then + let r := mkConst ``False + let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat_coeff) (toContextExpr atoms) (toExpr c) (toExpr (Int.ofNat k)) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) + else + -- TODO: tight the bound + defaultK else return none diff --git a/tests/lean/run/simp_int_arith.lean b/tests/lean/run/simp_int_arith.lean index 4078df421f70..14b14fe1e72f 100644 --- a/tests/lean/run/simp_int_arith.lean +++ b/tests/lean/run/simp_int_arith.lean @@ -188,3 +188,50 @@ fun x y z f => -/ #guard_msgs (info) in #print ex₂ + +example (x y : Int) (h : False) : 2*x = x + y := by + simp +arith only + guard_target = x = y + contradiction + +example (x y : Int) (h : 2*x + 2*y = 4) : x + y = 2 := by + simp +arith only at h + guard_hyp h : x + y + -2 = 0 + simp +arith + assumption + +example (x y : Int) (h : 6*x + 3*y = 9) : 2*x + y = 3 := by + simp +arith only at h + guard_hyp h : 2*x + y + -3 = 0 + simp +arith + assumption + +example (x y : Int) (h : 2*x - 2*y ≤ 4) : x - y ≤ 2 := by + simp +arith only at h + guard_hyp h : x + -1*y + -2 ≤ 0 + simp +arith + assumption + +example (x y : Int) (h : -6*x + 3*y = -9) : - 2*x = -3 - y := by + simp +arith only at h + guard_hyp h : -2*x + y + 3 = 0 + simp +arith + assumption + +example (x y : Int) (h : 3*x + 6*y = 2) : False := by + simp +arith only at h + +example (x : Int) (h : 3*x = 1) : False := by + simp +arith only at h + +example (x : Int) (h : 2*x = 1) : False := by + simp +arith only at h + +example (x : Int) (h : x + x = 1) : False := by + simp +arith only at h + +example (x y : Int) (h : x + x + x = 1 + 2*y + x) : False := by + simp +arith only at h + +example (x : Int) (h : -x - x = 1) : False := by + simp +arith only at h