Skip to content

Commit

Permalink
chore: Array cleanup (leanprover#5782)
Browse files Browse the repository at this point in the history
More cleanup of Array API. More to come.
  • Loading branch information
kim-em authored Oct 21, 2024
1 parent 4f18c29 commit 8151ac7
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 76 deletions.
118 changes: 87 additions & 31 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Init.Data.Nat.Lemmas
import Init.Data.List.Impl
import Init.Data.List.Monadic
import Init.Data.List.Range
import Init.Data.List.Nat.TakeDrop
import Init.Data.Array.Mem
import Init.TacticsExtra

Expand Down Expand Up @@ -43,21 +44,32 @@ theorem getElem?_eq_getElem?_toList (a : Array α) (i : Nat) : a[i]? = a.toList[
rw [getElem?_eq]
split <;> simp_all

theorem get_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
theorem getElem_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
have : i < (a.push x).size := by simp [*, Nat.lt_succ_of_le, Nat.le_of_lt]
(a.push x)[i] = a[i] := by
simp only [push, getElem_eq_getElem_toList, List.concat_eq_append, List.getElem_append_left, h]

@[simp] theorem get_push_eq (a : Array α) (x : α) : (a.push x)[a.size] = x := by
@[simp] theorem getElem_push_eq (a : Array α) (x : α) : (a.push x)[a.size] = x := by
simp only [push, getElem_eq_getElem_toList, List.concat_eq_append]
rw [List.getElem_append_right] <;> simp [getElem_eq_getElem_toList, Nat.zero_lt_one]

theorem get_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size) :
theorem getElem_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size) :
(a.push x)[i] = if h : i < a.size then a[i] else x := by
by_cases h' : i < a.size
· simp [get_push_lt, h']
· simp [getElem_push_lt, h']
· simp at h
simp [get_push_lt, Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.ge_of_not_lt h')]
simp [getElem_push_lt, Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.ge_of_not_lt h')]

@[deprecated getElem_push (since := "2024-10-21")] abbrev get_push := @getElem_push
@[deprecated getElem_push_lt (since := "2024-10-21")] abbrev get_push_lt := @getElem_push_lt
@[deprecated getElem_push_eq (since := "2024-10-21")] abbrev get_push_eq := @getElem_push_eq

@[simp] theorem get!_eq_getElem! [Inhabited α] (a : Array α) (i : Nat) : a.get! i = a[i]! := by
simp [getElem!_def, get!, getD]
split <;> rename_i h
· simp [getElem?_eq_getElem h]
rfl
· simp [getElem?_eq_none_iff.2 (by simpa using h)]

end Array

Expand All @@ -81,6 +93,9 @@ We prefer to pull `List.toArray` outwards.

@[simp] theorem getElem?_toArray {a : List α} {i : Nat} : a.toArray[i]? = a[i]? := rfl

@[simp] theorem getElem!_toArray [Inhabited α] {a : List α} {i : Nat} :
a.toArray[i]! = a[i]! := rfl

@[simp] theorem push_toArray (l : List α) (a : α) : l.toArray.push a = (l ++ [a]).toArray := by
apply ext'
simp
Expand All @@ -90,6 +105,14 @@ We prefer to pull `List.toArray` outwards.
funext a
simp

@[simp] theorem isEmpty_toArray (l : List α) : l.toArray.isEmpty = l.isEmpty := by
cases l <;> simp

@[simp] theorem toArray_singleton (a : α) : (List.singleton a).toArray = singleton a := rfl

@[simp] theorem back_toArray [Inhabited α] (l : List α) : l.toArray.back = l.getLast! := by
simp only [back, size_toArray, Array.get!_eq_getElem!, getElem!_toArray, getLast!_eq_getElem!]

theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) :
l.toArray.foldrM f init = l.foldrM f init := by
rw [foldrM_eq_reverse_foldlM_toList]
Expand Down Expand Up @@ -248,7 +271,7 @@ theorem size_uset (a : Array α) (v i h) : (uset a i v h).size = a.size := by si
@[simp] theorem get_eq_getElem (a : Array α) (i : Fin _) : a.get i = a[i.1] := rfl

theorem getElem?_lt
(a : Array α) {i : Nat} (h : i < a.size) : a[i]? = some (a[i]) := dif_pos h
(a : Array α) {i : Nat} (h : i < a.size) : a[i]? = some a[i] := dif_pos h

theorem getElem?_ge
(a : Array α) {i : Nat} (h : i ≥ a.size) : a[i]? = none := dif_neg (Nat.not_lt_of_le h)
Expand All @@ -271,8 +294,10 @@ theorem getD_get? (a : Array α) (i : Nat) (d : α) :

theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default := rfl

@[simp] theorem get!_eq_getElem? [Inhabited α] (a : Array α) (i : Nat) : a.get! i = (a.get? i).getD default := by
by_cases p : i < a.size <;> simp [getD_get?, get!_eq_getD, p]
@[simp] theorem get!_eq_getElem? [Inhabited α] (a : Array α) (i : Nat) :
a.get! i = (a.get? i).getD default := by
by_cases p : i < a.size <;>
simp only [get!_eq_getD, getD_eq_get?, getD_get?, p, get?_eq_getElem?]

/-! # set -/

Expand Down Expand Up @@ -352,8 +377,8 @@ theorem getElem_ofFn_go (f : Fin n → α) (i) {acc k}
simp only [dif_pos hin]
rw [getElem_ofFn_go f (i+1) _ hin (by simp [*]) (fun j hj => ?hacc)]
cases (Nat.lt_or_eq_of_le <| Nat.le_of_lt_succ (by simpa using hj)) with
| inl hj => simp [get_push, hj, hacc j hj]
| inr hj => simp [get_push, *]
| inl hj => simp [getElem_push, hj, hacc j hj]
| inr hj => simp [getElem_push, *]
else
simp [hin, hacc k (Nat.lt_of_lt_of_le hki (Nat.le_of_not_lt (hi ▸ hin)))]
termination_by n - i
Expand Down Expand Up @@ -421,7 +446,7 @@ theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size}
idx < a.size :=
hidx

theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by
@[simp] theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by
erw [Array.mem_def, getElem_eq_getElem_toList]
apply List.get_mem

Expand All @@ -430,55 +455,65 @@ theorem getElem_fin_eq_getElem_toList (a : Array α) (i : Fin a.size) : a[i] = a
@[simp] theorem ugetElem_eq_getElem (a : Array α) {i : USize} (h : i.toNat < a.size) :
a[i] = a[i.toNat] := rfl

theorem get?_len_le (a : Array α) (i : Nat) (h : a.size ≤ i) : a[i]? = none := by
theorem getElem?_size_le (a : Array α) (i : Nat) (h : a.size ≤ i) : a[i]? = none := by
simp [getElem?_neg, h]

@[deprecated getElem?_size_le (since := "2024-10-21")] abbrev get?_len_le := @getElem?_size_le

theorem getElem_mem_toList (a : Array α) (h : i < a.size) : a[i] ∈ a.toList := by
simp only [getElem_eq_getElem_toList, List.getElem_mem]

theorem get?_eq_get?_toList (a : Array α) (i : Nat) : a.get? i = a.toList.get? i := by
simp [getElem?_eq_getElem?_toList]

theorem get!_eq_get? [Inhabited α] (a : Array α) : a.get! n = (a.get? n).getD default := by
simp [get!_eq_getD]
simp only [get!_eq_getElem?, get?_eq_getElem?]

theorem getElem?_eq_some_iff {as : Array α} : as[n]? = some a ↔ ∃ h : n < as.size, as[n] = a := by
cases as
simp [List.getElem?_eq_some_iff]

@[simp] theorem back_eq_back? [Inhabited α] (a : Array α) : a.back = a.back?.getD default := by
simp [back, back?]
simp only [back, get!_eq_getElem?, get?_eq_getElem?, back?]

@[simp] theorem back?_push (a : Array α) : (a.push x).back? = some x := by
simp [back?, getElem?_eq_getElem?_toList]

theorem back_push [Inhabited α] (a : Array α) : (a.push x).back = x := by simp

theorem get?_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
theorem getElem?_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
(a.push x)[i]? = some a[i] := by
rw [getElem?_pos, get_push_lt]
rw [getElem?_pos, getElem_push_lt]

@[deprecated getElem?_push_lt (since := "2024-10-21")] abbrev get?_push_lt := @getElem?_push_lt

theorem getElem?_push_eq (a : Array α) (x : α) : (a.push x)[a.size]? = some x := by
rw [getElem?_pos, getElem_push_eq]

theorem get?_push_eq (a : Array α) (x : α) : (a.push x)[a.size]? = some x := by
rw [getElem?_pos, get_push_eq]
@[deprecated getElem?_push_eq (since := "2024-10-21")] abbrev get?_push_eq := @getElem?_push_eq

theorem get?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some x else a[i]? := by
theorem getElem?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some x else a[i]? := by
match Nat.lt_trichotomy i a.size with
| Or.inl g =>
have h1 : i < a.size + 1 := by omega
have h2 : i ≠ a.size := by omega
simp [getElem?_def, size_push, g, h1, h2, get_push_lt]
simp [getElem?_def, size_push, g, h1, h2, getElem_push_lt]
| Or.inr (Or.inl heq) =>
simp [heq, getElem?_pos, get_push_eq]
simp [heq, getElem?_pos, getElem_push_eq]
| Or.inr (Or.inr g) =>
simp only [getElem?_def, size_push]
have h1 : ¬ (i < a.size) := by omega
have h2 : ¬ (i < a.size + 1) := by omega
have h3 : i ≠ a.size := by omega
simp [h1, h2, h3]

@[simp] theorem get?_size {a : Array α} : a[a.size]? = none := by
@[deprecated getElem?_push (since := "2024-10-21")] abbrev get?_push := @getElem?_push

@[simp] theorem getElem?_size {a : Array α} : a[a.size]? = none := by
simp only [getElem?_def, Nat.lt_irrefl, dite_false]

@[deprecated getElem?_size (since := "2024-10-21")] abbrev get?_size := @getElem?_size

@[simp] theorem toList_set (a : Array α) (i v) : (a.set i v).toList = a.toList.set i.1 v := rfl

theorem get_set_eq (a : Array α) (i : Fin a.size) (v : α) :
Expand Down Expand Up @@ -528,6 +563,9 @@ theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j)
@[simp] theorem swapAt_def (a : Array α) (i : Fin a.size) (v : α) :
a.swapAt i v = (a[i.1], a.set i v) := rfl

@[simp] theorem size_swapAt (a : Array α) (i : Fin a.size) (v : α) :
(a.swapAt i v).2.size = a.size := by simp [swapAt_def]

@[simp]
theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
a.swapAt! i v = (a[i], a.set ⟨i, h⟩ v) := by simp [swapAt!, h]
Expand Down Expand Up @@ -560,11 +598,11 @@ theorem eq_push_pop_back_of_size_ne_zero [Inhabited α] {as : Array α} (h : as.
· simp [Nat.sub_add_cancel (Nat.zero_lt_of_ne_zero h)]
· intros i h h'
if hlt : i < as.pop.size then
rw [get_push_lt (h:=hlt), getElem_pop]
rw [getElem_push_lt (h:=hlt), getElem_pop]
else
have heq : i = as.pop.size :=
Nat.le_antisymm (size_pop .. ▸ Nat.le_pred_of_lt h) (Nat.le_of_not_gt hlt)
cases heq; rw [get_push_eq, back, ←size_pop, get!_eq_getD, getD, dif_pos h]; rfl
cases heq; rw [getElem_push_eq, back, ←size_pop, get!_eq_getD, getD, dif_pos h]; rfl

theorem eq_push_of_size_ne_zero {as : Array α} (h : as.size ≠ 0) :
∃ (bs : Array α) (c : α), as = bs.push c :=
Expand Down Expand Up @@ -773,9 +811,9 @@ theorem map_induction (as : Array α) (f : α → β) (motive : Nat → Prop) (h
· intro j h
simp at h ⊢
by_cases h' : j < size b
· rw [get_push]
· rw [getElem_push]
simp_all
· rw [get_push, dif_neg h']
· rw [getElem_push, dif_neg h']
simp only [show j = i by omega]
exact (hs _ m).1

Expand All @@ -800,7 +838,7 @@ theorem map_spec (as : Array α) (f : α → β) (p : Fin as.size → β → Pro
(as.push x).map f = (as.map f).push (f x) := by
ext
· simp
· simp only [getElem_map, get_push, size_map]
· simp only [getElem_map, getElem_push, size_map]
split <;> rfl

@[simp] theorem map_pop {f : α → β} {as : Array α} :
Expand Down Expand Up @@ -831,6 +869,11 @@ theorem getElem_modify_of_ne {as : Array α} {i : Nat} (h : i ≠ j)
(as.modify i f)[j] = as[j]'(by simpa using hj) := by
simp [getElem_modify hj, h]

theorem getElem?_modify {as : Array α} {i : Nat} {f : α → α} {j : Nat} :
(as.modify i f)[j]? = if i = j then as[j]?.map f else as[j]? := by
simp only [getElem?_def, size_modify, getElem_modify, Option.map_dif]
split <;> split <;> rfl

/-! ### filter -/

@[simp] theorem toList_filter (p : α → Bool) (l : Array α) :
Expand Down Expand Up @@ -892,7 +935,7 @@ theorem filterMap_congr {as bs : Array α} (h : as = bs)

theorem size_empty : (#[] : Array α).size = 0 := rfl

theorem toList_empty : (#[] : Array α).toList = [] := rfl
@[simp] theorem toList_empty : (#[] : Array α).toList = [] := rfl

/-! ### append -/

Expand Down Expand Up @@ -1050,7 +1093,7 @@ theorem getElem_extract_loop_ge (as bs : Array α) (size start : Nat) (hge : i
have h₂ : bs.size < (extract.loop as size (start+1) (bs.push as[start])).size := by
rw [size_extract_loop]; apply Nat.lt_of_lt_of_le h₁; exact Nat.le_add_right ..
have h : (extract.loop as size (start + 1) (push bs as[start]))[bs.size] = as[start] := by
rw [getElem_extract_loop_lt as (bs.push as[start]) size (start+1) h₁ h₂, get_push_eq]
rw [getElem_extract_loop_lt as (bs.push as[start]) size (start+1) h₁ h₂, getElem_push_eq]
rw [h]; congr; rw [Nat.add_sub_cancel]
else
have hge : bs.size + 1 ≤ i := Nat.lt_of_le_of_ne hge hi
Expand All @@ -1077,6 +1120,14 @@ theorem getElem?_extract {as : Array α} {start stop : Nat} :
· omega
· rfl

@[simp] theorem toList_extract (as : Array α) (start stop : Nat) :
(as.extract start stop).toList = (as.toList.drop start).take (stop - start) := by
apply List.ext_getElem
· simp only [length_toList, size_extract, List.length_take, List.length_drop]
omega
· intros n h₁ h₂
simp

@[simp] theorem extract_all (as : Array α) : as.extract 0 as.size = as := by
apply ext
· rw [size_extract, Nat.min_self, Nat.sub_zero]
Expand Down Expand Up @@ -1246,7 +1297,7 @@ open Fin
· assumption

theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) :
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
split
· simp_all only [getElem_swap_left]
· split <;> simp_all
Expand All @@ -1256,7 +1307,7 @@ theorem getElem_swap (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < (a.sw
apply getElem_swap'

@[simp] theorem swap_swap (a : Array α) {i j : Fin a.size} :
(a.swap i j).swap ⟨i.1, (a.size_swap ..).symm ▸i.2⟩ ⟨j.1, (a.size_swap ..).symm ▸j.2⟩ = a := by
(a.swap i j).swap ⟨i.1, (a.size_swap ..).symm ▸ i.2⟩ ⟨j.1, (a.size_swap ..).symm ▸ j.2⟩ = a := by
apply ext
· simp only [size_swap]
· intros
Expand Down Expand Up @@ -1419,6 +1470,11 @@ theorem filterMap_toArray (f : α → Option β) (l : List α) :
apply ext'
simp

@[simp] theorem toArray_extract (l : List α) (start stop : Nat) :
l.toArray.extract start stop = ((l.drop start).take (stop - start)).toArray := by
apply ext'
simp

end List

/-! ### Deprecations -/
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Array/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ theorem mapFinIdx_induction (as : Array α) (f : Fin as.size → α → β)
| succ i ih =>
apply @ih (bs.push (f ⟨j, by omega⟩ as[j])) (j + 1) (by omega) (by simp; omega)
· intro i i_lt h'
rw [get_push]
rw [getElem_push]
split
· apply h₂
· simp only [size_push] at h'
Expand Down
18 changes: 15 additions & 3 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1047,9 +1047,6 @@ theorem get_cons_length (x : α) (xs : List α) (n : Nat) (h : n = xs.length) :

@[simp] theorem getLast?_singleton (a : α) : getLast? [a] = a := rfl

theorem getLast!_of_getLast? [Inhabited α] : ∀ {l : List α}, getLast? l = some a → getLast! l = a
| _ :: _, rfl => rfl

theorem getLast?_eq_getLast : ∀ l h, @getLast? α l = some (getLast l h)
| [], h => nomatch h rfl
| _ :: _, _ => rfl
Expand Down Expand Up @@ -1083,6 +1080,21 @@ theorem getLast?_concat (l : List α) : getLast? (l ++ [a]) = some a := by
theorem getLastD_concat (a b l) : @getLastD α (l ++ [b]) a = b := by
rw [getLastD_eq_getLast?, getLast?_concat]; rfl

/-! ### getLast! -/

@[simp] theorem getLast!_nil [Inhabited α] : ([] : List α).getLast! = default := rfl

theorem getLast!_of_getLast? [Inhabited α] : ∀ {l : List α}, getLast? l = some a → getLast! l = a
| _ :: _, rfl => rfl

theorem getLast!_eq_getElem! [Inhabited α] {l : List α} : l.getLast! = l[l.length - 1]! := by
cases l with
| nil => simp
| cons _ _ =>
apply getLast!_of_getLast?
rw [getElem!_pos, getElem_cons_length (h := by simp)]
rfl

/-! ## Head and tail -/

/-! ### head -/
Expand Down
10 changes: 5 additions & 5 deletions src/Std/Sat/AIG/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ theorem Cache.get?_property {decls : Array (Decl α)} {idx : Nat} (c : Cache α
induction hcache generalizing decl with
| empty => simp at hfound
| push_id wf ih =>
rw [Array.get_push]
rw [Array.getElem_push]
split
· apply ih
simp [hfound]
Expand All @@ -140,7 +140,7 @@ theorem Cache.get?_property {decls : Array (Decl α)} {idx : Nat} (c : Cache α
assumption
| push_cache wf ih =>
rename_i decl'
rw [Array.get_push]
rw [Array.getElem_push]
split
· simp only [HashMap.getElem?_insert] at hfound
match heq : decl == decl' with
Expand Down Expand Up @@ -464,7 +464,7 @@ def mkGate (aig : AIG α) (input : GateInput aig) : Entrypoint α :=
let cache := aig.cache.noUpdate
have invariant := by
intro i lhs' rhs' linv' rinv' h1 h2
simp only [Array.get_push] at h2
simp only [Array.getElem_push] at h2
split at h2
· apply aig.invariant <;> assumption
· injections
Expand All @@ -483,7 +483,7 @@ def mkAtom (aig : AIG α) (n : α) : Entrypoint α :=
let cache := aig.cache.noUpdate
have invariant := by
intro i lhs rhs linv rinv h1 h2
simp only [Array.get_push] at h2
simp only [Array.getElem_push] at h2
split at h2
· apply aig.invariant <;> assumption
· contradiction
Expand All @@ -499,7 +499,7 @@ def mkConst (aig : AIG α) (val : Bool) : Entrypoint α :=
let cache := aig.cache.noUpdate
have invariant := by
intro i lhs rhs linv rinv h1 h2
simp only [Array.get_push] at h2
simp only [Array.getElem_push] at h2
split at h2
· apply aig.invariant <;> assumption
· contradiction
Expand Down
Loading

0 comments on commit 8151ac7

Please sign in to comment.