Skip to content

Commit

Permalink
chore: restore splitAt pointer equality behaviour (#922)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Aug 17, 2024
1 parent 65f464e commit a975dea
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 7 additions & 1 deletion Batteries/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,14 @@ splitAt 2 [a, b, c] = ([a, b], [c])
```
-/
def splitAt (n : Nat) (l : List α) : List α × List α := go l n [] where
/-- Auxiliary for `splitAt`: `splitAt.go xs n acc = (acc.reverse ++ take n xs, drop n xs)`. -/
/--
Auxiliary for `splitAt`:
`splitAt.go l xs n acc = (acc.reverse ++ take n xs, drop n xs)` if `n < xs.length`,
and `(l, [])` otherwise.
-/
go : List α → Nat → List α → List α × List α
| [], _, _ => (l, []) -- This branch ensures the pointer equality of the result with the input
-- without any runtime branching cost.
| x :: xs, n+1, acc => go xs n (x :: acc)
| xs, _, acc => (acc.reverse, xs)

Expand Down
9 changes: 6 additions & 3 deletions Batteries/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,21 @@ theorem get?_set_of_lt' (a : α) {m n} (l : List α) (h : m < length l) :
/-! ### splitAt -/

theorem splitAt_go (n : Nat) (l acc : List α) :
splitAt.go l n acc = (acc.reverse ++ l.take n, l.drop n) := by
induction l generalizing n acc with
splitAt.go l xs n acc =
if n < xs.length then (acc.reverse ++ xs.take n, xs.drop n) else (l, []) := by
induction xs generalizing n acc with
| nil => simp [splitAt.go]
| cons x xs ih =>
cases n with
| zero => simp [splitAt.go]
| succ n =>
rw [splitAt.go, take_succ_cons, drop_succ_cons, ih n (x :: acc),
reverse_cons, append_assoc, singleton_append]
reverse_cons, append_assoc, singleton_append, length_cons]
simp only [Nat.succ_lt_succ_iff]

theorem splitAt_eq (n : Nat) (l : List α) : splitAt n l = (l.take n, l.drop n) := by
rw [splitAt, splitAt_go, reverse_nil, nil_append]
split <;> simp_all [take_of_length_le, drop_of_length_le]

/-! ### eraseP -/

Expand Down

0 comments on commit a975dea

Please sign in to comment.