diff --git a/src/FsMath/SpanMath.fs b/src/FsMath/SpanMath.fs index 6c8c26e..42302a6 100644 --- a/src/FsMath/SpanMath.fs +++ b/src/FsMath/SpanMath.fs @@ -318,8 +318,9 @@ type SpanMath = // outer product ####### - + /// Computes the outer product of two spans. + /// Result[i,j] = u[i] * v[j] static member inline outerProduct<'T when 'T :> Numerics.INumber<'T> and 'T : struct @@ -335,19 +336,36 @@ type SpanMath = let cols = v.Length let data = Array.zeroCreate<'T> (rows * cols) - for i = 0 to rows - 1 do - let ui = u[i] - for j = 0 to cols - 1 do - let vSpan = v - let simdCols = Numerics.Vector<'T>.Count - let simdCount = cols / simdCols - let ceiling = simdCount * simdCols + if Numerics.Vector.IsHardwareAccelerated && cols >= Numerics.Vector<'T>.Count then + // SIMD-accelerated path + let simdWidth = Numerics.Vector<'T>.Count + let simdCount = cols / simdWidth + let scalarStart = simdCount * simdWidth + + // Cast v to SIMD vectors once + let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v) - let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v) + for i = 0 to rows - 1 do + let rowOffset = i * cols + let rowSpan = data.AsSpan(rowOffset, cols) + let rowVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(rowSpan) + // Broadcast u[i] to a SIMD vector + let uBroadcast = Numerics.Vector<'T>(u[i]) + + // Process SIMD chunks for k = 0 to simdCount - 1 do - let vi = Numerics.Vector<'T>(ui) - let res = vi * vVec[k] - res.CopyTo(MemoryMarshal.CreateSpan(&data.[i * cols + k * simdCols], simdCols)) + rowVec[k] <- uBroadcast * vVec[k] + + // Process scalar tail + for j = scalarStart to cols - 1 do + data[rowOffset + j] <- u[i] * v[j] + else + // Scalar fallback + for i = 0 to rows - 1 do + let ui = u[i] + let rowOffset = i * cols + for j = 0 to cols - 1 do + data[rowOffset + j] <- ui * v[j] (rows, cols, data) diff --git a/tests/FsMath.Tests/FsMath.Tests.fsproj b/tests/FsMath.Tests/FsMath.Tests.fsproj index d6a61c2..9938a15 100644 --- a/tests/FsMath.Tests/FsMath.Tests.fsproj +++ b/tests/FsMath.Tests/FsMath.Tests.fsproj @@ -25,6 +25,7 @@ + diff --git a/tests/FsMath.Tests/MatrixOuterProductTests.fs b/tests/FsMath.Tests/MatrixOuterProductTests.fs new file mode 100644 index 0000000..ee64a91 --- /dev/null +++ b/tests/FsMath.Tests/MatrixOuterProductTests.fs @@ -0,0 +1,66 @@ +module MatrixOuterProductTests + +open Xunit +open FsMath + +[] +let ``Outer product produces correct dimensions`` () = + let u = [| 1.0; 2.0; 3.0 |] + let v = [| 4.0; 5.0 |] + let result = Matrix.outerProduct u v + Assert.Equal(3, result.NumRows) + Assert.Equal(2, result.NumCols) + +[] +let ``Outer product computes correct values`` () = + let u = [| 1.0; 2.0; 3.0 |] + let v = [| 4.0; 5.0 |] + let result = Matrix.outerProduct u v + // Expected: [[1*4, 1*5], [2*4, 2*5], [3*4, 3*5]] + // = [[4, 5], [8, 10], [12, 15]] + Assert.Equal(4.0, result.[0, 0]) + Assert.Equal(5.0, result.[0, 1]) + Assert.Equal(8.0, result.[1, 0]) + Assert.Equal(10.0, result.[1, 1]) + Assert.Equal(12.0, result.[2, 0]) + Assert.Equal(15.0, result.[2, 1]) + +[] +let ``Outer product works with single element vectors`` () = + let u = [| 3.0 |] + let v = [| 7.0 |] + let result = Matrix.outerProduct u v + Assert.Equal(1, result.NumRows) + Assert.Equal(1, result.NumCols) + Assert.Equal(21.0, result.[0, 0]) + +[] +let ``Outer product works with larger vectors`` () = + let u = [| 1.0; 2.0; 3.0; 4.0 |] + let v = [| 10.0; 20.0; 30.0 |] + let result = Matrix.outerProduct u v + Assert.Equal(4, result.NumRows) + Assert.Equal(3, result.NumCols) + // Check a few values + Assert.Equal(10.0, result.[0, 0]) // 1 * 10 + Assert.Equal(20.0, result.[0, 1]) // 1 * 20 + Assert.Equal(30.0, result.[0, 2]) // 1 * 30 + Assert.Equal(30.0, result.[2, 0]) // 3 * 10 + Assert.Equal(80.0, result.[3, 1]) // 4 * 20 + Assert.Equal(120.0, result.[3, 2]) // 4 * 30 + +[] +let ``Outer product with SIMD-friendly size`` () = + // Size 16 ensures we use SIMD path on most systems (Vector.Count is usually 4 or 8) + let u = Array.init 10 (fun i -> float (i + 1)) + let v = Array.init 16 (fun i -> float (i + 1)) + let result = Matrix.outerProduct u v + + Assert.Equal(10, result.NumRows) + Assert.Equal(16, result.NumCols) + + // Verify a few values + Assert.Equal(1.0, result.[0, 0]) // 1 * 1 + Assert.Equal(16.0, result.[0, 15]) // 1 * 16 + Assert.Equal(50.0, result.[4, 9]) // 5 * 10 + Assert.Equal(160.0, result.[9, 15]) // 10 * 16