Skip to content

Commit bb9b915

Browse files
cannorinwallymathieu
authored andcommitted
Monad instance for Vector and Matrix (#607)
1 parent b6a094a commit bb9b915

File tree

4 files changed

+131
-53
lines changed

4 files changed

+131
-53
lines changed

src/FSharpPlus.TypeLevel/Data/Matrix.fs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,17 @@ module Vector =
251251

252252
let inline apply (f: Vector<'a -> 'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> = map2 id f v
253253

254+
/// <description>
255+
/// Converts the vector of vectors to a square matrix and returns its diagonal.
256+
/// </description>
257+
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
258+
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
259+
let join (vv: Vector<Vector<'a, 'n>, 'n>): Vector<'a, 'n> =
260+
{ Items = Array.init (Array.length vv.Items) (fun i -> vv.Items.[i].Items.[i]) }
261+
262+
let inline bind (f: 'a -> Vector<'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> =
263+
v |> map f |> join
264+
254265
let inline norm (v: Vector< ^a, ^n >) : ^a =
255266
v |> toArray |> Array.sumBy (fun x -> x * x) |> sqrt
256267
let inline maximumNorm (v: Vector< ^a, ^n >) : ^a =
@@ -327,6 +338,20 @@ module Matrix =
327338
for j = 0 to Array2D.length2 m1.Items - 1 do
328339
f i j m1.Items.[i, j] m2.Items.[i, j]
329340

341+
let inline apply (f: Matrix<'a -> 'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = map2 id f m
342+
343+
/// <description>
344+
/// Converts the matrix of matrices to a 3D cube matrix and returns its diagonal.
345+
/// </description>
346+
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
347+
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
348+
let join (m: Matrix<Matrix<'a, 'm, 'n>, 'm, 'n>) : Matrix<'a, 'm, 'n> =
349+
{ Items =
350+
Array2D.init (Array2D.length1 m.Items) (Array2D.length2 m.Items)
351+
(fun i j -> m.Items.[i, j].Items.[i, j] ) }
352+
353+
let inline bind (f: 'a -> Matrix<'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = m |> map f |> join
354+
330355
let inline rowLength (_: Matrix<'a, 'm, 'n>) : 'm = Singleton<'m>
331356
let inline colLength (_: Matrix<'a, 'm, 'n>) : 'n = Singleton<'n>
332357
let inline rowLength' (_: Matrix<'a, ^m, 'n>) : int = RuntimeValue (Singleton< ^m >)
@@ -571,8 +596,10 @@ type Matrix<'Item, 'Row, 'Column> with
571596

572597
static member inline Return (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
573598
static member inline Pure (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
574-
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
575-
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
599+
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
600+
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
601+
static member inline Join (x: Matrix<Matrix<'x, 'm, 'n>, 'm, 'n>) = Matrix.join x
602+
static member inline ( >>= ) (x: Matrix<'x, 'm, 'n>, f: 'x -> Matrix<'y, 'm, 'n>) = Matrix.bind f x
576603
static member inline get_Zero () : Matrix<'a, 'm, 'n> = Matrix.zero
577604
static member inline ( + ) (m1, m2) = Matrix.map2 (+) m1 m2
578605
static member inline ( - ) (m1, m2) = Matrix.map2 (-) m1 m2
@@ -607,6 +634,8 @@ type Vector<'Item, 'Length> with
607634
static member inline Pure (x: 'x) : Vector<'x, 'n> = Vector.replicate Singleton x
608635
static member inline ( <*> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
609636
static member inline ( <.> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
637+
static member inline Join (x: Vector<Vector<'x, 'n>, 'n>) : Vector<'x, 'n> = Vector.join x
638+
static member inline ( >>= ) (x: Vector<'x, 'n>, f: 'x -> Vector<'y, 'n>) = Vector.bind f x
610639

611640
[<EditorBrowsable(EditorBrowsableState.Never)>]
612641
static member inline Zip (x, y) = Vector.zip x y

tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
<Compile Include="Extensions.fs" />
3737
<Compile Include="BifoldableTests.fs" />
3838
<Compile Include="Compatibility.fs" />
39+
<Compile Include="Matrix.fs" />
3940
<Compile Include="TypeLevel.fs" />
4041
<Content Include="App.config" Condition=" '$(TargetFramework)' == 'net462'" />
4142
</ItemGroup>

tests/FSharpPlus.Tests/Matrix.fs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
namespace FSharpPlus.Tests
2+
3+
open System
4+
open NUnit.Framework
5+
open Helpers
6+
7+
open FSharpPlus
8+
open FSharpPlus.Data
9+
open FSharpPlus.TypeLevel
10+
11+
module VectorTests =
12+
[<Test>]
13+
let constructorAndDeconstructorWorks() =
14+
let v1 = vector (1,2,3,4,5)
15+
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
16+
let (Vector(_,_,_,_,_)) = v1
17+
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2
18+
()
19+
20+
[<Test>]
21+
let applicativeWorks() =
22+
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
23+
let u = vector (2, 3)
24+
let vu = v <*> u
25+
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
26+
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)
27+
28+
[<Test>]
29+
let satisfiesApplicativeLaws() =
30+
let u = vector ((fun i -> i - 1), (fun i -> i * 2))
31+
let v = vector ((fun i -> i + 1), (fun i -> i * 3))
32+
let w = vector (1, 1)
33+
34+
areEqual (result id <*> v) v
35+
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
36+
areEqual (result 2) ((result (fun i -> i + 1) : Vector<int -> int, S<S<Z>>>) <*> result 1)
37+
areEqual (u <*> result 1) (result ((|>) 1) <*> u)
38+
39+
[<Test>]
40+
let satisfiesMonadLaws() =
41+
let k = fun (a: int) -> vector (a - 1, a * 2)
42+
let h = fun (a: int) -> vector (a + 1, a * 3)
43+
let m = vector (1, 2)
44+
45+
areEqual (result 2 >>= k) (k 2)
46+
areEqual (m >>= result) m
47+
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)
48+
49+
module MatrixTests =
50+
[<Test>]
51+
let constructorAndDeconstructorWorks() =
52+
let m1 =
53+
matrix (
54+
(1,0,0,0),
55+
(0,1,0,0),
56+
(0,0,1,0)
57+
)
58+
let m2 =
59+
matrix (
60+
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
61+
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
62+
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
63+
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
64+
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
65+
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
66+
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
67+
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
68+
)
69+
let (Matrix(_x1,_x2,_x3)) = m1
70+
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
71+
()
72+
73+
[<Test>]
74+
let satisfiesApplicativeLaws() =
75+
let u = matrix (
76+
((fun i -> i - 1), (fun i -> i * 2)),
77+
((fun i -> i + 1), (fun i -> i * 3))
78+
)
79+
let v = matrix (
80+
((fun i -> i - 2), (fun i -> i * 5)),
81+
((fun i -> i + 2), (fun i -> i * 7))
82+
)
83+
let w = matrix ((1, 1), (1, 2))
84+
85+
areEqual (result id <*> v) v
86+
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
87+
areEqual ((result (fun i -> i + 1) : Matrix<int -> int, S<S<Z>>, S<S<Z>>>) <*> result 1) (result 2)
88+
areEqual (u <*> result 1) (result ((|>) 1) <*> u)
89+
90+
[<Test>]
91+
let satisfiesMonadLaws() =
92+
let k = fun (a: int) -> matrix ((a - 1, a * 2), (a + 1, a * 3))
93+
let h = fun (a: int) -> matrix ((a - 2, a * 5), (a + 2, a * 7))
94+
let m = matrix ((1, 1), (1, 2))
95+
96+
areEqual (result 2 >>= k) (k 2)
97+
areEqual (m >>= result) m
98+
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)

tests/FSharpPlus.Tests/TypeLevel.fs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -150,38 +150,8 @@ module NatTests =
150150
Assert (g2 =^ S(S(S(S(S(S Z))))))
151151

152152

153-
open FSharpPlus.Data
154-
155-
module MatrixTests =
156-
[<Test>]
157-
let matrixTests =
158-
let v1 = vector (1,2,3,4,5)
159-
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
160-
let (Vector(_,_,_,_,_)) = v1
161-
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2
162-
163-
let m1 =
164-
matrix (
165-
(1,0,0,0),
166-
(0,1,0,0),
167-
(0,0,1,0)
168-
)
169-
let m2 =
170-
matrix (
171-
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
172-
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
173-
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
174-
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
175-
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
176-
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
177-
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
178-
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
179-
)
180-
let (Matrix(_x1,_x2,_x3)) = m1
181-
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
182-
()
183-
184153
open Helpers
154+
open FSharpPlus.Data
185155

186156
module TypeProviderTests =
187157
type ``0`` = TypeNat<0>
@@ -206,23 +176,3 @@ module TypeProviderTests =
206176
Assert (Matrix.colLength row1 =^ (Z |> S |> S |> S))
207177
areEqual 5 (Matrix.get Z (S Z) row1)
208178
areEqual [3; 6; 9] (Vector.toList col2)
209-
210-
module TestFunctors1 =
211-
[<Test>]
212-
let applicativeOperatorWorks() =
213-
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
214-
let u = vector (2, 3)
215-
let vu = v <*> u
216-
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
217-
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)
218-
219-
module TestFunctors2 =
220-
open FSharpPlus
221-
222-
[<Test>]
223-
let applicativeWorksWithoutSubsumption() =
224-
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
225-
let u = vector (2, 3)
226-
let vu = v <*> u
227-
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
228-
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

0 commit comments

Comments
 (0)