Skip to content

Commit

Permalink
feat: change Array.set to take a Nat and a tactic provided bound (#5988)
Browse files Browse the repository at this point in the history
This PR changes the signature of `Array.set` to take a `Nat`, and a
tactic-provided bound, rather than a `Fin`.

Corresponding changes (but without the auto-param) for `Array.get` will
arrive shortly, after which I'll go more pervasively through the Array
API.
  • Loading branch information
kim-em authored Nov 11, 2024
1 parent 456e6d2 commit 258d372
Show file tree
Hide file tree
Showing 27 changed files with 156 additions and 144 deletions.
1 change: 1 addition & 0 deletions src/Init.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ import Init.Omega
import Init.MacroTrace
import Init.Grind
import Init.While
import Init.Syntax
1 change: 1 addition & 0 deletions src/Init/Data/Array.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ import Init.Data.Array.TakeDrop
import Init.Data.Array.Bootstrap
import Init.Data.Array.GetLit
import Init.Data.Array.MapIdx
import Init.Data.Array.Set
12 changes: 7 additions & 5 deletions src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Init.Data.Repr
import Init.Data.ToString.Basic
import Init.GetElem
import Init.Data.List.ToArray
import Init.Data.Array.Set
universe u v w

/-! ### Array literal syntax -/
Expand All @@ -29,7 +30,8 @@ namespace Array

/-! ### Preliminary theorems -/

@[simp] theorem size_set (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size :=
@[simp] theorem size_set (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
(set a i v h).size = a.size :=
List.length_set ..

@[simp] theorem size_push (a : Array α) (v : α) : (push a v).size = a.size + 1 :=
Expand Down Expand Up @@ -141,7 +143,7 @@ def uget (a : @& Array α) (i : USize) (h : i.toNat < a.size) : α :=
`fset` may be slightly slower than `uset`. -/
@[extern "lean_array_uset"]
def uset (a : Array α) (i : USize) (v : α) (h : i.toNat < a.size) : Array α :=
a.set i.toNat, h⟩ v
a.set i.toNat v h

@[extern "lean_array_pop"]
def pop (a : Array α) : Array α where
Expand All @@ -167,10 +169,10 @@ def swap (a : Array α) (i j : @& Fin a.size) : Array α :=
let v₁ := a.get i
let v₂ := a.get j
let a' := a.set i v₂
a'.set (size_set a i v₂ ▸ j) v₁
a'.set j v₁ (Nat.lt_of_lt_of_eq j.isLt (size_set a i v₂ _).symm)

@[simp] theorem size_swap (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by
show ((a.set i (a.get j)).set (size_set a i _ ▸ j) (a.get i)).size = a.size
show ((a.set i (a.get j)).set j (a.get i) (Nat.lt_of_lt_of_eq j.isLt (size_set a i (a.get j) _).symm)).size = a.size
rw [size_set, size_set]

/--
Expand Down Expand Up @@ -278,7 +280,7 @@ unsafe def modifyMUnsafe [Monad m] (a : Array α) (i : Nat) (f : α → m α) :
-- of the element type, and that it is valid to store `box(0)` in any array.
let a' := a.set idx (unsafeCast ())
let v ← f v
pure <| a'.set (size_set a .. ▸ idx) v
pure <| a'.set idx v (Nat.lt_of_lt_of_eq h (size_set a ..).symm)
else
pure a

Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Array/BasicAux.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
if ptrEq a b then
go (i+1) as
else
go (i+1) (as.set ⟨i, h⟩ b)
go (i+1) (as.set i b h)
else
return as

Expand Down
90 changes: 38 additions & 52 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -483,25 +483,26 @@ theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default

/-! # set -/

@[simp] theorem getElem_set_eq (a : Array α) (i : Fin a.size) (v : α) {j : Nat}
(eq : i.val = j) (p : j < (a.set i v).size) :
@[simp] theorem getElem_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) {j : Nat}
(eq : i = j) (p : j < (a.set i v).size) :
(a.set i v)[j]'p = v := by
simp [set, getElem_eq_getElem_toList, ←eq]

@[simp] theorem getElem_set_ne (a : Array α) (i : Fin a.size) (v : α) {j : Nat} (pj : j < (a.set i v).size)
(h : i.val ≠ j) : (a.set i v)[j]'pj = a[j]'(size_set a i v ▸ pj) := by
@[simp] theorem getElem_set_ne (a : Array α) (i : Nat) (h' : i < a.size) (v : α) {j : Nat}
(pj : j < (a.set i v).size) (h : i ≠ j) :
(a.set i v)[j]'pj = a[j]'(size_set a i v _ ▸ pj) := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]

theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat)
theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat)
(h : j < (a.set i v).size) :
(a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v ▸ h) := by
by_cases p : i.1 = j <;> simp [p]
(a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v _ ▸ h) := by
by_cases p : i = j <;> simp [p]

@[simp] theorem getElem?_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1]? = v := by simp [getElem?_lt, i.2]
@[simp] theorem getElem?_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) :
(a.set i v)[i]? = v := by simp [getElem?_lt, h]

