Skip to content

Commit

Permalink
feat: simp +arith normalizes coefficient in linear integer polynomi…
Browse files Browse the repository at this point in the history
…als (leanprover#7015)

This PR makes sure `simp +arith` normalizes coefficients in linear
integer polynomials. There is still one todo: tightening the bound of
inequalities.
  • Loading branch information
leodemoura authored Feb 10, 2025
1 parent 7f3e170 commit d61f506
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 13 deletions.
148 changes: 146 additions & 2 deletions src/Init/Data/Int/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Init.ByCases
import Init.Data.Prod
import Init.Data.Int.Lemmas
import Init.Data.Int.LemmasAux
import Init.Data.Int.DivModLemmas
import Init.Data.RArray

namespace Int.Linear
Expand Down Expand Up @@ -97,10 +98,34 @@ def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop
| .eq p => p.denote ctx = 0
| .le p => p.denote ctx ≤ 0

def Poly.div (k : Int) : Poly → Poly
| .num k' => .num (k'/k)
| .add k' x p => .add (k'/k) x (div k p)

def Poly.divAll (k : Int) : Poly → Bool
| .num k' => (k'/k)*k == k'
| .add k' _ p => (k'/k)*k == k' && divAll k p

def Poly.divCoeffs (k : Int) : Poly → Bool
| .num _ => true
| .add k' _ p => (k'/k)*k == k' && divCoeffs k p

def Poly.getConst : Poly → Int
| .num k => k
| .add _ _ p => getConst p

def PolyCnstr.norm : PolyCnstr → PolyCnstr
| .eq p => .eq p.norm
| .le p => .le p.norm

def PolyCnstr.divAll (k : Int) : PolyCnstr → Bool
| .eq p => p.divAll k
| .le p => p.divAll k

def PolyCnstr.div (k : Int) : PolyCnstr → PolyCnstr
| .eq p => .eq <| p.div k
| .le p => .le <| p.div k

inductive ExprCnstr where
| eq (p₁ p₂ : Expr)
| le (p₁ p₂ : Expr)
Expand All @@ -114,6 +139,10 @@ def ExprCnstr.toPoly : ExprCnstr → PolyCnstr
| .eq e₁ e₂ => .eq (e₁.sub e₂).toPoly.norm
| .le e₁ e₂ => .le (e₁.sub e₂).toPoly.norm

