Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Sep 13, 2024
1 parent 686eff5 commit 998d6c5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
12 changes: 12 additions & 0 deletions vecxt/src/dimMatCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@ protected[vecxt] object dimMatCheck:
inline if doCheck then if a.cols != b.rows then throw MatrixDimensionMismatch(a.rows, a.cols, b.rows, b.cols)
end dimMatCheck

protected[vecxt] object dimMatInstantiateCheck:
inline def apply(a: Matrix)(using inline doCheck: BoundsCheck) =
inline if doCheck then
if a.cols * a.rows != a.raw.size
then throw InvalidMatrix(a.rows, a.cols, a.raw.size)
end dimMatInstantiateCheck

case class MatrixDimensionMismatch(aCols: Int, aRows: Int, bCols: Int, bRows: Int)
extends Exception(
s"Matrix dimensions do not match. Matrix A : ($aRows, $aCols), Matrix B : ($bRows, $bCols)"
)

case class InvalidMatrix(cols: Int, rows: Int, data: Int)
extends Exception(
s"Matrix dimensions do not match. Matrix A : ($cols, $rows), is provided with data of length $data"
)
26 changes: 16 additions & 10 deletions vecxt/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,21 @@ object Tensors:
type Matrix = Matrix1 & Tensor

object Matrix:
def apply[T <: Tuple2[Int, Int]](a: NArray[Double], b: T)(using ev: TupleOfInts[T] =:= true): Matrix = (a, b)
def apply[T <: Tuple2[Int, Int]](b: T, a: NArray[Double])(using ev: TupleOfInts[T] =:= true): Matrix = (a, b)

def fromRows(a: NArray[NArray[Double]]): Matrix =
inline def apply[T <: Tuple2[Int, Int]](a: NArray[Double], b: T)(using inline boundsCheck: BoundsCheck)(using
ev: TupleOfInts[T] =:= true
): Matrix =
(a, b)
inline def apply[T <: Tuple2[Int, Int]](b: T, a: NArray[Double])(using inline boundsCheck: BoundsCheck)(using
ev: TupleOfInts[T] =:= true
): Matrix =
(a, b)

inline def fromRows(a: NArray[NArray[Double]])(using inline boundsCheck: BoundsCheck): Matrix =
val rows = a.size
val cols = a(0).size

assert(a.forall(_.size == cols))

println("cols :" + cols)
println("rows : " + rows)
val newArr = NArray.ofSize[Double](rows * cols)
var idx = 0
var i = 0
Expand All @@ -53,12 +57,10 @@ object Tensors:
Matrix(newArr, (rows, cols))
end fromRows

def fromColumns(a: NArray[NArray[Double]]): Matrix =
inline def fromColumns(a: NArray[NArray[Double]])(using inline boundsCheck: BoundsCheck): Matrix =
val cols = a.size
val rows = a(0).size
assert(a.forall(_.size == rows))
println("cols :" + cols)
println("rows : " + rows)
val newArr = NArray.ofSize[Double](rows * cols)
var idx = 0
var i = 0
Expand Down Expand Up @@ -117,6 +119,8 @@ object Tensors:
extension (m: Matrix)
inline def :@(b: Matrix)(using inline boundsCheck: BoundsCheck): Matrix = m.matmul(b)

inline def scale(d: Double): Unit = m._1 *= d

inline def rows: Int = m._2._1

inline def cols: Int = m._2._2
Expand Down Expand Up @@ -163,7 +167,9 @@ object Tensors:
end while
i += 1
end while
Matrix(newArr, (m.cols, m.rows))
Matrix(newArr, (m.cols, m.rows))(using
BoundsCheck.DoBoundsCheck.no
) // we already have a valid matrix if we are transposing it, so this check is redundant if this method works as intended.
end transpose
end extension

Expand Down
16 changes: 16 additions & 0 deletions vecxt/test/src/TensorExt.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,20 @@ class TensorExtensionSuite extends FunSuite:
assertVecEquals(col, NArray[Double](3.0, 4.0))

}

test("That we have to provide ints for dimensions") {
// We should allow non-integer values in the dimensions of the Tuple.
val code = "Matrix((1, 1.0), NArray.ofSize[Double](6))"
val code2 = """Matrix((1, "s"), NArray.ofSize[Double](6))"""
compileErrors(code)
compileErrors(code2)
}

test("matrix scale") {
val array = NArray[Double](1.0, 2.0, 3.0, 4.0)
val matrix = Tensors.Matrix(array, (2, 2))
val col1 = matrix.scale(2)
assertVecEquals(matrix.raw, NArray[Double](2.0, 4.0, 6.0, 8.0))
}

end TensorExtensionSuite

0 comments on commit 998d6c5

Please sign in to comment.