Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: remove nonterminal simps in UnionFind #868

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions Batteries/Data/UnionFind/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ theorem lt_of_parentD : parentD arr i ≠ i → i < arr.size :=

theorem parentD_set {arr : Array UFNode} {x v i} :
parentD (arr.set x v) i = if x.1 = i then v.parent else parentD arr i := by
rw [parentD]; simp [Array.get_eq_getElem, parentD]
split <;> [split <;> simp [Array.get_set, *]; split <;> [(subst i; cases ‹¬_› x.2); rfl]]
rw [parentD]; simp only [Array.size_set, Array.get_eq_getElem, parentD]
split
· split <;> simp_all
· split <;> [(subst i; cases ‹¬_› x.2); rfl]

theorem rankD_set {arr : Array UFNode} {x v i} :
rankD (arr.set x v) i = if x.1 = i then v.rank else rankD arr i := by
rw [rankD]; simp [Array.get_eq_getElem, rankD]
split <;> [split <;> simp [Array.get_set, *]; split <;> [(subst i; cases ‹¬_› x.2); rfl]]
rw [rankD]; simp only [Array.size_set, Array.get_eq_getElem, rankD]
split
· split <;> simp_all
· split <;> [(subst i; cases ‹¬_› x.2); rfl]

end UnionFind

Expand Down Expand Up @@ -146,7 +150,7 @@ theorem rank'_lt_rankMax (self : UnionFind) (i : Fin self.size) :
let rec go : ∀ {l} {x : UFNode}, x ∈ l → x.rank ≤ List.foldr (max ·.rank) 0 l
| a::l, _, List.Mem.head _ => by dsimp; apply Nat.le_max_left
| a::l, _, .tail _ h => by dsimp; exact Nat.le_trans (go h) (Nat.le_max_right ..)
simp [rankMax, Array.foldr_eq_foldr_data]
simp only [Array.get_eq_getElem, rankMax, Array.foldr_eq_foldr_data]
exact Nat.lt_succ.2 <| go (self.arr.data.get_mem i.1 i.2)

theorem rankD_lt_rankMax (self : UnionFind) (i : Nat) :
Expand All @@ -156,11 +160,11 @@ theorem rankD_lt_rankMax (self : UnionFind) (i : Nat) :
theorem lt_rankMax (self : UnionFind) (i : Nat) : self.rank i < self.rankMax := rankD_lt_rankMax ..

theorem push_rankD (arr : Array UFNode) : rankD (arr.push ⟨arr.size, 0⟩) i = rankD arr i := by
simp [rankD, Array.get_eq_getElem, Array.get_push]
simp only [rankD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite]
split <;> split <;> first | simp | cases ‹¬_› (Nat.lt_succ_of_lt ‹_›)

