Skip to content

Commit

Permalink
add comparison operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Sep 16, 2024
1 parent 76447e5 commit 8c288ea
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 2 deletions.
8 changes: 8 additions & 0 deletions vecxt/js-native/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// package vecxt

// import vecxt.MatrixStuff.Matrix
// import scala.reflect.ClassTag

// extension (m: Matrix[Double])

// end extension
12 changes: 12 additions & 0 deletions vecxt/js/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ object extensions:

inline def norm: Double = blas.dnrm2(vec.length, vec, 1)

inline def lt(num: Double): NArray[Boolean] =
logicalIdx((a, b) => a < b, num)

inline def gt(num: Double): NArray[Boolean] =
logicalIdx((a, b) => a > b, num)

inline def lte(num: Double): NArray[Boolean] =
logicalIdx((a, b) => a <= b, num)

inline def gte(num: Double): NArray[Boolean] =
logicalIdx((a, b) => a >= b, num)

inline def -(vec2: Float64Array)(using inline boundsCheck: BoundsCheck.BoundsCheck): Float64Array =
dimCheck(vec, vec2)
vec.nativeSlice().tap(_ -= vec2)
Expand Down
15 changes: 14 additions & 1 deletion vecxt/jvm/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package vecxt
import dev.ludovic.netlib.blas.JavaBLAS.getInstance as blas
import scala.util.chaining.*
import vecxt.MatrixStuff.*
import narr.*

object extensions:

Expand Down Expand Up @@ -68,7 +69,7 @@ object extensions:
end ||
end extension

extension (vec: Array[Double])
extension (vec: NArray[Double])

inline def idxBoolean(index: Array[Boolean])(using inline boundsCheck: BoundsCheck) =
dimCheck(vec, index)
Expand Down Expand Up @@ -242,6 +243,18 @@ object extensions:
vec.clone.tap(_ /= d)
end /

inline def lt(num: Double): Array[Boolean] =
logicalIdx((a, b) => a < b, num)

inline def gt(num: Double): Array[Boolean] =
logicalIdx((a, b) => a > b, num)

inline def lte(num: Double): Array[Boolean] =
logicalIdx((a, b) => a <= b, num)

inline def gte(num: Double): Array[Boolean] =
logicalIdx((a, b) => a >= b, num)

inline def <(num: Double): Array[Boolean] =
logicalIdx((a, b) => a < b, num)

Expand Down
12 changes: 12 additions & 0 deletions vecxt/native/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ object extensions:
blas.cblas_ddot(vec.length, vec.at(0), 1, v1.at(0), 1)
end dot

inline def lt(num: Double): Array[Boolean] =
logicalIdx((a, b) => a < b, num)

inline def gt(num: Double): Array[Boolean] =
logicalIdx((a, b) => a > b, num)

inline def lte(num: Double): Array[Boolean] =
logicalIdx((a, b) => a <= b, num)

inline def gte(num: Double): Array[Boolean] =
logicalIdx((a, b) => a >= b, num)

inline def -(vec2: Array[Double])(using inline boundsCheck: BoundsCheck.BoundsCheck) =
dimCheck(vec, vec2)
vec.clone.tap(_ -= vec2)
Expand Down
15 changes: 14 additions & 1 deletion vecxt/src/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ object MatrixStuff:

inline def raw: NArray[A] = m._1

inline def shape: String = s"${m.rows} x ${m.cols}"
inline def shape: String = s"${m.rows} x ${m.cols}"
end extension

extension (m: Matrix[Double])
Expand All @@ -194,6 +194,19 @@ object MatrixStuff:
inline def *=(d: Double): Unit = m._1.multInPlace(d)

inline def @@(b: Matrix[Double])(using inline boundsCheck: BoundsCheck): Matrix[Double] = m.matmul(b)

inline def >=( d: Double) : Matrix[Boolean] =
Matrix[(Int, Int), Boolean](m.raw.gte(d), m._2)(using BoundsCheck.DoBoundsCheck.no)

inline def >( d: Double) : Matrix[Boolean] =
Matrix[(Int, Int), Boolean](m.raw.gt(d), m._2)(using BoundsCheck.DoBoundsCheck.no)

inline def <=( d: Double) : Matrix[Boolean]=
Matrix[(Int, Int), Boolean](m.raw.lte(d), m._2)(using BoundsCheck.DoBoundsCheck.no)

inline def <( d: Double) : Matrix[Boolean]=
Matrix[(Int, Int), Boolean](m.raw.lt(d), m._2)(using BoundsCheck.DoBoundsCheck.no)

end extension

extension [@specialized(Double, Boolean) A](m: Matrix[A])
Expand Down
23 changes: 23 additions & 0 deletions vecxt/test/src/MatrixBoolean.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,27 @@ class MatrixBooleanSuite extends FunSuite:

}

test("logicals") {
val mat = Matrix.fromRows[Double](
NArray(
NArray[Double](1.0, 2.0),
NArray[Double](3.0, 4.0),
)
)

val lt = mat < 2.0
assertVecEquals(NArray(true, false, false, false), lt.raw)

val gt = mat > 2.0
assertVecEquals(NArray(false, true, false, true), gt.raw)

val gte = mat >= 2.0
assertVecEquals(NArray(false, true, true, true), gte.raw)

val lte = mat <= 2.0
assertVecEquals(NArray(true, false, true, false), lte.raw)

}


end MatrixBooleanSuite

0 comments on commit 8c288ea

Please sign in to comment.