Skip to content

Commit

Permalink
fix for merge
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em committed Aug 20, 2024
1 parent dbd0028 commit 79479e7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 97 deletions.
30 changes: 0 additions & 30 deletions Batteries/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,6 @@ Unlike `bagInter` this does not preserve multiplicity: `[1, 1].inter [1]` is `[1

instance [BEq α] : Inter (List α) := ⟨List.inter⟩

/--
Split a list at an index.
```
splitAt 2 [a, b, c] = ([a, b], [c])
```
-/
def splitAt (n : Nat) (l : List α) : List α × List α := go l n [] where
/--
Auxiliary for `splitAt`:
`splitAt.go l xs n acc = (acc.reverse ++ take n xs, drop n xs)` if `n < xs.length`,
and `(l, [])` otherwise.
-/
go : List α → Nat → List α → List α × List α
| [], _, _ => (l, []) -- This branch ensures the pointer equality of the result with the input
-- without any runtime branching cost.
| x :: xs, n+1, acc => go xs n (x :: acc)
| xs, _, acc => (acc.reverse, xs)

/--
Split a list at an index. Ensures the left list always has the specified length
by right padding with the provided default element.
Expand Down Expand Up @@ -1171,15 +1153,3 @@ where
loop : List α → List α → List α
| [], r => reverseAux (a :: r) [] -- Note: `reverseAux` is tail recursive.
| b :: l, r => bif p b then reverseAux (a :: r) (b :: l) else loop l (b :: r)

/--
`O(|l| + |r|)`. Merge two lists using `s` as a switch.
-/
def merge (s : α → α → Bool) (l r : List α) : List α :=
loop l r []
where
/-- Inner loop for `List.merge`. Tail recursive. -/
loop : List α → List α → List α → List α
| [], r, t => reverseAux t r
| l, [], t => reverseAux t l
| a::l, b::r, t => bif s a b then loop l (b::r) (a::t) else loop (a::l) r (b::t)
72 changes: 5 additions & 67 deletions Batteries/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,6 @@ theorem get?_set_of_lt' (a : α) {m n} (l : List α) (h : m < length l) :

@[deprecated (since := "2024-05-06")] alias length_removeNth := length_eraseIdx

/-! ### splitAt -/

theorem splitAt_go (n : Nat) (l acc : List α) :
splitAt.go l xs n acc =
if n < xs.length then (acc.reverse ++ xs.take n, xs.drop n) else (l, []) := by
induction xs generalizing n acc with
| nil => simp [splitAt.go]
| cons x xs ih =>
cases n with
| zero => simp [splitAt.go]
| succ n =>
rw [splitAt.go, take_succ_cons, drop_succ_cons, ih n (x :: acc),
reverse_cons, append_assoc, singleton_append, length_cons]
simp only [Nat.succ_lt_succ_iff]

theorem splitAt_eq (n : Nat) (l : List α) : splitAt n l = (l.take n, l.drop n) := by
rw [splitAt, splitAt_go, reverse_nil, nil_append]
split <;> simp_all [take_of_length_le, drop_of_length_le]

/-! ### eraseP -/

@[simp] theorem extractP_eq_find?_eraseP
Expand Down Expand Up @@ -606,47 +587,15 @@ theorem insertP_loop (a : α) (l r : List α) :
induction l with simp [insertP, insertP.loop, cond]
| cons _ _ ih => split <;> simp [insertP_loop, ih]

theorem merge_loop_nil_left (s : α → α → Bool) (r t) :
merge.loop s [] r t = reverseAux t r := by
rw [merge.loop]

/-! ### merge -/

theorem merge_loop_nil_right (s : α → α → Bool) (l t) :
merge.loop s l [] t = reverseAux t l := by
cases l <;> rw [merge.loop]; intro; contradiction

theorem merge_loop (s : α → α → Bool) (l r t) :
merge.loop s l r t = reverseAux t (merge s l r) := by
rw [merge]; generalize hn : l.length + r.length = n
induction n using Nat.recAux generalizing l r t with
| zero =>
rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_left hn)]
rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_right hn)]
simp only [merge.loop, reverseAux]
| succ n ih =>
match l, r with
| [], r => simp only [merge_loop_nil_left]; rfl
| l, [] => simp only [merge_loop_nil_right]; rfl
| a::l, b::r =>
simp only [merge.loop, cond]
split
· have hn : l.length + (b :: r).length = n := by
apply Nat.add_right_cancel (m:=1)
rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add]
rw [ih _ _ (a::t) hn, ih _ _ [] hn, ih _ _ [a] hn]; rfl
· have hn : (a::l).length + r.length = n := by
apply Nat.add_right_cancel (m:=1)
rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add]
rw [ih _ _ (b::t) hn, ih _ _ [] hn, ih _ _ [b] hn]; rfl

@[simp] theorem merge_nil (s : α → α → Bool) (l) : merge s l [] = l := merge_loop_nil_right ..

@[simp] theorem nil_merge (s : α → α → Bool) (r) : merge s [] r = r := merge_loop_nil_left ..
@[simp] theorem merge_nil (s : α → α → Bool) (l) : merge s l [] = l := by simp [merge]

@[simp] theorem nil_merge (s : α → α → Bool) (r) : merge s [] r = r := by simp [merge]

theorem cons_merge_cons (s : α → α → Bool) (a b l r) :
merge s (a::l) (b::r) = if s a b then a :: merge s l (b::r) else b :: merge s (a::l) r := by
simp only [merge, merge.loop, cond]; split <;> (next hs => rw [hs, merge_loop]; rfl)
merge s (a::l) (b::r) = if s a b then a :: merge s l (b::r) else b :: merge s (a::l) r := by
simp only [merge]

@[simp] theorem cons_merge_cons_pos (s : α → α → Bool) (l r) (h : s a b) :
merge s (a::l) (b::r) = a :: merge s l (b::r) := by
Expand All @@ -667,17 +616,6 @@ theorem cons_merge_cons (s : α → α → Bool) (a b l r) :
· simp_arith [length_merge s l (b::r)]
· simp_arith [length_merge s (a::l) r]

@[simp]
theorem mem_merge {s : α → α → Bool} : x ∈ merge s l r ↔ x ∈ l ∨ x ∈ r := by
match l, r with
| l, [] => simp
| [], l => simp
| a::l, b::r =>
rw [cons_merge_cons]
split
· simp [mem_merge (l := l) (r := b::r), or_assoc]
· simp [mem_merge (l := a::l) (r := r), or_assoc, or_left_comm]

theorem mem_merge_left (s : α → α → Bool) (h : x ∈ l) : x ∈ merge s l r :=
mem_merge.2 <| .inl h

Expand Down

0 comments on commit 79479e7

Please sign in to comment.