From 8c288ea54b690b104e1f90467d6ea504a553effe Mon Sep 17 00:00:00 2001 From: Simon Parten Date: Mon, 16 Sep 2024 14:54:26 +0200 Subject: [PATCH] add comparison operators --- vecxt/js-native/package.scala | 8 ++++++++ vecxt/js/src/package.scala | 12 ++++++++++++ vecxt/jvm/src/package.scala | 15 ++++++++++++++- vecxt/native/src/package.scala | 12 ++++++++++++ vecxt/src/package.scala | 15 ++++++++++++++- vecxt/test/src/MatrixBoolean.test.scala | 23 +++++++++++++++++++++++ 6 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 vecxt/js-native/package.scala diff --git a/vecxt/js-native/package.scala b/vecxt/js-native/package.scala new file mode 100644 index 0000000..1736abe --- /dev/null +++ b/vecxt/js-native/package.scala @@ -0,0 +1,8 @@ +// package vecxt + +// import vecxt.MatrixStuff.Matrix +// import scala.reflect.ClassTag + +// extension (m: Matrix[Double]) + +// end extension diff --git a/vecxt/js/src/package.scala b/vecxt/js/src/package.scala index 8e17052..d8fbedb 100644 --- a/vecxt/js/src/package.scala +++ b/vecxt/js/src/package.scala @@ -249,6 +249,18 @@ object extensions: inline def norm: Double = blas.dnrm2(vec.length, vec, 1) + inline def lt(num: Double): NArray[Boolean] = + logicalIdx((a, b) => a < b, num) + + inline def gt(num: Double): NArray[Boolean] = + logicalIdx((a, b) => a > b, num) + + inline def lte(num: Double): NArray[Boolean] = + logicalIdx((a, b) => a <= b, num) + + inline def gte(num: Double): NArray[Boolean] = + logicalIdx((a, b) => a >= b, num) + inline def -(vec2: Float64Array)(using inline boundsCheck: BoundsCheck.BoundsCheck): Float64Array = dimCheck(vec, vec2) vec.nativeSlice().tap(_ -= vec2) diff --git a/vecxt/jvm/src/package.scala b/vecxt/jvm/src/package.scala index b8f815e..775c562 100644 --- a/vecxt/jvm/src/package.scala +++ b/vecxt/jvm/src/package.scala @@ -18,6 +18,7 @@ package vecxt import dev.ludovic.netlib.blas.JavaBLAS.getInstance as blas import scala.util.chaining.* import vecxt.MatrixStuff.* +import narr.* object extensions: @@ -68,7 +69,7 @@ object extensions: end || end extension - extension (vec: Array[Double]) + extension (vec: NArray[Double]) inline def idxBoolean(index: Array[Boolean])(using inline boundsCheck: BoundsCheck) = dimCheck(vec, index) @@ -242,6 +243,18 @@ object extensions: vec.clone.tap(_ /= d) end / + inline def lt(num: Double): Array[Boolean] = + logicalIdx((a, b) => a < b, num) + + inline def gt(num: Double): Array[Boolean] = + logicalIdx((a, b) => a > b, num) + + inline def lte(num: Double): Array[Boolean] = + logicalIdx((a, b) => a <= b, num) + + inline def gte(num: Double): Array[Boolean] = + logicalIdx((a, b) => a >= b, num) + inline def <(num: Double): Array[Boolean] = logicalIdx((a, b) => a < b, num) diff --git a/vecxt/native/src/package.scala b/vecxt/native/src/package.scala index 99aac82..d154bbb 100644 --- a/vecxt/native/src/package.scala +++ b/vecxt/native/src/package.scala @@ -218,6 +218,18 @@ object extensions: blas.cblas_ddot(vec.length, vec.at(0), 1, v1.at(0), 1) end dot + inline def lt(num: Double): Array[Boolean] = + logicalIdx((a, b) => a < b, num) + + inline def gt(num: Double): Array[Boolean] = + logicalIdx((a, b) => a > b, num) + + inline def lte(num: Double): Array[Boolean] = + logicalIdx((a, b) => a <= b, num) + + inline def gte(num: Double): Array[Boolean] = + logicalIdx((a, b) => a >= b, num) + inline def -(vec2: Array[Double])(using inline boundsCheck: BoundsCheck.BoundsCheck) = dimCheck(vec, vec2) vec.clone.tap(_ -= vec2) diff --git a/vecxt/src/package.scala b/vecxt/src/package.scala index 5d7ff81..1547369 100644 --- a/vecxt/src/package.scala +++ b/vecxt/src/package.scala @@ -180,7 +180,7 @@ object MatrixStuff: inline def raw: NArray[A] = m._1 - inline def shape: String = s"${m.rows} x ${m.cols}" + inline def shape: String = s"${m.rows} x ${m.cols}" end extension extension (m: Matrix[Double]) @@ -194,6 +194,19 @@ object MatrixStuff: inline def *=(d: Double): Unit = m._1.multInPlace(d) inline def @@(b: Matrix[Double])(using inline boundsCheck: BoundsCheck): Matrix[Double] = m.matmul(b) + + inline def >=( d: Double) : Matrix[Boolean] = + Matrix[(Int, Int), Boolean](m.raw.gte(d), m._2)(using BoundsCheck.DoBoundsCheck.no) + + inline def >( d: Double) : Matrix[Boolean] = + Matrix[(Int, Int), Boolean](m.raw.gt(d), m._2)(using BoundsCheck.DoBoundsCheck.no) + + inline def <=( d: Double) : Matrix[Boolean]= + Matrix[(Int, Int), Boolean](m.raw.lte(d), m._2)(using BoundsCheck.DoBoundsCheck.no) + + inline def <( d: Double) : Matrix[Boolean]= + Matrix[(Int, Int), Boolean](m.raw.lt(d), m._2)(using BoundsCheck.DoBoundsCheck.no) + end extension extension [@specialized(Double, Boolean) A](m: Matrix[A]) diff --git a/vecxt/test/src/MatrixBoolean.test.scala b/vecxt/test/src/MatrixBoolean.test.scala index 0dda2e0..ee9cee7 100644 --- a/vecxt/test/src/MatrixBoolean.test.scala +++ b/vecxt/test/src/MatrixBoolean.test.scala @@ -48,4 +48,27 @@ class MatrixBooleanSuite extends FunSuite: } + test("logicals") { + val mat = Matrix.fromRows[Double]( + NArray( + NArray[Double](1.0, 2.0), + NArray[Double](3.0, 4.0), + ) + ) + + val lt = mat < 2.0 + assertVecEquals(NArray(true, false, false, false), lt.raw) + + val gt = mat > 2.0 + assertVecEquals(NArray(false, true, false, true), gt.raw) + + val gte = mat >= 2.0 + assertVecEquals(NArray(false, true, true, true), gte.raw) + + val lte = mat <= 2.0 + assertVecEquals(NArray(true, false, true, false), lte.raw) + + } + + end MatrixBooleanSuite