Skip to content

Commit

Permalink
transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Sep 12, 2024
1 parent 4b7b98b commit 479dde3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
23 changes: 23 additions & 0 deletions vecxt/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ object Tensors:

extension (m: Matrix)
inline def rows: Int = m._2._1

inline def cols: Int = m._2._2

inline def row(i: Int): NArray[Double] =
val start = i * m.cols
val end = (i + 1) * m.cols
Expand All @@ -83,6 +85,27 @@ object Tensors:
end while
result
end row

inline def print: String =
val arrArr = for i <- 0 until m.rows yield m.row(i).mkString(" ")
arrArr.mkString("\n")
end print

inline def col(i: Int): NArray[Double] = NArray.tabulate(m.rows)(j => m._1(j * m.cols + i))

inline def transpose: Matrix =
val newArr = NArray.ofSize[Double](m._1.length)
var i = 0
while i < m.cols do
var j = 0
while j < m.rows do
newArr(i * m.rows + j) = m._1(j * m.cols + i)
j += 1
end while
i += 1
end while
Matrix(newArr, (m.cols, m.rows))
end transpose
end extension

end Tensors
16 changes: 16 additions & 0 deletions vecxt/test/src/TensorExt.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ class TensorExtensionSuite extends FunSuite:
assertVecEquals(result.raw, NArray(58.0, 139.0, 64.0, 154.0))
}

test("Matrix transpose") {
val originalArray = NArray[Double](1, 2, 3, 4, 5, 6)
val matrix = Matrix(originalArray, (2, 3))

println(matrix.print)

val transposedMatrix = matrix.transpose

val expectedArray = NArray[Double](1, 4, 2, 5, 3, 6)
val expectedMatrix = Matrix(expectedArray, (3, 2))

assertEquals(transposedMatrix.raw.toList, expectedMatrix.raw.toList)
assertEquals(transposedMatrix.rows, expectedMatrix.rows)
assertEquals(transposedMatrix.cols, expectedMatrix.cols)
}

test("Tensor raw array retrieval") {
val vec = Vector(NArray[Double](1.0, 2.0, 3.0))
assertVecEquals(vec.raw, NArray(1.0, 2.0, 3.0))
Expand Down

0 comments on commit 479dde3

Please sign in to comment.