diff --git a/vecxt/src/package.scala b/vecxt/src/package.scala index 0dea622..a408939 100644 --- a/vecxt/src/package.scala +++ b/vecxt/src/package.scala @@ -68,6 +68,8 @@ object Tensors: end extension extension (m: Matrix) + inline def :@(b: Matrix)(using inline boundsCheck: BoundsCheck): Matrix = m.matmul(b) + inline def rows: Int = m._2._1 inline def cols: Int = m._2._2 diff --git a/vecxt/test/src/TensorExt.test.scala b/vecxt/test/src/TensorExt.test.scala index 44a51f5..74d309f 100644 --- a/vecxt/test/src/TensorExt.test.scala +++ b/vecxt/test/src/TensorExt.test.scala @@ -11,8 +11,12 @@ class TensorExtensionSuite extends FunSuite: test("Matrix multiplication") { val mat1 = Matrix(NArray(0.0, 0.0, 1.0, 0.0), (2, 2)) val mat2 = Matrix(NArray(0.0, 1.0, 0.0, 0.0), (2, 2)) + + val result2 = mat1 :@ mat2 + val result = mat1.matmul(mat2) assertVecEquals(result.raw, NArray(1.0, 0.0, 0.0, 0.0)) + assertVecEquals(result2.raw, NArray(1.0, 0.0, 0.0, 0.0)) } test("Matrix multiplication2") {