theorem push_parentD (arr : Array UFNode) : parentD (arr.push ⟨arr.size, 0⟩) i = parentD arr i := by
simp [parentD, Array.get_eq_getElem, Array.get_push]
simp only [parentD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite]
split <;> split <;> try simp
· exact Nat.le_antisymm (Nat.ge_of_not_lt ‹_›) (Nat.le_of_lt_succ ‹_›)
· cases ‹¬_› (Nat.lt_succ_of_lt ‹_›)
Expand All @@ -169,9 +173,9 @@ theorem push_parentD (arr : Array UFNode) : parentD (arr.push ⟨arr.size, 0⟩)
def push (self : UnionFind) : UnionFind where
arr := self.arr.push ⟨self.arr.size, 0⟩
parentD_lt {i} := by
simp [push_parentD]; simp [parentD]
simp only [Array.size_push, push_parentD]; simp only [parentD, Array.get_eq_getElem]
split <;> [exact fun _ => Nat.lt_succ_of_lt (self.parent'_lt _); exact id]
rankD_lt := by simp [push_parentD, push_rankD]; exact self.rank_lt
rankD_lt := by simp only [push_parentD, ne_eq, push_rankD]; exact self.rank_lt

/-- Root of a union-find node. -/
def root (self : UnionFind) (x : Fin self.size) : Fin self.size :=
Expand Down Expand Up @@ -205,18 +209,23 @@ termination_by self.rankMax - self.rank x

theorem parent_rootD (self : UnionFind) (x : Nat) :
self.parent (self.rootD x) = self.rootD x := by
rw [rootD]; split <;>
[simp [parentD, parent_root, -Array.get_eq_getElem]; simp [parentD_of_not_lt, *]]
rw [rootD]
split
· simp [parentD, parent_root, -Array.get_eq_getElem]
· simp [parentD_of_not_lt, *]

@[nolint unusedHavesSuffices]
theorem rootD_parent (self : UnionFind) (x : Nat) : self.rootD (self.parent x) = self.rootD x := by
simp [rootD, parent_lt]; split <;> simp [parentD, parentD_of_not_lt, *, -Array.get_eq_getElem]
(conv => rhs; rw [root]); split
· rw [root, dif_pos] <;> simp [*, -Array.get_eq_getElem]
· simp
simp only [rootD, Array.data_length, parent_lt]
split
· simp only [parentD, ↓reduceDIte, *]
(conv => rhs; rw [root]); split
· rw [root, dif_pos] <;> simp_all
· simp
· simp only [not_false_eq_true, parentD_of_not_lt, *]

theorem rootD_lt {self : UnionFind} {x : Nat} : self.rootD x < self.size ↔ x < self.size := by
simp [rootD]; split <;> simp [*]
simp only [rootD, Array.data_length]; split <;> simp [*]

@[nolint unusedHavesSuffices]
theorem rootD_eq_self {self : UnionFind} {x : Nat} : self.rootD x = x ↔ self.parent x = x := by
Expand Down Expand Up @@ -273,7 +282,9 @@ termination_by self.rankMax - self.rank x
@[nolint unusedHavesSuffices]
theorem findAux_root {self : UnionFind} {x : Fin self.size} :
(findAux self x).root = self.root x := by
rw [findAux, root]; simp; split <;> simp
rw [findAux, root]
simp only [Array.data_length, Array.get_eq_getElem, dite_eq_ite]
split <;> simp only
have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›)
exact findAux_root
termination_by self.rankMax - self.rank x
Expand All @@ -286,7 +297,7 @@ theorem findAux_s {self : UnionFind} {x : Fin self.size} :
rw [show self.rootD _ = (self.findAux ⟨_, self.parent'_lt x⟩).root from _]
· rw [findAux]; split <;> rfl
· rw [← rootD_parent, parent, parentD_eq]
simp [findAux_root, rootD]
simp only [rootD, Array.get_eq_getElem, Array.data_length, findAux_root]
apply dif_pos
exact parent'_lt ..

Expand All @@ -299,7 +310,8 @@ theorem rankD_findAux {self : UnionFind} {x : Fin self.size} :
rw [rankD_eq' (by simp [FindAux.size_eq, h]), Array.get_modify (by rwa [FindAux.size_eq])]
split <;> simp [← rankD_eq, rankD_findAux (x := ⟨_, self.parent'_lt x⟩), -Array.get_eq_getElem]
else
simp [rank, rankD]; rw [dif_neg (by rwa [FindAux.size_eq]), dif_neg h]
simp only [rankD, Array.data_length, Array.get_eq_getElem, rank]
rw [dif_neg (by rwa [FindAux.size_eq]), dif_neg h]
termination_by self.rankMax - self.rank x

theorem parentD_findAux {self : UnionFind} {x : Fin self.size} :
Expand All @@ -311,7 +323,7 @@ theorem parentD_findAux {self : UnionFind} {x : Fin self.size} :
· next h =>
rw [parentD]; split <;> rename_i h'
· rw [Array.get_modify (by simpa using h')]
simp [@eq_comm _ i, -Array.get_eq_getElem]
simp only [Array.data_length, @eq_comm _ i]
split <;> simp [← parentD_eq, -Array.get_eq_getElem]
· rw [if_neg (mt (by rintro rfl; simp [FindAux.size_eq]) h')]
rw [parentD, dif_neg]; simpa using h'
Expand All @@ -330,9 +342,11 @@ theorem parentD_findAux_lt {self : UnionFind} {x : Fin self.size} (h : i < self.
if h' : (self.arr.get x).parent = x then
rw [findAux_s, if_pos h']; apply self.parentD_lt h
else
rw [parentD_findAux]; split <;> [simp [rootD_lt]; skip]
have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›)
apply parentD_findAux_lt h
rw [parentD_findAux]
split
· simp [rootD_lt]
· have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›)
apply parentD_findAux_lt h
termination_by self.rankMax - self.rank x

theorem parentD_findAux_or (self : UnionFind) (x : Fin self.size) (i) :
Expand All @@ -341,10 +355,12 @@ theorem parentD_findAux_or (self : UnionFind) (x : Fin self.size) (i) :
if h' : (self.arr.get x).parent = x then
rw [findAux_s, if_pos h']; exact .inr rfl
else
rw [parentD_findAux]; split <;> [simp [*]; skip]
have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›)
exact (parentD_findAux_or self ⟨_, self.parent'_lt x⟩ i).imp_left <| .imp_right fun h => by
simp only [h, ← parentD_eq, rootD_parent, Array.data_length]
rw [parentD_findAux]
split
· simp [*]
· have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›)
exact (parentD_findAux_or self ⟨_, self.parent'_lt x⟩ i).imp_left <| .imp_right fun h => by
simp only [h, ← parentD_eq, rootD_parent, Array.data_length]
termination_by self.rankMax - self.rank x

theorem lt_rankD_findAux {self : UnionFind} {x : Fin self.size} :
Expand All @@ -365,7 +381,9 @@ def find (self : UnionFind) (x : Fin self.size) :
let r := self.findAux x
{ 1.arr := r.s
2.1.val := r.root
1.parentD_lt := fun h => by simp [FindAux.size_eq] at *; exact parentD_findAux_lt h
1.parentD_lt := fun h => by
simp only [Array.data_length, FindAux.size_eq] at *
exact parentD_findAux_lt h
1.rankD_lt := fun h => by rw [rankD_findAux, rankD_findAux]; exact lt_rankD_findAux h
2.1.isLt := show _ < r.s.size by rw [r.size_eq]; exact r.root.2
2.2 := by simp [size, r.size_eq] }
Expand Down Expand Up @@ -398,7 +416,8 @@ def findD (self : UnionFind) (x : Nat) : UnionFind × Nat :=

@[simp] theorem find_parent_1 (self : UnionFind) (x : Fin self.size) :
(self.find x).1.parent x = self.rootD x := by
simp [find, parent]; rw [parentD_findAux, if_pos rfl]
simp only [parent, Array.data_length, find]
rw [parentD_findAux, if_pos rfl]

theorem find_parent_or (self : UnionFind) (x : Fin self.size) (i) :
(self.find x).1.parent i = self.rootD i ∧ self.rootD i = self.rootD x ∨
Expand Down Expand Up @@ -449,7 +468,8 @@ theorem setParentBump_rankD_lt {arr : Array UFNode} {x y : Fin arr.size}
simp [hP, hR, -Array.get_eq_getElem] at *; split <;> rename_i h₁ <;> [simp [← h₁]; skip] <;>
split <;> rename_i h₂ <;> intro h
· simp [h₂] at h
· simp [rankD_eq]; split <;> rename_i h₃
· simp only [rankD_eq, Array.get_eq_getElem]
split <;> rename_i h₃
· rw [← h₃]; apply Nat.lt_succ_self
· exact Nat.lt_of_le_of_ne H h₃
· cases h₂.1
Expand All @@ -469,13 +489,16 @@ theorem setParent_rankD_lt {arr : Array UFNode} {x y : Fin arr.size}
(by simp [rankD_set, Nat.ne_of_lt h, rankD_eq, -Array.get_eq_getElem])

@[simp] theorem linkAux_size : (linkAux self x y).size = self.size := by
simp [linkAux]; split <;> [rfl; split] <;> [skip; split] <;> simp
simp only [linkAux, Array.get_eq_getElem]
split <;> [rfl; split] <;> [skip; split] <;> simp

/-- Link a union-find node to a root node. -/
def link (self : UnionFind) (x y : Fin self.size) (yroot : self.parent y = y) : UnionFind where
arr := linkAux self.arr x y
parentD_lt h := by
simp at *; simp [linkAux]; split <;> [skip; split <;> [skip; split]]
simp only [Array.data_length, linkAux_size] at *
simp only [linkAux, Array.get_eq_getElem]
split <;> [skip; split <;> [skip; split]]
· exact self.parentD_lt h
· rw [parentD_set]; split <;> [exact x.2; exact self.parentD_lt h]
· rw [parentD_set]; split
Expand All @@ -484,12 +507,21 @@ def link (self : UnionFind) (x y : Fin self.size) (yroot : self.parent y = y) :
· rw [parentD_set]; split <;> [exact y.2; exact self.parentD_lt h]
rankD_lt := by
rw [parent, parentD_eq] at yroot
simp [linkAux]; split <;> [skip; split <;> [skip; split]]
simp only [linkAux, Array.get_eq_getElem, ne_eq]
split <;> [skip; split <;> [skip; split]]
· exact self.rankD_lt
· exact setParent_rankD_lt ‹_› self.rankD_lt
· refine setParentBump_rankD_lt (.inr yroot) (Nat.le_of_eq ‹_›) self.rankD_lt
(by simp [parentD_set]; rintro rfl; simp [*, parentD_eq]) fun {i} => ?_
simp [rankD_set]; split <;> simp [*]; rintro rfl; simp [rankD_eq, *]
· refine setParentBump_rankD_lt (.inr yroot) (Nat.le_of_eq ‹_›) self.rankD_lt (by
simp only [parentD_set, ite_eq_right_iff]
rintro rfl
simp [*, parentD_eq]) fun {i} => ?_
simp only [rankD_set, Fin.eta, Array.get_eq_getElem]
split
· simp_all
· simp_all only [Array.get_eq_getElem, Array.data_length, Nat.lt_irrefl, not_false_eq_true,
and_true, ite_false, ite_eq_right_iff]
rintro rfl
simp [rankD_eq, *]
· exact setParent_rankD_lt (Nat.lt_of_le_of_ne (Nat.not_lt.1 ‹_›) ‹_›) self.rankD_lt

@[inherit_doc link]
Expand Down