Skip to content

Commit

Permalink
Slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Sep 14, 2024
1 parent 0591da1 commit e8c7f09
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 2 deletions.
29 changes: 29 additions & 0 deletions site/docs/_docs/matrix.mdoc.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,35 @@ result2.print

val mat3 = mat2.transpose + mat1

```

## Slicing

Index via a `Int`, `NArray[Int]` or a `Range` to slice a matrix. The `::` operator is used to select all elements in a dimension.

```scala mdoc:to-string
import vecxt.Matrix.*
import vecxt.BoundsCheck.DoBoundsCheck.yes
import narr.*
import vecxt.extensions.*

val mat = Matrix.fromRows(
NArray(
NArray[Double](1.0, 2.0, 3.0),
NArray[Double](4.0, 5.0, 6.0),
NArray[Double](7.0, 8.0, 9.0)
)
)
mat(::, ::)
mat(1, ::)

mat(::, 1)

mat(1, 1)

mat(0 to 1, 0 to 1)

mat(NArray.from[Int](Array(0, 2)), 0 to 1)


```
2 changes: 1 addition & 1 deletion vecxt/js/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object extensions:
// end copy
end extension

extension (vec: Float64Array)
extension (vec: NArray[Double])

inline def idxBoolean(index: js.Array[Boolean])(using inline boundsCheck: BoundsCheck.BoundsCheck) =
dimCheck(vec, index)
Expand Down
42 changes: 41 additions & 1 deletion vecxt/src/package.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vecxt

import narr.*
import javax.print.attribute.standard.MediaSize.NA

object Matrix:

Expand All @@ -25,6 +26,8 @@ object Matrix:

opaque type Matrix = (NArray[Double], Tuple2[Int, Int])

type RangeExtender = Range | Int | NArray[Int] | ::.type

// type Matrix = Matrix1 & Tensor

object Matrix:
Expand Down Expand Up @@ -63,6 +66,12 @@ object Matrix:
Matrix(newArr, (rows, cols))
end fromRows

inline def ones(dim: Tuple2[Int, Int]): Matrix =
val (rows, cols) = dim
val newArr = NArray.fill[Double](rows * cols)(1.0)
Matrix(newArr, dim)(using BoundsCheck.DoBoundsCheck.no)
end ones

inline def zeros(dim: Tuple2[Int, Int]): Matrix =
val (rows, cols) = dim
val newArr = NArray.ofSize[Double](rows * cols)
Expand Down Expand Up @@ -114,10 +123,41 @@ object Matrix:
// end apply
// end StrictMatrix

extension (d: Array[Double]) def print: String = d.mkString("[", ",", "],")
extension [A](d: Array[A]) def print: String = d.mkString("[", ",", "],")

extension (m: Matrix)

private inline def range(r: RangeExtender, max: Int): NArray[Int] = r match
case _: ::.type => NArray.from((0 until max).toArray)
case r: Range => NArray.from(r.toArray)
case l: NArray[Int] => l
case i: Int => NArray(i)

def apply(rowRange: RangeExtender, colRange: RangeExtender): Matrix =
val newRows = range(rowRange, m.rows)
val newCols = range(colRange, m.cols)
val newArr = NArray.ofSize[Double](newCols.size * newRows.size)

var idx = 0

var i = 0
while i < newCols.length do
val oldCol = newCols(i)
val stride = oldCol * m.cols
var j = 0
while j < newRows.length do
val oldRow = newRows(j)
newArr(idx) = m._1(stride + oldRow)
idx += 1
j += 1
end while
i += 1
end while

Matrix(newArr, (newRows.size, newCols.size))(using BoundsCheck.DoBoundsCheck.no)

end apply

inline def raw: NArray[Double] = m._1

/** Zero indexed element retrieval
Expand Down
33 changes: 33 additions & 0 deletions vecxt/test/src/TensorExt.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class TensorExtensionSuite extends FunSuite:
assertVecEquals(tensor.raw, NArray[Double](1.0, 0.0, 0.0, 1.0))
}

test("ones") {
val tensor = Matrix.ones((2, 2))
assertVecEquals(tensor.raw, NArray[Double](1.0, 1.0, 1.0, 1.0))
}

test("Matrix creation from nested NArray") {
val nestedArr = NArray(
NArray[Double](1.0, 2.0, 3.5), // col 1
Expand Down Expand Up @@ -205,4 +210,32 @@ class TensorExtensionSuite extends FunSuite:
)
}

test("slice syntax") {
val mat = Matrix.fromRows(
NArray(
NArray[Double](1.0, 2.0, 3.0),
NArray[Double](4.0, 5.0, 6.0),
NArray[Double](7.0, 8.0, 9.0)
)
)
val a = mat(::, ::)
assertVecEquals(a.raw, mat.raw)

val b = mat(1, ::)
assertVecEquals(NArray[Double](4.0, 5.0, 6.0), b.raw)

val c = mat(::, 1)
assertVecEquals(NArray[Double](2.0, 5.0, 8.0), c.raw)

val d = mat(1, 1)
assertVecEquals(NArray[Double](5.0), d.raw)

val e = mat(0 to 1, 0 to 1)
assertVecEquals(NArray[Double](1.0, 4.0, 2.0, 5.0), e.raw)

val f = mat(NArray.from[Int](Array(0, 2)), 0 to 1)
assertVecEquals(NArray[Double](1.0, 7.0, 2.0, 8.0), f.raw)

}

end TensorExtensionSuite

0 comments on commit e8c7f09

Please sign in to comment.