@[simp] theorem getElem?_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α)
(ne : i.val ≠ j) : (a.set i v)[j]? = a[j]? := by
@[simp] theorem getElem?_set_ne (a : Array α) (i : Nat) (h : i < a.size) {j : Nat} (v : α)
(ne : i ≠ j) : (a.set i v)[j]? = a[j]? := by
by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h]

/-! # setD -/
Expand All @@ -518,7 +519,7 @@ theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat)
@[simp] theorem getElem_setD_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setD a i v)[i]'h = v := by
simp at h
simp only [setD, h, dite_true, getElem_set, ite_true]
simp only [setD, h, ↓reduceDIte, getElem_set_eq]

@[simp]
theorem getElem?_setD_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setD i v)[i]? = some v := by
Expand Down Expand Up @@ -693,43 +694,43 @@ theorem getElem?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some

@[deprecated getElem?_size (since := "2024-10-21")] abbrev get?_size := @getElem?_size

@[simp] theorem toList_set (a : Array α) (i v) : (a.set i v).toList = a.toList.set i.1 v := rfl
@[simp] theorem toList_set (a : Array α) (i v h) : (a.set i v).toList = a.toList.set i v := rfl

theorem get_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1] = v := by
theorem get_set_eq (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
(a.set i v h)[i]'(by simp [h]) = v := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_self]

theorem get?_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1]? = v := by simp [getElem?_pos, i.2]
theorem get?_set_eq (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
(a.set i v)[i]? = v := by simp [getElem?_pos, h]

@[simp] theorem get?_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α)
(h : i.1 ≠ j) : (a.set i v)[j]? = a[j]? := by
@[simp] theorem get?_set_ne (a : Array α) (i : Nat) (h' : i < a.size) {j : Nat} (v : α)
(h : i ≠ j) : (a.set i v)[j]? = a[j]? := by
by_cases j < a.size <;> simp [getElem?_pos, getElem?_neg, *]

theorem get?_set (a : Array α) (i : Fin a.size) (j : Nat) (v : α) :
(a.set i v)[j]? = if i.1 = j then some v else a[j]? := by
if h : i.1 = j then subst j; simp [*] else simp [*]
theorem get?_set (a : Array α) (i : Nat) (h : i < a.size) (j : Nat) (v : α) :
(a.set i v)[j]? = if i = j then some v else a[j]? := by
if h : i = j then subst j; simp [*] else simp [*]

theorem get_set (a : Array α) (i : Fin a.size) (j : Nat) (hj : j < a.size) (v : α) :
theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a.size) (v : α) :
(a.set i v)[j]'(by simp [*]) = if i = j then v else a[j] := by
if h : i.1 = j then subst j; simp [*] else simp [*]
if h : i = j then subst j; simp [*] else simp [*]

@[simp] theorem get_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) (hj : j < a.size)
(h : i.1 ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
@[simp] theorem get_set_ne (a : Array α) (i : Nat) (hi : i < a.size) {j : Nat} (v : α) (hj : j < a.size)
(h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]

theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
(setD a i v)[i] = v := by
simp at h
simp only [setD, h, dite_true, get_set, ite_true]
simp only [setD, h, ↓reduceDIte, getElem_set_eq]

theorem set_set (a : Array α) (i : Fin a.size) (v v' : α) :
(a.set i v).set ⟨i, by simp [i.2]⟩ v' = a.set i v' := by simp [set, List.set_set]
theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]

private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl

theorem swap_def (a : Array α) (i j : Fin a.size) :
a.swap i j = (a.set i (a.get j)).set ⟨j.1, by simp [j.2]⟩ (a.get i) := by
a.swap i j = (a.set i (a.get j)).set j (a.get i) := by
simp [swap, fin_cast_val]

@[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) :
Expand All @@ -747,7 +748,7 @@ theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j)