-- Certificate for normalizing the coefficients of a constraint
def divBy (e e' : ExprCnstr) (k : Int) : Bool :=
k > 0 && e.toPoly.divAll k && e'.toPoly == e.toPoly.div k

attribute [local simp] Int.add_comm Int.add_assoc Int.add_left_comm Int.add_mul Int.mul_add
attribute [local simp] Poly.insert Poly.denote Poly.norm Poly.addConst

Expand Down Expand Up @@ -143,7 +172,30 @@ private theorem sub_fold (a b : Int) : a.sub b = a - b := rfl
private theorem neg_fold (a : Int) : a.neg = -a := rfl

attribute [local simp] sub_fold neg_fold
attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly PolyCnstr.denote Expr.denote

attribute [local simp] Poly.div Poly.divAll PolyCnstr.denote

theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.divAll k → (p.div k).denote ctx * k = p.denote ctx := by
induction p with
| num _ => simp
| add k' v p ih =>
simp; intro h₁ h₂
have ih := ih h₂
simp [ih]
apply congrArg (denote ctx p + ·)
rw [Int.mul_right_comm, h₁]

attribute [local simp] Poly.divCoeffs Poly.getConst

theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → (p.div k).denote ctx * k + p.getConst % k = p.denote ctx := by
induction p with
| num k' => simp; rw [Int.add_comm, Int.mul_comm, Int.ediv_add_emod]
| add k' v p ih =>
simp; intro h₁ h₂
rw [← ih h₂]
rw [Int.mul_right_comm, h₁, Int.add_assoc]

attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly Expr.denote

theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
(toPoly'.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by
Expand Down Expand Up @@ -172,7 +224,7 @@ theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by
simp [toPoly, toPoly', Expr.denote_toPoly'_go]

attribute [local simp] Expr.denote_toPoly
attribute [local simp] Expr.denote_toPoly PolyCnstr.denote

theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denote ctx = c.denote ctx := by
cases c <;> simp
Expand Down Expand Up @@ -228,6 +280,61 @@ theorem ExprCnstr.eq_of_toPoly_eq_const (ctx : Context) (x : Var) (k : Int) (c :
rw [h]; simp
rw [Int.add_comm, ← Int.sub_eq_add_neg, Int.sub_eq_zero]

private theorem mul_eq_zero_iff_eq_zero (a b : Int) : b ≠ 0 → (a * b = 0 ↔ a = 0) := by
intro h
constructor
· intro h'
cases Int.mul_eq_zero.mp h'
· assumption
· contradiction
· intro; simp [*]

private theorem eq_mul_le_zero {a b : Int} : 0 < b → (a ≤ 0 ↔ a * b ≤ 0) := by
intro h
have : 0 = 0 * b := by simp
constructor
· intro h'
rw [this]
apply Int.mul_le_mul h' <;> try simp
apply Int.le_of_lt h
· intro h'
rw [this] at h'
exact Int.le_of_mul_le_mul_right h' h

attribute [local simp] PolyCnstr.divAll PolyCnstr.div

theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 → p.divAll k → e.toPoly = p → e'.toPoly = p.div k → e.denote ctx = e'.denote ctx := by
intro h₀ h₁ h₂ h₃
have hz : k ≠ 0 := by intro h; simp [h] at h₀
cases p <;> simp at h₁
next p =>
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_1, ← h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_1, PolyCnstr.div] at h₃
rw [mul_eq_zero_iff_eq_zero _ _ hz] at h₂
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
exact this
next p =>
-- TODO: this is correct but we can simplify `p ≤ 0` if `p.divCoeffs k` and `p.getConst % k > 0`. Here, we are simplifying only the case `p.getConst % k = 0`
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_2, ← h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₃
rw [eq_mul_le_zero h₀] at h₃
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
exact this

theorem ExprCnstr.eq_of_toPoly_eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k → e.denote ctx = e'.denote ctx := by
intro h
simp only [divBy, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have ⟨⟨h₁, h₂⟩, h₃⟩ := h
exact ExprCnstr.eq_of_toPoly_eq_of_divBy' ctx e e' e.toPoly k h₁ h₂ rfl h₃

def PolyCnstr.isUnsat : PolyCnstr → Bool
| .eq (.num k) => k != 0
| .eq _ => false
Expand All @@ -243,6 +350,43 @@ theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toP
rw [ExprCnstr.denote_toPoly] at this
assumption

def PolyCnstr.isUnsatCoeff (k : Int) : PolyCnstr → Bool
| .eq p => p.divCoeffs k && k > 0 && p.getConst % k > 0
| .le _ => false

private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k) (h₃ : a*k + b = 0) : False := by
have : b = -a*k := by
rw [← Int.neg_eq_of_add_eq_zero h₃, Int.neg_mul]
rw [this] at h₁ h₂
conv at h₂ => rhs; rw [← Int.one_mul k]
have high := Int.lt_of_mul_lt_mul_right h₂ (Int.le_of_lt h₀)
rw [← Int.zero_mul k] at h₁
have low := Int.lt_of_mul_lt_mul_right h₁ (Int.le_of_lt h₀)
replace low : 1 ≤ -a := low
have : (1 : Int) < 1 := Int.lt_of_le_of_lt low high
contradiction

private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → k > 0 → p.getConst % k > 0 → (PolyCnstr.eq p).denote ctx = False := by
simp
intro h₁ h₂ h₃ h
have hnz : k ≠ 0 := by intro h; rw [h] at h₂; contradiction
have := Poly.denote_div_eq_of_divCoeffs ctx p k h₁
rw [h] at this
have low := h₃
have high := Int.emod_lt_of_pos p.getConst h₂
exact contra h₂ low high this

theorem ExprCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : ExprCnstr) (k : Int) : c.toPoly.isUnsatCoeff k → c.denote ctx = False := by
intro h
cases c <;> simp [toPoly, PolyCnstr.isUnsatCoeff] at h
next e₁ e₂ =>
have ⟨⟨h₁, h₂⟩, h₃⟩ := h
have := PolyCnstr.eq_false ctx _ _ h₁ h₂ h₃
simp at this
simp
intro he
simp [he] at this

def PolyCnstr.isValid : PolyCnstr → Bool
| .eq (.num k) => k == 0
| .eq _ => false
Expand Down
72 changes: 61 additions & 11 deletions src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,40 @@ prelude
import Lean.Meta.Tactic.LinearArith.Basic
import Lean.Meta.Tactic.LinearArith.Int.Basic

def Int.Linear.Poly.gcdAll : Poly → Nat
| .num k => k.natAbs
| .add k _ p => go k.natAbs p
where
go (k : Nat) (p : Poly) : Nat :=
if k == 1 then k
else match p with
| .num k' => Nat.gcd k k'.natAbs
| .add k' _ p => go (Nat.gcd k k'.natAbs) p

