diff --git a/Batteries/Data/BinaryHeap.lean b/Batteries/Data/BinaryHeap.lean index b36ed97fd3..29a273a9d0 100644 --- a/Batteries/Data/BinaryHeap.lean +++ b/Batteries/Data/BinaryHeap.lean @@ -1,78 +1,76 @@ /- Copyright (c) 2021 Mario Carneiro. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Mario Carneiro +Authors: Mario Carneiro, François G. Dorais -/ +import Batteries.Data.Vector.Basic + namespace Batteries /-- A max-heap data structure. -/ structure BinaryHeap (α) (lt : α → α → Bool) where - /-- Backing array for `BinaryHeap`. -/ + /-- `O(1)`. Get data array for a `BinaryHeap`. -/ arr : Array α namespace BinaryHeap -/-- Core operation for binary heaps, expressed directly on arrays. -Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/ -def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : - {a' : Array α // a'.size = a.size} := +private def maxChild (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) : Option (Fin sz) := let left := 2 * i.1 + 1 let right := left + 1 - have left_le : i ≤ left := Nat.le_trans - (by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i) - (Nat.le_add_right ..) - have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..) - have i_le : i ≤ i := Nat.le_refl _ - have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then - if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩ - have j := if h : right < a.size then - if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j - if h : i.1 = j then ⟨a, rfl⟩ else - let a' := a.swap i j - let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2⟩ - have : a'.size - j < a.size - i := by - rw [a.size_swap i j]; exact Nat.sub_lt_sub_left i.2 <| Nat.lt_of_le_of_ne j.2 h - let ⟨a₂, h₂⟩ := heapifyDown lt a' j' - ⟨a₂, h₂.trans (a.size_swap i j)⟩ -termination_by a.size - i - -@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : - (heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2 + if hleft : left < sz then + if hright : right < sz then + if lt a[left] a[right] then + some ⟨right, hright⟩ + else + some ⟨left, hleft⟩ + else + some ⟨left, hleft⟩ + else none + +/-- Core operation for binary heaps, expressed directly on arrays. +Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/ +def heapifyDown (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) : + Vector α sz := + match h : maxChild lt a i with + | none => a + | some j => + have : i < j := by + cases i; cases j + simp only [maxChild] at h + split at h + · split at h + · split at h <;> (cases h; simp_arith) + · cases h; simp_arith + · contradiction + if lt a[i] a[j] then + heapifyDown lt (a.swap i j) j + else a +termination_by sz - i /-- Core operation for binary heaps, expressed directly on arrays. Construct a heap from an unsorted array, by heapifying all the elements. -/ -def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} := - loop (a.size / 2) a (Nat.div_le_self ..) +def mkHeap (lt : α → α → Bool) (a : Vector α sz) : Vector α sz := + loop (sz / 2) a (Nat.div_le_self ..) where /-- Inner loop for `mkHeap`. -/ - loop : (i : Nat) → (a : Array α) → i ≤ a.size → {a' : Array α // a'.size = a.size} - | 0, a, _ => ⟨a, rfl⟩ + loop : (i : Nat) → (a : Vector α sz) → i ≤ sz → Vector α sz + | 0, a, _ => a | i+1, a, h => - let h := Nat.lt_of_succ_le h - let a' := heapifyDown lt a ⟨i, h⟩ - let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ Nat.le_of_lt h) - ⟨a₂, h₂.trans a'.2⟩ - -@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) : - (mkHeap lt a).1.size = a.size := (mkHeap lt a).2 + let a' := heapifyDown lt a ⟨i, Nat.lt_of_succ_le h⟩ + loop i a' (Nat.le_trans (Nat.le_succ _) h) /-- Core operation for binary heaps, expressed directly on arrays. Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/ -def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : - {a' : Array α // a'.size = a.size} := - if i0 : i.1 = 0 then ⟨a, rfl⟩ else - have : (i.1 - 1) / 2 < i := Nat.lt_of_le_of_lt (Nat.div_le_self ..) <| - Nat.sub_lt (Nat.pos_of_ne_zero i0) Nat.zero_lt_one - let j := ⟨(i.1 - 1) / 2, Nat.lt_trans this i.2⟩ - if lt (a.get j) (a.get i) then - let a' := a.swap i j - let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2⟩ - ⟨a₂, h₂.trans (a.size_swap i j)⟩ - else ⟨a, rfl⟩ - -@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : - (heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2 +def heapifyUp (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) : + Vector α sz := + match i with + | ⟨0, _⟩ => a + | ⟨i'+1, hi⟩ => + let j := ⟨i'/2, by get_elem_tactic⟩ + if lt a[j] a[i] then + heapifyUp lt (a.swap i j) j + else a /-- `O(1)`. Build a new empty heap. -/ def empty (lt) : BinaryHeap α lt := ⟨#[]⟩ @@ -86,81 +84,91 @@ def singleton (lt) (x : α) : BinaryHeap α lt := ⟨#[x]⟩ /-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/ def size (self : BinaryHeap α lt) : Nat := self.1.size +/-- `O(1)`. Get data vector of a `BinaryHeap`. -/ +def vector (self : BinaryHeap α lt) : Vector α self.size := ⟨self.1, rfl⟩ + /-- `O(1)`. Get an element in the heap by index. -/ def get (self : BinaryHeap α lt) (i : Fin self.size) : α := self.1.get i /-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/ def insert (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where - arr := let n := self.size; - heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩ + arr := heapifyUp lt (self.vector.push x) ⟨_, Nat.lt_succ_self _⟩ |>.toArray @[simp] theorem size_insert (self : BinaryHeap α lt) (x : α) : (self.insert x).size = self.size + 1 := by - simp [insert, size, size_heapifyUp] + simp [size, insert] /-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/ -def max (self : BinaryHeap α lt) : Option α := self.1.get? 0 - -/-- Auxiliary for `popMax`. -/ -def popMaxAux (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} := - match e: self.1.size with - | 0 => ⟨self, by simp [size, e]⟩ - | n+1 => - have h0 := by rw [e]; apply Nat.succ_pos - have hn := by rw [e]; apply Nat.lt_succ_self - if hn0 : 0 < n then - let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop - ⟨⟨heapifyDown lt a ⟨0, by rwa [Array.size_pop, Array.size_swap, e]⟩⟩, - by simp [size, a]⟩ - else - ⟨⟨self.1.pop⟩, by simp [size]⟩ +def max (self : BinaryHeap α lt) : Option α := self.1[0]? /-- `O(log n)`. Remove the maximum element from a `BinaryHeap`. Call `max` first to actually retrieve the maximum element. -/ -def popMax (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux +def popMax (self : BinaryHeap α lt) : BinaryHeap α lt := + if h0 : self.size = 0 then self else + have hs : self.size - 1 < self.size := Nat.pred_lt h0 + have h0 : 0 < self.size := Nat.zero_lt_of_ne_zero h0 + let v := self.vector.swap ⟨_, h0⟩ ⟨_, hs⟩ |>.pop + if h : 0 < self.size - 1 then + ⟨heapifyDown lt v ⟨0, h⟩ |>.toArray⟩ + else + ⟨v.toArray⟩ @[simp] theorem size_popMax (self : BinaryHeap α lt) : - self.popMax.size = self.size - 1 := self.popMaxAux.2 + self.popMax.size = self.size - 1 := by + simp only [popMax, size] + split + · simp_arith [*] + · split <;> simp_arith [*] /-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/ def extractMax (self : BinaryHeap α lt) : Option α × BinaryHeap α lt := (self.max, self.popMax) -theorem size_pos_of_max {self : BinaryHeap α lt} (e : self.max = some x) : 0 < self.size := - Decidable.of_not_not fun h : ¬ 0 < self.1.size => by simp [BinaryHeap.max, Array.get?, h] at e +theorem size_pos_of_max {self : BinaryHeap α lt} (h : self.max = some x) : 0 < self.size := by + simp only [max, getElem?_def] at h + split at h + · assumption + · contradiction /-- `O(log n)`. Equivalent to `extractMax (self.insert x)`, except that extraction cannot fail. -/ def insertExtractMax (self : BinaryHeap α lt) (x : α) : α × BinaryHeap α lt := - match e: self.max with + match e : self.max with | none => (x, self) | some m => if lt x m then - let a := self.1.set ⟨0, size_pos_of_max e⟩ x - (m, ⟨heapifyDown lt a ⟨0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩) + let v := self.vector.set ⟨0, size_pos_of_max e⟩ x + (m, ⟨heapifyDown lt v ⟨0, size_pos_of_max e⟩ |>.toArray⟩) else (x, self) /-- `O(log n)`. Equivalent to `(self.max, self.popMax.insert x)`. -/ def replaceMax (self : BinaryHeap α lt) (x : α) : Option α × BinaryHeap α lt := - match e: self.max with - | none => (none, ⟨self.1.push x⟩) + match e : self.max with + | none => (none, ⟨self.vector.push x |>.toArray⟩) | some m => - let a := self.1.set ⟨0, size_pos_of_max e⟩ x - (some m, ⟨heapifyDown lt a ⟨0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩) + let v := self.vector.set ⟨0, size_pos_of_max e⟩ x + (some m, ⟨heapifyDown lt v ⟨0, size_pos_of_max e⟩ |>.toArray⟩) /-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `x ≤ self.get i`. -/ def decreaseKey (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where - arr := heapifyDown lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + arr := heapifyDown lt (self.vector.set i x) i |>.toArray /-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `self.get i ≤ x`. -/ def increaseKey (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where - arr := heapifyUp lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + arr := heapifyUp lt (self.vector.set i x) i |>.toArray end Batteries.BinaryHeap +/-- `O(n)`. Convert an unsorted vector to a `BinaryHeap`. -/ +def Batteries.Vector.toBinaryHeap (lt : α → α → Bool) (v : Vector α n) : + Batteries.BinaryHeap α lt where + arr := BinaryHeap.mkHeap lt v |>.toArray + +open Batteries in /-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/ def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : Batteries.BinaryHeap α lt where - arr := Batteries.BinaryHeap.mkHeap lt a + arr := BinaryHeap.mkHeap lt ⟨a, rfl⟩ |>.toArray +open Batteries in /-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/ @[specialize] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α := loop (a.toBinaryHeap (flip lt)) #[]