Skip to content

Commit

Permalink
symbolic matrix multplication
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Sep 13, 2024
1 parent 479dde3 commit 8535a18
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vecxt/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ object Tensors:
end extension

extension (m: Matrix)
inline def :@(b: Matrix)(using inline boundsCheck: BoundsCheck): Matrix = m.matmul(b)

inline def rows: Int = m._2._1

inline def cols: Int = m._2._2
Expand Down
4 changes: 4 additions & 0 deletions vecxt/test/src/TensorExt.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ class TensorExtensionSuite extends FunSuite:
test("Matrix multiplication") {
val mat1 = Matrix(NArray(0.0, 0.0, 1.0, 0.0), (2, 2))
val mat2 = Matrix(NArray(0.0, 1.0, 0.0, 0.0), (2, 2))

val result2 = mat1 :@ mat2

val result = mat1.matmul(mat2)
assertVecEquals(result.raw, NArray(1.0, 0.0, 0.0, 0.0))
assertVecEquals(result2.raw, NArray(1.0, 0.0, 0.0, 0.0))
}

test("Matrix multiplication2") {
Expand Down

0 comments on commit 8535a18

Please sign in to comment.