@[simp]
theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
a.swapAt! i v = (a[i], a.set ⟨i, h⟩ v) := by simp [swapAt!, h]
a.swapAt! i v = (a[i], a.set i v) := by simp [swapAt!, h]

@[simp] theorem size_swapAt! (a : Array α) (i : Nat) (v : α) :
(a.swapAt! i v).2.size = a.size := by
Expand Down Expand Up @@ -1112,7 +1113,7 @@ theorem getElem_modify {as : Array α} {x i} (h : i < (as.modify x f).size) :
(as.modify x f)[i] = if x = i then f (as[i]'(by simpa using h)) else as[i]'(by simpa using h) := by
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
split
· simp only [Id.bind_eq, get_set _ _ _ (by simpa using h)]; split <;> simp [*]
· simp only [Id.bind_eq, get_set _ _ _ _ (by simpa using h)]; split <;> simp [*]
· rw [if_neg (mt (by rintro rfl; exact h) (by simp_all))]

@[simp] theorem toList_modify (as : Array α) (f : α → α) :
Expand Down Expand Up @@ -1541,30 +1542,15 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=

open Fin

@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.val] = a[i] :=
by simp only [swap, fin_cast_val, get_eq_getElem, getElem_set_eq, getElem_fin]
@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.1] = a[i] := by
simp [swap_def, getElem_set]

@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.val] = a[j] :=
if he : ((Array.size_set _ _ _).symm ▸ j).val = i.val then by
simp only [←he, fin_cast_val, getElem_swap_right, getElem_fin]
else by
apply Eq.trans
· apply Array.get_set_ne
· simp only [size_set, Fin.isLt]
· assumption
· simp [get_set_ne]
@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.1] = a[j] := by
simp +contextual [swap_def, getElem_set]

@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Fin a.size} (hp : p < a.size)
(hi : p ≠ i) (hj : p ≠ j) : (a.swap i j)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by
apply Eq.trans
· have : ((a.size_set i (a.get j)).symm ▸ j).val = j.val := by simp only [fin_cast_val]
apply Array.get_set_ne
· simp only [this]
apply Ne.symm
· assumption
· apply Array.get_set_ne
· apply Ne.symm
· assumption
simp [swap_def, getElem_set, hi.symm, hj.symm]

theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) :
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
Expand Down
39 changes: 39 additions & 0 deletions src/Init/Data/Array/Set.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Mario Carneiro
-/
prelude
import Init.Tactics


/--
Set an element in an array, using a proof that the index is in bounds.
(This proof can usually be omitted, and will be synthesized automatically.)
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_fset"]
def Array.set (a : Array α) (i : @& Nat) (v : α) (h : i < a.size := by get_elem_tactic) :
Array α where
toList := a.toList.set i v

/--
Set an element in an array, or do nothing if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α :=
dite (LT.lt i a.size) (fun h => a.set i v h) (fun _ => a)

/--
Set an element in an array, or panic if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_set"]
def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
Array.setD a i v
2 changes: 1 addition & 1 deletion src/Init/Data/ByteArray/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set! : ByteArray → (@& Nat) → UInt8 → ByteArray

@[extern "lean_byte_array_fset"]
def set : (a : ByteArray) → (@& Fin a.size) → UInt8 → ByteArray
| ⟨bs⟩, i, b => ⟨bs.set i b
| ⟨bs⟩, i, b => ⟨bs.set i.1 b i.2

@[extern "lean_byte_array_uset"]
def uset : (a : ByteArray) → (i : USize) → UInt8 → i.toNat < a.size → ByteArray
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/FloatArray/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def uset : (a : FloatArray) → (i : USize) → Float → i.toNat < a.size → F

