diff --git a/src/Init/Data/List/Sort/Basic.lean b/src/Init/Data/List/Sort/Basic.lean index 05ba40c1d073..7fbf4cc7e102 100644 --- a/src/Init/Data/List/Sort/Basic.lean +++ b/src/Init/Data/List/Sort/Basic.lean @@ -22,17 +22,18 @@ namespace List This version is not tail-recursive, but it is replaced at runtime by `mergeTR` using a `@[csimp]` lemma. -/ -def merge (le : α → α → Bool) : List α → List α → List α +def merge (xs ys : List α) (le : α → α → Bool := by exact fun a b => a ≤ b) : List α := + match xs, ys with | [], ys => ys | xs, [] => xs | x :: xs, y :: ys => if le x y then - x :: merge le xs (y :: ys) + x :: merge xs (y :: ys) le else - y :: merge le (x :: xs) ys + y :: merge (x :: xs) ys le -@[simp] theorem nil_merge (ys : List α) : merge le [] ys = ys := by simp [merge] -@[simp] theorem merge_right (xs : List α) : merge le xs [] = xs := by +@[simp] theorem nil_merge (ys : List α) : merge [] ys le = ys := by simp [merge] +@[simp] theorem merge_right (xs : List α) : merge xs [] le = xs := by induction xs with | nil => simp [merge] | cons x xs ih => simp [merge, ih] @@ -45,6 +46,7 @@ def splitInTwo (l : { l : List α // l.length = n }) : let r := splitAt ((n+1)/2) l.1 (⟨r.1, by simp [r, splitAt_eq, l.2]; omega⟩, ⟨r.2, by simp [r, splitAt_eq, l.2]; omega⟩) +set_option linter.unusedVariables false in /-- Simplified implementation of stable merge sort. @@ -56,16 +58,15 @@ It is replaced at runtime in the compiler by `mergeSortTR₂` using a `@[csimp]` Because we want the sort to be stable, it is essential that we split the list in two contiguous sublists. -/ -def mergeSort (le : α → α → Bool) : List α → List α - | [] => [] - | [a] => [a] - | a :: b :: xs => +def mergeSort : ∀ (xs : List α) (le : α → α → Bool := by exact fun a b => a ≤ b), List α + | [], _ => [] + | [a], _ => [a] + | a :: b :: xs, le => let lr := splitInTwo ⟨a :: b :: xs, rfl⟩ have := by simpa using lr.2.2 have := by simpa using lr.1.2 - merge le (mergeSort le lr.1) (mergeSort le lr.2) -termination_by l => l.length - + merge (mergeSort lr.1 le) (mergeSort lr.2 le) le +termination_by xs => xs.length /-- Given an ordering relation `le : α → α → Bool`, diff --git a/src/Init/Data/List/Sort/Impl.lean b/src/Init/Data/List/Sort/Impl.lean index 7f49a622b787..95bdd3f7ba53 100644 --- a/src/Init/Data/List/Sort/Impl.lean +++ b/src/Init/Data/List/Sort/Impl.lean @@ -38,7 +38,7 @@ namespace List.MergeSort.Internal /-- `O(min |l| |r|)`. Merge two lists using `le` as a switch. -/ -def mergeTR (le : α → α → Bool) (l₁ l₂ : List α) : List α := +def mergeTR (l₁ l₂ : List α) (le : α → α → Bool) : List α := go l₁ l₂ [] where go : List α → List α → List α → List α | [], l₂, acc => reverseAux acc l₂ @@ -49,7 +49,7 @@ where go : List α → List α → List α → List α else go (x :: xs) ys (y :: acc) -theorem mergeTR_go_eq : mergeTR.go le l₁ l₂ acc = acc.reverse ++ merge le l₁ l₂ := by +theorem mergeTR_go_eq : mergeTR.go le l₁ l₂ acc = acc.reverse ++ merge l₁ l₂ le := by induction l₁ generalizing l₂ acc with | nil => simp [mergeTR.go, merge, reverseAux_eq] | cons x l₁ ih₁ => @@ -97,14 +97,14 @@ This version uses the tail-recurive `mergeTR` function as a subroutine. This is not the final version we use at runtime, as `mergeSortTR₂` is faster. This definition is useful as an intermediate step in proving the `@[csimp]` lemma for `mergeSortTR₂`. -/ -def mergeSortTR (le : α → α → Bool) (l : List α) : List α := +def mergeSortTR (l : List α) (le : α → α → Bool := by exact fun a b => a ≤ b) : List α := run ⟨l, rfl⟩ where run : {n : Nat} → { l : List α // l.length = n } → List α | 0, ⟨[], _⟩ => [] | 1, ⟨[a], _⟩ => [a] | n+2, xs => let (l, r) := splitInTwo xs - mergeTR le (run l) (run r) + mergeTR (run l) (run r) le /-- Split a list in two equal parts, reversing the first part. @@ -130,7 +130,7 @@ Faster version of `mergeSortTR`, which avoids unnecessary list reversals. -- Per the benchmark in `tests/bench/mergeSort/` -- (which averages over 4 use cases: already sorted lists, reverse sorted lists, almost sorted lists, and random lists), -- for lists of length 10^6, `mergeSortTR₂` is about 20% faster than `mergeSortTR`. -def mergeSortTR₂ (le : α → α → Bool) (l : List α) : List α := +def mergeSortTR₂ (l : List α) (le : α → α → Bool := by exact fun a b => a ≤ b) : List α := run ⟨l, rfl⟩ where run : {n : Nat} → { l : List α // l.length = n } → List α @@ -138,13 +138,13 @@ where | 1, ⟨[a], _⟩ => [a] | n+2, xs => let (l, r) := splitRevInTwo xs - mergeTR le (run' l) (run r) + mergeTR (run' l) (run r) le run' : {n : Nat} → { l : List α // l.length = n } → List α | 0, ⟨[], _⟩ => [] | 1, ⟨[a], _⟩ => [a] | n+2, xs => let (l, r) := splitRevInTwo' xs - mergeTR le (run' r) (run l) + mergeTR (run' r) (run l) le theorem splitRevInTwo'_fst (l : { l : List α // l.length = n }) : (splitRevInTwo' l).1 = ⟨(splitInTwo ⟨l.1.reverse, by simpa using l.2⟩).2.1, by have := l.2; simp; omega⟩ := by @@ -166,7 +166,7 @@ theorem splitRevInTwo_snd (l : { l : List α // l.length = n }) : (splitRevInTwo l).2 = ⟨(splitInTwo l).2.1, by have := l.2; simp; omega⟩ := by simp only [splitRevInTwo, splitRevAt_eq, reverse_take, splitInTwo_snd] -theorem mergeSortTR_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR.run le l = mergeSort le l.1 +theorem mergeSortTR_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR.run le l = mergeSort l.1 le | 0, ⟨[], _⟩ | 1, ⟨[a], _⟩ => by simp [mergeSortTR.run, mergeSort] | n+2, ⟨a :: b :: l, h⟩ => by @@ -183,7 +183,7 @@ theorem mergeSort_eq_mergeSortTR : @mergeSort = @mergeSortTR := by -- This mutual block is unfortunately quite slow to elaborate. set_option maxHeartbeats 400000 in mutual -theorem mergeSortTR₂_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR₂.run le l = mergeSort le l.1 +theorem mergeSortTR₂_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR₂.run le l = mergeSort l.1 le | 0, ⟨[], _⟩ | 1, ⟨[a], _⟩ => by simp [mergeSortTR₂.run, mergeSort] | n+2, ⟨a :: b :: l, h⟩ => by @@ -195,7 +195,7 @@ theorem mergeSortTR₂_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l. rw [reverse_reverse] termination_by n => n -theorem mergeSortTR₂_run'_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → (w : l' = l.1.reverse) → mergeSortTR₂.run' le l = mergeSort le l' +theorem mergeSortTR₂_run'_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → (w : l' = l.1.reverse) → mergeSortTR₂.run' le l = mergeSort l' le | 0, ⟨[], _⟩, w | 1, ⟨[a], _⟩, w => by simp_all [mergeSortTR₂.run', mergeSort] | n+2, ⟨a :: b :: l, h⟩, w => by diff --git a/src/Init/Data/List/Sort/Lemmas.lean b/src/Init/Data/List/Sort/Lemmas.lean index 6d223988e6f1..3a49aa603fca 100644 --- a/src/Init/Data/List/Sort/Lemmas.lean +++ b/src/Init/Data/List/Sort/Lemmas.lean @@ -24,10 +24,6 @@ import Init.Data.Bool namespace List -- We enable this instance locally so we can write `Pairwise le` instead of `Pairwise (le · ·)` everywhere. -attribute [local instance] boolRelToRel - -variable {le : α → α → Bool} - /-! ### splitInTwo -/ @[simp] theorem splitInTwo_fst (l : { l : List α // l.length = n }) : @@ -89,6 +85,8 @@ theorem splitInTwo_fst_le_splitInTwo_snd {l : { l : List α // l.length = n }} ( /-! ### enumLE -/ +variable {le : α → α → Bool} + theorem enumLE_trans (trans : ∀ a b c, le a b → le b c → le a c) (a b c : Nat × α) : enumLE le a b → enumLE le b c → enumLE le a c := by simp only [enumLE] @@ -129,7 +127,7 @@ theorem enumLE_total (total : ∀ a b, !le a b → le b a) /-! ### merge -/ theorem merge_stable : ∀ (xs ys) (_ : ∀ x y, x ∈ xs → y ∈ ys → x.1 ≤ y.1), - (merge (enumLE le) xs ys).map (·.2) = merge le (xs.map (·.2)) (ys.map (·.2)) + (merge xs ys (enumLE le)).map (·.2) = merge (xs.map (·.2)) (ys.map (·.2)) le | [], ys, _ => by simp [merge] | xs, [], _ => by simp [merge] | (i, x) :: xs, (j, y) :: ys, h => by @@ -147,7 +145,7 @@ theorem merge_stable : ∀ (xs ys) (_ : ∀ x y, x ∈ xs → y ∈ ys → x.1 The elements of `merge le xs ys` are exactly the elements of `xs` and `ys`. -/ -- We subsequently prove that `mergeSort_perm : merge le xs ys ~ xs ++ ys`. -theorem mem_merge {a : α} {xs ys : List α} : a ∈ merge le xs ys ↔ a ∈ xs ∨ a ∈ ys := by +theorem mem_merge {a : α} {xs ys : List α} : a ∈ merge xs ys le ↔ a ∈ xs ∨ a ∈ ys := by induction xs generalizing ys with | nil => simp [merge] | cons x xs ih => @@ -161,6 +159,8 @@ theorem mem_merge {a : α} {xs ys : List α} : a ∈ merge le xs ys ↔ a ∈ xs apply or_congr_left simp only [or_comm (a := a = y), or_assoc] +attribute [local instance] boolRelToRel + /-- If the ordering relation `le` is transitive and total (i.e. `le a b ∨ le b a` for all `a, b`) then the `merge` of two sorted lists is sorted. @@ -168,7 +168,7 @@ then the `merge` of two sorted lists is sorted. theorem sorted_merge (trans : ∀ (a b c : α), le a b → le b c → le a c) (total : ∀ (a b : α), !le a b → le b a) - (l₁ l₂ : List α) (h₁ : l₁.Pairwise le) (h₂ : l₂.Pairwise le) : (merge le l₁ l₂).Pairwise le := by + (l₁ l₂ : List α) (h₁ : l₁.Pairwise le) (h₂ : l₂.Pairwise le) : (merge l₁ l₂ le).Pairwise le := by induction l₁ generalizing l₂ with | nil => simpa only [merge] | cons x l₁ ih₁ => @@ -195,7 +195,7 @@ theorem sorted_merge · exact ih₂ h₂.tail theorem merge_of_le : ∀ {xs ys : List α} (_ : ∀ a b, a ∈ xs → b ∈ ys → le a b), - merge le xs ys = xs ++ ys + merge xs ys le = xs ++ ys | [], ys, _ | xs, [], _ => by simp [merge] | x :: xs, y :: ys, h => by @@ -206,7 +206,7 @@ theorem merge_of_le : ∀ {xs ys : List α} (_ : ∀ a b, a ∈ xs → b ∈ ys · exact h x y (mem_cons_self _ _) (mem_cons_self _ _) variable (le) in -theorem merge_perm_append : ∀ {xs ys : List α}, merge le xs ys ~ xs ++ ys +theorem merge_perm_append : ∀ {xs ys : List α}, merge xs ys le ~ xs ++ ys | [], ys => by simp [merge] | xs, [] => by simp [merge] | x :: xs, y :: ys => by @@ -222,24 +222,23 @@ theorem merge_perm_append : ∀ {xs ys : List α}, merge le xs ys ~ xs ++ ys @[simp] theorem mergeSort_singleton (a : α) : [a].mergeSort r = [a] := by rw [List.mergeSort] -variable (le) in -theorem mergeSort_perm : ∀ (l : List α), mergeSort le l ~ l - | [] => by simp [mergeSort] - | [a] => by simp [mergeSort] - | a :: b :: xs => by +theorem mergeSort_perm : ∀ (l : List α) (le), mergeSort l le ~ l + | [], _ => by simp [mergeSort] + | [a], _ => by simp [mergeSort] + | a :: b :: xs, le => by simp only [mergeSort] have : (splitInTwo ⟨a :: b :: xs, rfl⟩).1.1.length < xs.length + 1 + 1 := by simp [splitInTwo_fst]; omega have : (splitInTwo ⟨a :: b :: xs, rfl⟩).2.1.length < xs.length + 1 + 1 := by simp [splitInTwo_snd]; omega exact (merge_perm_append le).trans - (((mergeSort_perm _).append (mergeSort_perm _)).trans + (((mergeSort_perm _ _).append (mergeSort_perm _ _)).trans (Perm.of_eq (splitInTwo_fst_append_splitInTwo_snd _))) termination_by l => l.length -@[simp] theorem mergeSort_length (l : List α) : (mergeSort le l).length = l.length := - (mergeSort_perm le l).length_eq +@[simp] theorem mergeSort_length (l : List α) : (mergeSort l le).length = l.length := + (mergeSort_perm l le).length_eq -@[simp] theorem mem_mergeSort {a : α} {l : List α} : a ∈ mergeSort le l ↔ a ∈ l := - (mergeSort_perm le l).mem_iff +@[simp] theorem mem_mergeSort {a : α} {l : List α} : a ∈ mergeSort l le ↔ a ∈ l := + (mergeSort_perm l le).mem_iff /-- The result of `mergeSort` is sorted, @@ -251,7 +250,7 @@ The comparison function need not be irreflexive, i.e. `le a b` and `le b a` is a theorem sorted_mergeSort (trans : ∀ (a b c : α), le a b → le b c → le a c) (total : ∀ (a b : α), !le a b → le b a) : - (l : List α) → (mergeSort le l).Pairwise le + (l : List α) → (mergeSort l le).Pairwise le | [] => by simp [mergeSort] | [a] => by simp [mergeSort] | a :: b :: xs => by @@ -268,7 +267,7 @@ termination_by l => l.length /-- If the input list is already sorted, then `mergeSort` does not change the list. -/ -theorem mergeSort_of_sorted : ∀ {l : List α} (_ : Pairwise le l), mergeSort le l = l +theorem mergeSort_of_sorted : ∀ {l : List α} (_ : Pairwise le l), mergeSort l le = l | [], _ => by simp [mergeSort] | [a], _ => by simp [mergeSort] | a :: b :: xs, h => by @@ -294,10 +293,10 @@ See also: * `pair_sublist_mergeSort`: if `[a, b] <+ l` and `le a b`, then `[a, b] <+ mergeSort le l`) -/ theorem mergeSort_enum {l : List α} : - (mergeSort (enumLE le) (l.enum)).map (·.2) = mergeSort le l := + (mergeSort (l.enum) (enumLE le)).map (·.2) = mergeSort l le := go 0 l where go : ∀ (i : Nat) (l : List α), - (mergeSort (enumLE le) (l.enumFrom i)).map (·.2) = mergeSort le l + (mergeSort (l.enumFrom i) (enumLE le)).map (·.2) = mergeSort l le | _, [] | _, [a] => by simp [mergeSort] | _, a :: b :: xs => by @@ -320,24 +319,24 @@ theorem mergeSort_cons {le : α → α → Bool} (trans : ∀ (a b c : α), le a b → le b c → le a c) (total : ∀ (a b : α), !le a b → le b a) (a : α) (l : List α) : - ∃ l₁ l₂, mergeSort le (a :: l) = l₁ ++ a :: l₂ ∧ mergeSort le l = l₁ ++ l₂ ∧ + ∃ l₁ l₂, mergeSort (a :: l) le = l₁ ++ a :: l₂ ∧ mergeSort l le = l₁ ++ l₂ ∧ ∀ b, b ∈ l₁ → !le a b := by rw [← mergeSort_enum] rw [enum_cons] have nd : Nodup ((a :: l).enum.map (·.1)) := by rw [enum_map_fst]; exact nodup_range _ - have m₁ : (0, a) ∈ mergeSort (enumLE le) ((a :: l).enum) := + have m₁ : (0, a) ∈ mergeSort ((a :: l).enum) (enumLE le) := mem_mergeSort.mpr (mem_cons_self _ _) obtain ⟨l₁, l₂, h⟩ := append_of_mem m₁ have s := sorted_mergeSort (enumLE_trans trans) (enumLE_total total) ((a :: l).enum) rw [h] at s - have p := mergeSort_perm (enumLE le) ((a :: l).enum) + have p := mergeSort_perm ((a :: l).enum) (enumLE le) rw [h] at p refine ⟨l₁.map (·.2), l₂.map (·.2), ?_, ?_, ?_⟩ · simpa using congrArg (·.map (·.2)) h · rw [← mergeSort_enum.go 1, ← map_append] congr 1 - have q : mergeSort (enumLE le) (enumFrom 1 l) ~ l₁ ++ l₂ := - (mergeSort_perm (enumLE le) (enumFrom 1 l)).trans + have q : mergeSort (enumFrom 1 l) (enumLE le) ~ l₁ ++ l₂ := + (mergeSort_perm (enumFrom 1 l) (enumLE le)).trans (p.symm.trans perm_middle).cons_inv apply Perm.eq_of_sorted (le := enumLE le) · rintro ⟨i, a⟩ ⟨j, b⟩ ha hb @@ -379,7 +378,7 @@ theorem sublist_mergeSort (trans : ∀ (a b c : α), le a b → le b c → le a c) (total : ∀ (a b : α), !le a b → le b a) : ∀ {c : List α} (_ : c.Pairwise le) (_ : c <+ l), - c <+ mergeSort le l + c <+ mergeSort l le | _, _, .slnil => nil_sublist _ | c, hc, @Sublist.cons _ _ l a h => by obtain ⟨l₁, l₂, h₁, h₂, -⟩ := mergeSort_cons trans total a l @@ -409,7 +408,7 @@ then `[a, b]` is still a sublist of `mergeSort le l`. theorem pair_sublist_mergeSort (trans : ∀ (a b c : α), le a b → le b c → le a c) (total : ∀ (a b : α), !le a b → le b a) - (hab : le a b) (h : [a, b] <+ l) : [a, b] <+ mergeSort le l := + (hab : le a b) (h : [a, b] <+ l) : [a, b] <+ mergeSort l le := sublist_mergeSort trans total (pairwise_pair.mpr hab) h @[deprecated (since := "2024-09-02")] abbrev mergeSort_stable_pair := @pair_sublist_mergeSort diff --git a/tests/lean/run/mergeSort.lean b/tests/lean/run/mergeSort.lean index f43d47f118e9..954bfa156ae1 100644 --- a/tests/lean/run/mergeSort.lean +++ b/tests/lean/run/mergeSort.lean @@ -1,25 +1,62 @@ open List MergeSort Internal +-- If we omit the comparator, it is filled by the autoparam `fun a b => a ≤ b` unseal mergeSort merge in -example : mergeSort (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := +example : mergeSort [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := rfl unseal mergeSort merge in -example : mergeSort (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := +example : mergeSort [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] (· ≤ ·) = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := + rfl + +unseal mergeSort merge in +example : mergeSort [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := rfl unseal mergeSortTR.run mergeTR.go in -example : mergeSortTR (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := +example : mergeSortTR [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := rfl unseal mergeSortTR.run mergeTR.go in -example : mergeSortTR (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := +example : mergeSortTR [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := rfl unseal mergeSortTR₂.run mergeTR.go in -example : mergeSortTR₂ (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := +example : mergeSortTR₂ [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := rfl unseal mergeSortTR₂.run mergeTR.go in -example : mergeSortTR₂ (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := +example : mergeSortTR₂ [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := rfl + +/-! +# Behaviour of mergeSort when the comparator is not provided, but typeclasses are missing. +-/ + +inductive NoLE +| mk : NoLE + +/-- +error: failed to synthesize + LE NoLE +Additional diagnostic information may be available using the `set_option diagnostics true` command. +-/ +#guard_msgs in +example : mergeSort [NoLE.mk] = [NoLE.mk] := sorry + +inductive UndecidableLE +| mk : UndecidableLE + +instance : LE UndecidableLE where + le := fun _ _ => true + +/-- +error: type mismatch + a ≤ b +has type + Prop : Type +but is expected to have type + Bool : Type +-/ +#guard_msgs in +example : mergeSort [UndecidableLE.mk] = [UndecidableLE.mk] := sorry