From 479dde385222e6dafb50d143c3f52482973df857 Mon Sep 17 00:00:00 2001 From: Simon Parten Date: Thu, 12 Sep 2024 23:17:49 +0200 Subject: [PATCH] transpose --- vecxt/src/package.scala | 23 +++++++++++++++++++++++ vecxt/test/src/TensorExt.test.scala | 16 ++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/vecxt/src/package.scala b/vecxt/src/package.scala index 79db1c4..0dea622 100644 --- a/vecxt/src/package.scala +++ b/vecxt/src/package.scala @@ -69,7 +69,9 @@ object Tensors: extension (m: Matrix) inline def rows: Int = m._2._1 + inline def cols: Int = m._2._2 + inline def row(i: Int): NArray[Double] = val start = i * m.cols val end = (i + 1) * m.cols @@ -83,6 +85,27 @@ object Tensors: end while result end row + + inline def print: String = + val arrArr = for i <- 0 until m.rows yield m.row(i).mkString(" ") + arrArr.mkString("\n") + end print + inline def col(i: Int): NArray[Double] = NArray.tabulate(m.rows)(j => m._1(j * m.cols + i)) + + inline def transpose: Matrix = + val newArr = NArray.ofSize[Double](m._1.length) + var i = 0 + while i < m.cols do + var j = 0 + while j < m.rows do + newArr(i * m.rows + j) = m._1(j * m.cols + i) + j += 1 + end while + i += 1 + end while + Matrix(newArr, (m.cols, m.rows)) + end transpose end extension + end Tensors diff --git a/vecxt/test/src/TensorExt.test.scala b/vecxt/test/src/TensorExt.test.scala index aae3876..44a51f5 100644 --- a/vecxt/test/src/TensorExt.test.scala +++ b/vecxt/test/src/TensorExt.test.scala @@ -24,6 +24,22 @@ class TensorExtensionSuite extends FunSuite: assertVecEquals(result.raw, NArray(58.0, 139.0, 64.0, 154.0)) } + test("Matrix transpose") { + val originalArray = NArray[Double](1, 2, 3, 4, 5, 6) + val matrix = Matrix(originalArray, (2, 3)) + + println(matrix.print) + + val transposedMatrix = matrix.transpose + + val expectedArray = NArray[Double](1, 4, 2, 5, 3, 6) + val expectedMatrix = Matrix(expectedArray, (3, 2)) + + assertEquals(transposedMatrix.raw.toList, expectedMatrix.raw.toList) + assertEquals(transposedMatrix.rows, expectedMatrix.rows) + assertEquals(transposedMatrix.cols, expectedMatrix.cols) + } + test("Tensor raw array retrieval") { val vec = Vector(NArray[Double](1.0, 2.0, 3.0)) assertVecEquals(vec.raw, NArray(1.0, 2.0, 3.0))