Skip to content

Commit

Permalink
refactor: clean up public API around Array.eraseIdx (#3676)
Browse files Browse the repository at this point in the history
- Removes the public definitions `Array.eraseIdxAux` and
`Array.eraseIdxSzAux` which were implementation details.
- Motivation: `Array.eraseIdxAux` and `Array.eraseIdxSzAux` were clearly
not intended to remain public, but simply making them private would make
it inconvenient to unfold them when writing proofs in Std.
- Adds documentation comments to the public `Array.eraseIdx`-related
definitions which remain.
- Removes `Array.eraseIdx'` which was just `Array.feraseIdx` wrapped in
a subtype and adds `Array.size_feraseIdx` to prove the subtype property
as a standalone theorem.

Co-Authored-By: Daniel Windham <daniel@atlascomputing.org>
  • Loading branch information
timotree3 and tenedor authored Mar 17, 2024
1 parent 9ee10aa commit 8e96d7b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
47 changes: 26 additions & 21 deletions src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -727,33 +727,38 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
termination_by as.size - i
go 0 #[]

def eraseIdxAux (i : Nat) (a : Array α) : Array α :=
if h : i < a.size then
let idx : Fin a.size := ⟨i, h⟩;
let idx1 : Fin a.size := ⟨i - 1, by exact Nat.lt_of_le_of_lt (Nat.pred_le i) h⟩;
let a' := a.swap idx idx1
eraseIdxAux (i+1) a'
/-- Remove the element at a given index from an array without bounds checks, using a `Fin` index.
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
if h : i.val + 1 < a.size then
let a' := a.swap ⟨i.val + 1, h⟩ i
let i' : Fin a'.size := ⟨i.val + 1, by simp [a', h]⟩
have : a'.size - i' < a.size - i := by
simp [a', Nat.sub_succ_lt_self _ _ i.isLt]
a'.feraseIdx i'
else
a.pop
termination_by a.size - i
termination_by a.size - i.val

def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
eraseIdxAux (i.val + 1) a
derive_functional_induction feraseIdx

def eraseIdx (a : Array α) (i : Nat) : Array α :=
if i < a.size then eraseIdxAux (i+1) a else a
theorem size_feraseIdx (a : Array α) (i : Fin a.size) : (a.feraseIdx i).size = a.size - 1 := by
induction a, i using feraseIdx.induct with
| @case1 a i h a' _ _ ih =>
unfold feraseIdx
simp [h, a', ih]
| case2 a i h =>
unfold feraseIdx
simp [h]

def eraseIdxSzAux (a : Array α) (i : Nat) (r : Array α) (heq : r.size = a.size) : { r : Array α // r.size = a.size - 1 } :=
if h : i < r.size then
let idx : Fin r.size := ⟨i, h⟩;
let idx1 : Fin r.size := ⟨i - 1, by exact Nat.lt_of_le_of_lt (Nat.pred_le i) h⟩;
eraseIdxSzAux a (i+1) (r.swap idx idx1) ((size_swap r idx idx1).trans heq)
else
⟨r.pop, (size_pop r).trans (heq ▸ rfl)⟩
termination_by r.size - i
/-- Remove the element at a given index from an array, or do nothing if the index is out of bounds.
def eraseIdx' (a : Array α) (i : Fin a.size) : { r : Array α // r.size = a.size - 1 } :=
eraseIdxSzAux a (i.val + 1) a rfl
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
def eraseIdx (a : Array α) (i : Nat) : Array α :=
if h : i < a.size then a.feraseIdx ⟨i, h⟩ else a

def erase [BEq α] (as : Array α) (a : α) : Array α :=
match as.indexOf? a with
Expand Down
6 changes: 4 additions & 2 deletions src/Lean/Data/PersistentHashMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,10 @@ partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bo
| n@(Node.collision keys vals heq), _, k =>
match keys.indexOf? k with
| some idx =>
let ⟨keys', keq⟩ := keys.eraseIdx' idx
let ⟨vals', veq⟩ := vals.eraseIdx' (Eq.ndrec idx heq)
let keys' := keys.feraseIdx idx
have keq := keys.size_feraseIdx idx
let vals' := vals.feraseIdx (Eq.ndrec idx heq)
have veq := vals.size_feraseIdx (Eq.ndrec idx heq)
have : keys.size - 1 = vals.size - 1 := by rw [heq]
(Node.collision keys' vals' (keq.trans (this.trans veq.symm)), true)
| none => (n, false)
Expand Down

0 comments on commit 8e96d7b

Please sign in to comment.