def Int.Linear.PolyCnstr.gcdAll : PolyCnstr → Nat
| .eq p => p.gcdAll
| .le p => p.gcdAll

def Int.Linear.Poly.gcdCoeffs : Poly → Nat
| .num _ => 1
| .add k _ p => go k.natAbs p
where
go (k : Nat) (p : Poly) : Nat :=
if k == 1 then k
else match p with
| .num _ => k
| .add k' _ p => go (Nat.gcd k k'.natAbs) p

def Int.Linear.PolyCnstr.gcdCoeffs : PolyCnstr → Nat
| .eq p | .le p => p.gcdCoeffs

def Int.Linear.PolyCnstr.isEq : PolyCnstr → Bool
| .eq _ => true
| .le _ => false

def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int
| .eq p | .le p => p.getConst

namespace Lean.Meta.Linear.Int

def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
Expand All @@ -16,28 +50,44 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let p := c.toPoly
if p.isUnsat then
let r := mkConst ``False
let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else if p.isValid then
let r := mkConst ``True
let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
let c' : LinearCnstr := p.toExprCnstr
if c != c' then
match p with
| .eq (.add 1 x (.add (-1) y (.num 0))) =>
let r := mkIntEq atoms[x]! atoms[y]!
let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
| .eq (.add 1 x (.num k)) =>
let r := mkIntEq atoms[x]! (toExpr (-k))
let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
| _ =>
let r ← c'.toArith atoms
let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
let defaultK := do
let r ← c'.toArith atoms
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
let k := p.gcdCoeffs
if k == 1 then
defaultK
else if p.getConst % k == 0 then
let c' : LinearCnstr := (p.div k).toExprCnstr
let r ← c'.toArith atoms
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else if p.isEq then
let r := mkConst ``False
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat_coeff) (toContextExpr atoms) (toExpr c) (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
-- TODO: tight the bound
defaultK
else
return none

Expand Down
47 changes: 47 additions & 0 deletions tests/lean/run/simp_int_arith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,50 @@ fun x y z f =>
-/
#guard_msgs (info) in
#print ex₂

example (x y : Int) (h : False) : 2*x = x + y := by
simp +arith only
guard_target = x = y
contradiction

example (x y : Int) (h : 2*x + 2*y = 4) : x + y = 2 := by
simp +arith only at h
guard_hyp h : x + y + -2 = 0
simp +arith
assumption

example (x y : Int) (h : 6*x + 3*y = 9) : 2*x + y = 3 := by
simp +arith only at h
guard_hyp h : 2*x + y + -3 = 0
simp +arith
assumption

example (x y : Int) (h : 2*x - 2*y ≤ 4) : x - y ≤ 2 := by
simp +arith only at h
guard_hyp h : x + -1*y + -20
simp +arith
assumption

example (x y : Int) (h : -6*x + 3*y = -9) : - 2*x = -3 - y := by
simp +arith only at h
guard_hyp h : -2*x + y + 3 = 0
simp +arith
assumption

example (x y : Int) (h : 3*x + 6*y = 2) : False := by
simp +arith only at h

example (x : Int) (h : 3*x = 1) : False := by
simp +arith only at h

example (x : Int) (h : 2*x = 1) : False := by
simp +arith only at h

example (x : Int) (h : x + x = 1) : False := by
simp +arith only at h

example (x y : Int) (h : x + x + x = 1 + 2*y + x) : False := by
simp +arith only at h

example (x : Int) (h : -x - x = 1) : False := by
simp +arith only at h

0 comments on commit d61f506

Please sign in to comment.