Skip to content

Commit

Permalink
feat: add map and mapM for scalar array types (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais authored Sep 9, 2024
1 parent 869f2ad commit afe9c5c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Batteries.Data.ByteSubarray
import Batteries.Data.Char
import Batteries.Data.DList
import Batteries.Data.Fin
import Batteries.Data.FloatArray
import Batteries.Data.HashMap
import Batteries.Data.Int
import Batteries.Data.LazyList
Expand Down
35 changes: 35 additions & 0 deletions Batteries/Data/ByteArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,38 @@ where
(ofFnAux.go f i acc).data = Array.ofFn.go f i acc.data := by
rw [ofFnAux.go, Array.ofFn.go]; split; rw [data_ofFnAux f (i+1), data_push]; rfl
termination_by n - i

/-! ### map/mapM -/

/--
Unsafe optimized implementation of `mapM`.
This function is unsafe because it relies on the implementation limit that the size of an array is
always less than `USize.size`.
-/
@[inline]
unsafe def mapMUnsafe [Monad m] (a : ByteArray) (f : UInt8 → m UInt8) : m ByteArray :=
loop a 0 a.usize
where
/-- Inner loop for `mapMUnsafe`. -/
@[specialize]
loop (a : ByteArray) (k s : USize) := do
if k < a.usize then
let x := a.uget k lcProof
let y ← f x
let a := a.uset k y lcProof
loop a (k+1) s
else pure a

/-- `mapM f a` applies the monadic function `f` to each element of the array. -/
@[implemented_by mapMUnsafe]
def mapM [Monad m] (a : ByteArray) (f : UInt8 → m UInt8) : m ByteArray := do
let mut r := a
for i in [0:r.size] do
r := r.set! i (← f r[i]!)
return r

/-- `map f a` applies the function `f` to each element of the array. -/
@[inline]
def map (a : ByteArray) (f : UInt8 → UInt8) : ByteArray :=
mapM (m:=Id) a f
40 changes: 40 additions & 0 deletions Batteries/Data/FloatArray.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2. license as described in the file LICENSE.
Authors: François G. Dorais
-/

namespace FloatArray

/--
Unsafe optimized implementation of `mapM`.
This function is unsafe because it relies on the implementation limit that the size of an array is
always less than `USize.size`.
-/
@[inline]
unsafe def mapMUnsafe [Monad m] (a : FloatArray) (f : Float → m Float) : m FloatArray :=
loop a 0 a.usize
where
/-- Inner loop for `mapMUnsafe`. -/
@[specialize]
loop (a : FloatArray) (k s : USize) := do
if k < s then
let x := a.uget k lcProof
let y ← f x
let a := a.uset k y lcProof
loop a (k+1) s
else pure a

/-- `mapM f a` applies the monadic function `f` to each element of the array. -/
@[implemented_by mapMUnsafe]
def mapM [Monad m] (a : FloatArray) (f : Float → m Float) : m FloatArray := do
let mut r := a
for i in [0:r.size] do
r := r.set! i (← f r[i]!)
return r

/-- `map f a` applies the function `f` to each element of the array. -/
@[inline]
def map (a : FloatArray) (f : Float → Float) : FloatArray :=
mapM (m:=Id) a f

0 comments on commit afe9c5c

Please sign in to comment.