@[extern "lean_float_array_fset"]
def set : (ds : FloatArray) → (@& Fin ds.size) → Float → FloatArray
| ⟨ds⟩, i, d => ⟨ds.set i d
| ⟨ds⟩, i, d => ⟨ds.set i.1 d i.2

@[extern "lean_float_array_set"]
def set! : FloatArray → (@& Nat) → Float → FloatArray
Expand Down
3 changes: 2 additions & 1 deletion src/Init/Meta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Additional goodies for writing macros
-/
prelude
import Init.MetaTypes
import Init.Syntax
import Init.Data.Array.GetLit
import Init.Data.Option.BasicAux

Expand Down Expand Up @@ -442,7 +443,7 @@ def unsetTrailing (stx : Syntax) : Syntax :=
if h : i < a.size then
let v := a[i]
match f v with
| some v => some <| a.set ⟨i, h⟩ v
| some v => some <| a.set i v h
| none => updateFirst a f (i+1)
else
none
Expand Down
61 changes: 7 additions & 54 deletions src/Init/Prelude.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2688,35 +2688,6 @@ def Array.mkArray7 {α : Type u} (a₁ a₂ a₃ a₄ a₅ a₆ a₇ : α) : Arr
def Array.mkArray8 {α : Type u} (a₁ a₂ a₃ a₄ a₅ a₆ a₇ a₈ : α) : Array α :=
((((((((mkEmpty 8).push a₁).push a₂).push a₃).push a₄).push a₅).push a₆).push a₇).push a₈

/--
Set an element in an array without bounds checks, using a `Fin` index.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_fset"]
def Array.set (a : Array α) (i : @& Fin a.size) (v : α) : Array α where
toList := a.toList.set i.val v

/--
Set an element in an array, or do nothing if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α :=
dite (LT.lt i a.size) (fun h => a.set ⟨i, h⟩ v) (fun _ => a)

/--
Set an element in an array, or panic if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_set"]
def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
Array.setD a i v

/-- Slower `Array.append` used in quotations. -/
protected def Array.appendCore {α : Type u} (as : Array α) (bs : Array α) : Array α :=
let rec loop (i : Nat) (j : Nat) (as : Array α) : Array α :=
Expand Down Expand Up @@ -3637,6 +3608,13 @@ def appendCore : Name → Name → Name

end Name

/-- The default maximum recursion depth. This is adjustable using the `maxRecDepth` option. -/
def defaultMaxRecDepth := 512

/-- The message to display on stack overflow. -/
def maxRecDepthErrorMessage : String :=
"maximum recursion depth has been reached\nuse `set_option maxRecDepth <num>` to increase limit\nuse `set_option diagnostics true` to get diagnostic information"

/-! # Syntax -/

/-- Source information of tokens. -/
Expand Down Expand Up @@ -3969,24 +3947,6 @@ def getId : Syntax → Name
| ident _ _ val _ => val
| _ => Name.anonymous

/--
Updates the argument list without changing the node kind.
Does nothing for non-`node` nodes.
-/
def setArgs (stx : Syntax) (args : Array Syntax) : Syntax :=
match stx with
| node info k _ => node info k args
| stx => stx

/--
Updates the `i`'th argument of the syntax.
Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list.
-/
def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax :=
match stx with
| node info k args => node info k (args.setD i arg)
| stx => stx

/-- Retrieve the left-most node or leaf's info in the Syntax tree. -/
partial def getHeadInfo? : Syntax → Option SourceInfo
| atom info _ => some info
Expand Down Expand Up @@ -4423,13 +4383,6 @@ main module and current macro scope.
bind getCurrMacroScope fun scp =>
pure (Lean.addMacroScope mainModule n scp)

/-- The default maximum recursion depth. This is adjustable using the `maxRecDepth` option. -/
def defaultMaxRecDepth := 512

/-- The message to display on stack overflow. -/
def maxRecDepthErrorMessage : String :=
"maximum recursion depth has been reached\nuse `set_option maxRecDepth <num>` to increase limit\nuse `set_option diagnostics true` to get diagnostic information"

namespace Syntax

/-- Is this syntax a null `node`? -/
Expand Down
Loading

0 comments on commit 258d372

Please sign in to comment.