diff --git a/src/FsMath/SpanMath.fs b/src/FsMath/SpanMath.fs
index 6c8c26e..42302a6 100644
--- a/src/FsMath/SpanMath.fs
+++ b/src/FsMath/SpanMath.fs
@@ -318,8 +318,9 @@ type SpanMath =
// outer product #######
-
+
/// Computes the outer product of two spans.
+ /// Result[i,j] = u[i] * v[j]
static member inline outerProduct<'T
when 'T :> Numerics.INumber<'T>
and 'T : struct
@@ -335,19 +336,36 @@ type SpanMath =
let cols = v.Length
let data = Array.zeroCreate<'T> (rows * cols)
- for i = 0 to rows - 1 do
- let ui = u[i]
- for j = 0 to cols - 1 do
- let vSpan = v
- let simdCols = Numerics.Vector<'T>.Count
- let simdCount = cols / simdCols
- let ceiling = simdCount * simdCols
+ if Numerics.Vector.IsHardwareAccelerated && cols >= Numerics.Vector<'T>.Count then
+ // SIMD-accelerated path
+ let simdWidth = Numerics.Vector<'T>.Count
+ let simdCount = cols / simdWidth
+ let scalarStart = simdCount * simdWidth
+
+ // Cast v to SIMD vectors once
+ let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v)
- let vVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(v)
+ for i = 0 to rows - 1 do
+ let rowOffset = i * cols
+ let rowSpan = data.AsSpan(rowOffset, cols)
+ let rowVec = MemoryMarshal.Cast<'T, Numerics.Vector<'T>>(rowSpan)
+ // Broadcast u[i] to a SIMD vector
+ let uBroadcast = Numerics.Vector<'T>(u[i])
+
+ // Process SIMD chunks
for k = 0 to simdCount - 1 do
- let vi = Numerics.Vector<'T>(ui)
- let res = vi * vVec[k]
- res.CopyTo(MemoryMarshal.CreateSpan(&data.[i * cols + k * simdCols], simdCols))
+ rowVec[k] <- uBroadcast * vVec[k]
+
+ // Process scalar tail
+ for j = scalarStart to cols - 1 do
+ data[rowOffset + j] <- u[i] * v[j]
+ else
+ // Scalar fallback
+ for i = 0 to rows - 1 do
+ let ui = u[i]
+ let rowOffset = i * cols
+ for j = 0 to cols - 1 do
+ data[rowOffset + j] <- ui * v[j]
(rows, cols, data)
diff --git a/tests/FsMath.Tests/FsMath.Tests.fsproj b/tests/FsMath.Tests/FsMath.Tests.fsproj
index d6a61c2..9938a15 100644
--- a/tests/FsMath.Tests/FsMath.Tests.fsproj
+++ b/tests/FsMath.Tests/FsMath.Tests.fsproj
@@ -25,6 +25,7 @@
+
diff --git a/tests/FsMath.Tests/MatrixOuterProductTests.fs b/tests/FsMath.Tests/MatrixOuterProductTests.fs
new file mode 100644
index 0000000..ee64a91
--- /dev/null
+++ b/tests/FsMath.Tests/MatrixOuterProductTests.fs
@@ -0,0 +1,66 @@
+module MatrixOuterProductTests
+
+open Xunit
+open FsMath
+
+[]
+let ``Outer product produces correct dimensions`` () =
+ let u = [| 1.0; 2.0; 3.0 |]
+ let v = [| 4.0; 5.0 |]
+ let result = Matrix.outerProduct u v
+ Assert.Equal(3, result.NumRows)
+ Assert.Equal(2, result.NumCols)
+
+[]
+let ``Outer product computes correct values`` () =
+ let u = [| 1.0; 2.0; 3.0 |]
+ let v = [| 4.0; 5.0 |]
+ let result = Matrix.outerProduct u v
+ // Expected: [[1*4, 1*5], [2*4, 2*5], [3*4, 3*5]]
+ // = [[4, 5], [8, 10], [12, 15]]
+ Assert.Equal(4.0, result.[0, 0])
+ Assert.Equal(5.0, result.[0, 1])
+ Assert.Equal(8.0, result.[1, 0])
+ Assert.Equal(10.0, result.[1, 1])
+ Assert.Equal(12.0, result.[2, 0])
+ Assert.Equal(15.0, result.[2, 1])
+
+[]
+let ``Outer product works with single element vectors`` () =
+ let u = [| 3.0 |]
+ let v = [| 7.0 |]
+ let result = Matrix.outerProduct u v
+ Assert.Equal(1, result.NumRows)
+ Assert.Equal(1, result.NumCols)
+ Assert.Equal(21.0, result.[0, 0])
+
+[]
+let ``Outer product works with larger vectors`` () =
+ let u = [| 1.0; 2.0; 3.0; 4.0 |]
+ let v = [| 10.0; 20.0; 30.0 |]
+ let result = Matrix.outerProduct u v
+ Assert.Equal(4, result.NumRows)
+ Assert.Equal(3, result.NumCols)
+ // Check a few values
+ Assert.Equal(10.0, result.[0, 0]) // 1 * 10
+ Assert.Equal(20.0, result.[0, 1]) // 1 * 20
+ Assert.Equal(30.0, result.[0, 2]) // 1 * 30
+ Assert.Equal(30.0, result.[2, 0]) // 3 * 10
+ Assert.Equal(80.0, result.[3, 1]) // 4 * 20
+ Assert.Equal(120.0, result.[3, 2]) // 4 * 30
+
+[]
+let ``Outer product with SIMD-friendly size`` () =
+ // Size 16 ensures we use SIMD path on most systems (Vector.Count is usually 4 or 8)
+ let u = Array.init 10 (fun i -> float (i + 1))
+ let v = Array.init 16 (fun i -> float (i + 1))
+ let result = Matrix.outerProduct u v
+
+ Assert.Equal(10, result.NumRows)
+ Assert.Equal(16, result.NumCols)
+
+ // Verify a few values
+ Assert.Equal(1.0, result.[0, 0]) // 1 * 1
+ Assert.Equal(16.0, result.[0, 15]) // 1 * 16
+ Assert.Equal(50.0, result.[4, 9]) // 5 * 10
+ Assert.Equal(160.0, result.[9, 15]) // 10 * 16