Skip to content

Commit

Permalink
refactor: structurally recursive List.ofFn (#784)
Browse files Browse the repository at this point in the history
* refactor: structurally recursive List.ofFn

This used to be defined via `Array.ofFn`
but `Array.ofFn.go` is defined by well-founded recursion (slow to reduce)
and used `Array.push` (quadratic complexity on lists). Since mathlib relies on
reducing `List.ofFn`, use a structurally recursive definition here.

* Update Batteries/Data/List/Basic.lean

Co-authored-by: Kim Morrison <kim@tqft.net>

* refactor: add ofFnTR

---------

Co-authored-by: Kim Morrison <kim@tqft.net>
Co-authored-by: Mario Carneiro <di.gama@gmail.com>
  • Loading branch information
3 people authored May 10, 2024
1 parent 24f2da1 commit 231202b
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion Batteries/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,28 @@ def sigmaTR {σ : α → Type _} (l₁ : List α) (l₂ : ∀ a, List (σ a)) :
ofFn f = [f 0, f 1, ... , f (n - 1)]
```
-/
def ofFn {n} (f : Fin n → α) : List α := (Array.ofFn f).data
def ofFn {n} (f : Fin n → α) : List α := go n 0 rfl where
/-- Auxiliary for `List.ofFn`. `ofFn.go f i j _ = [f j, ..., f (n - 1)]`. -/
-- This used to be defined via `Array.ofFn` but mathlib relies on reducing it,
-- so we use a structurally recursive definition here.
go : (i j : Nat) → (h : i + j = n) → List α
| 0, _, _ => []
| i+1, j, h => f ⟨j, by omega⟩ :: go i (j+1) (Nat.add_right_comm .. ▸ h :)

/-- Tail-recursive version of `ofFn`. -/
@[inline] def ofFnTR {n} (f : Fin n → α) : List α := go n (Nat.le_refl _) [] where
/-- Auxiliary for `List.ofFnTR`. `ofFnTR.go f i _ acc = f 0 :: ... :: f (i - 1) :: acc`. -/
go : (i : Nat) → (h : i ≤ n) → List α → List α
| 0, _, acc => acc
| i+1, h, acc => go i (Nat.le_of_lt h) (f ⟨i, h⟩ :: acc)

@[csimp] theorem ofFn_eq_ofFnTR : @ofFn = @ofFnTR := by
funext α n f; simp [ofFnTR]
let rec go (i j h h') : ofFnTR.go f j h' (ofFn.go f i j h) = ofFn f := by
unfold ofFnTR.go; split
· subst h; rfl
· next l j h' => exact go (i+1) j ((Nat.succ_add ..).trans h) (Nat.le_of_lt h')
exact (go 0 n (Nat.zero_add _) (Nat.le_refl _)).symm

/-- `ofFnNthVal f i` returns `some (f i)` if `i < n` and `none` otherwise. -/
def ofFnNthVal {n} (f : Fin n → α) (i : Nat) : Option α :=
Expand Down

0 comments on commit 231202b

Please sign in to comment.