Skip to content

Commit 6644a01

Browse files
committed
Add bind operator to vector and matrix using join
1 parent 29f98d6 commit 6644a01

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/FSharpPlus.TypeLevel/Data/Matrix.fs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,17 @@ module Vector =
251251

252252
let inline apply (f: Vector<'a -> 'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> = map2 id f v
253253

254+
/// <description>
255+
/// Converts the vector of vectors to a square matrix and returns its diagonal.
256+
/// </description>
257+
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
254258
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
255259
let join (vv: Vector<Vector<'a, 'n>, 'n>): Vector<'a, 'n> =
256260
{ Items = Array.init (Array.length vv.Items) (fun i -> vv.Items.[i].Items.[i]) }
257261

262+
let inline bind (f: 'a -> Vector<'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> =
263+
v |> map f |> join
264+
258265
let inline norm (v: Vector< ^a, ^n >) : ^a =
259266
v |> toArray |> Array.sumBy (fun x -> x * x) |> sqrt
260267
let inline maximumNorm (v: Vector< ^a, ^n >) : ^a =
@@ -333,12 +340,18 @@ module Matrix =
333340

334341
let inline apply (f: Matrix<'a -> 'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = map2 id f m
335342

343+
/// <description>
344+
/// Converts the matrix of matrices to a 3D cube matrix and returns its diagonal.
345+
/// </description>
346+
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
336347
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
337348
let join (m: Matrix<Matrix<'a, 'm, 'n>, 'm, 'n>) : Matrix<'a, 'm, 'n> =
338349
{ Items =
339350
Array2D.init (Array2D.length1 m.Items) (Array2D.length2 m.Items)
340351
(fun i j -> m.Items.[i, j].Items.[i, j] ) }
341352

353+
let inline bind (f: 'a -> Matrix<'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = m |> map f |> join
354+
342355
let inline rowLength (_: Matrix<'a, 'm, 'n>) : 'm = Singleton<'m>
343356
let inline colLength (_: Matrix<'a, 'm, 'n>) : 'n = Singleton<'n>
344357
let inline rowLength' (_: Matrix<'a, ^m, 'n>) : int = RuntimeValue (Singleton< ^m >)
@@ -586,6 +599,7 @@ type Matrix<'Item, 'Row, 'Column> with
586599
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
587600
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
588601
static member inline Join (x: Matrix<Matrix<'x, 'm, 'n>, 'm, 'n>) = Matrix.join x
602+
static member inline ( >>= ) (x: Matrix<'x, 'm, 'n>, f: 'x -> Matrix<'y, 'm, 'n>) = Matrix.bind f x
589603
static member inline get_Zero () : Matrix<'a, 'm, 'n> = Matrix.zero
590604
static member inline ( + ) (m1, m2) = Matrix.map2 (+) m1 m2
591605
static member inline ( - ) (m1, m2) = Matrix.map2 (-) m1 m2
@@ -621,6 +635,7 @@ type Vector<'Item, 'Length> with
621635
static member inline ( <*> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
622636
static member inline ( <.> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
623637
static member inline Join (x: Vector<Vector<'x, 'n>, 'n>) : Vector<'x, 'n> = Vector.join x
638+
static member inline ( >>= ) (x: Vector<'x, 'n>, f: 'x -> Vector<'y, 'n>) = Vector.bind f x
624639

625640
[<EditorBrowsable(EditorBrowsableState.Never)>]
626641
static member inline Zip (x, y) = Vector.zip x y

0 commit comments

Comments
 (0)