From 42d90e6093f46c858144e64485a5263da7db45a1 Mon Sep 17 00:00:00 2001 From: scorebot Date: Fri, 4 Oct 2024 21:25:09 -0700 Subject: [PATCH] Refactor the code further --- README.md | 2 +- .../scala/org/pmml4s/common/predicates.scala | 118 ++++++++++-------- .../scala/org/pmml4s/model/MiningModel.scala | 2 +- .../scala/org/pmml4s/model/TreeModel.scala | 30 +++-- src/main/scala/org/pmml4s/util/Utils.scala | 18 ++- src/main/scala/org/pmml4s/xml/Builder.scala | 8 +- 6 files changed, 107 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 49adf56..4c0dbc6 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ _PMML4S_ is really easy to use. Just do one or more of the following: val row = Map("sepal_length" -> "5.1", "sepal_width" -> "3.5", "petal_length" -> "1.4", "petal_width" -> "0.2") // You need to convert the data to the desired type defined by PMML, and keep the same order as defined in the input schema. - val values = inputSchema.map(x => Utils.toVal(row(x.name), x.dataType)) + val values = inputSchema.map(x => Utils.toDataVal(row(x.name), x.dataType)) scala> val result = model.predict(Series.fromSeq(values)) result: org.pmml4s.data.Series = [Iris-setosa,1.0,1.0,0.0,0.0,1],[(predicted_class,string),(probability,double),(probability_Iris-setosa,double),(probability_Iris-versicolor,double),(probability_Iris-virginica,double),(node_id,string)] diff --git a/src/main/scala/org/pmml4s/common/predicates.scala b/src/main/scala/org/pmml4s/common/predicates.scala index d480c09..7c53073 100644 --- a/src/main/scala/org/pmml4s/common/predicates.scala +++ b/src/main/scala/org/pmml4s/common/predicates.scala @@ -18,7 +18,6 @@ package org.pmml4s.common import org.pmml4s.common.Operator.Operator import org.pmml4s.data.Series import org.pmml4s.metadata.Field -import org.pmml4s.util.Utils import org.pmml4s.xml.ElemTags object Predication extends Enumeration { @@ -99,11 +98,11 @@ class SimplePredicate( val missing = (v != v) operator match { case `lessOrEqual` => if (missing) UNKNOWN else if (v <= value) TRUE else FALSE - case `equal` => if (missing) UNKNOWN else if (v == value) TRUE else FALSE - case `notEqual` => if (missing) UNKNOWN else if (v != value) TRUE else FALSE case `lessThan` => if (missing) UNKNOWN else if (v < value) TRUE else FALSE case `greaterThan` => if (missing) UNKNOWN else if (v > value) TRUE else FALSE case `greaterOrEqual` => if (missing) UNKNOWN else if (v >= value) TRUE else FALSE + case `equal` => if (missing) UNKNOWN else if (v == value) TRUE else FALSE + case `notEqual` => if (missing) UNKNOWN else if (v != value) TRUE else FALSE case `isMissing` => if (missing) TRUE else FALSE case `isNotMissing` => if (!missing) TRUE else FALSE } @@ -121,56 +120,72 @@ class CompoundPredicate( import CompoundPredicate.BooleanOperator._ - def eval(input: Series): Predication = booleanOperator match { - case `or` => { - var hasUnknown = false - for (child <- children) { - val r = child.eval(input) - if (r == TRUE) - return TRUE - else if (r == UNKNOWN) - hasUnknown = true + def eval(input: Series): Predication = { + val len = children.length + var i = 0 + booleanOperator match { + case `or` => { + var hasUnknown = false + while (i < len) { + val child = children(i) + val r = child.eval(input) + if (r == TRUE) + return TRUE + else if (r == UNKNOWN) + hasUnknown = true + + i += 1 + } + + if (hasUnknown) UNKNOWN else FALSE } - - if (hasUnknown) UNKNOWN else FALSE - } - case `and` => { - var hasUnknown = false - for (child <- children) { - val r = child.eval(input) - if (r == FALSE) - return FALSE - else if (r == UNKNOWN) - hasUnknown = true + case `and` => { + var hasUnknown = false + while (i < len) { + val child = children(i) + val r = child.eval(input) + if (r == FALSE) + return FALSE + else if (r == UNKNOWN) + hasUnknown = true + + i += 1 + } + + if (hasUnknown) UNKNOWN else TRUE } - - if (hasUnknown) UNKNOWN else TRUE - } - case `xor` => { - var count = 0 - for (child <- children) { - val r = child.eval(input) - if (r == UNKNOWN) - return UNKNOWN - else if (r == TRUE) - count += 1 + case `xor` => { + var count = 0 + while (9 < len) { + val child = children(i) + val r = child.eval(input) + if (r == UNKNOWN) + return UNKNOWN + else if (r == TRUE) + count += 1 + + i += 1 + } + + if (count % 2 == 1) TRUE else FALSE } - - if (count % 2 == 1) TRUE else FALSE - } - case `surrogate` => { - var isSurrogate = false - for (child <- children) { - val r = child.eval(input) - if (r != UNKNOWN) - return if (r == TRUE) { - if (isSurrogate) SURROGATE else TRUE - } else r - else - isSurrogate = true + case `surrogate` => { + var isSurrogate = false + while (i < len) { + val child = children(i) + val r = child.eval(input) + if (r != UNKNOWN) + return if (r == TRUE) { + if (isSurrogate) SURROGATE else TRUE + } else r + else + isSurrogate = true + + i += 1 + } + + UNKNOWN } - - UNKNOWN } } } @@ -188,9 +203,10 @@ class SimpleSetPredicate( def eval(input: Series): Predication = { val v = field.encode(input) + val missing = (v != v) booleanOperator match { - case `isIn` => if (Utils.isMissing(v)) UNKNOWN else if (values.contains(v)) TRUE else FALSE - case `isNotIn` => if (Utils.isMissing(v)) UNKNOWN else if (!values.contains(v)) TRUE else FALSE + case `isIn` => if (missing) UNKNOWN else if (values.contains(v)) TRUE else FALSE + case `isNotIn` => if (missing) UNKNOWN else if (!values.contains(v)) TRUE else FALSE } } } diff --git a/src/main/scala/org/pmml4s/model/MiningModel.scala b/src/main/scala/org/pmml4s/model/MiningModel.scala index bbdf521..3d4ecb7 100644 --- a/src/main/scala/org/pmml4s/model/MiningModel.scala +++ b/src/main/scala/org/pmml4s/model/MiningModel.scala @@ -141,7 +141,7 @@ class MiningModel( x._2.feature match { case ResultFeature.predictedValue => outputs.predictedValue = x._1 case ResultFeature.probability => x._2.value.foreach(y => { - probabilities += (y -> x._1.asInstanceOf[Double]) + probabilities += (y -> x._1.toDouble) }) } }) diff --git a/src/main/scala/org/pmml4s/model/TreeModel.scala b/src/main/scala/org/pmml4s/model/TreeModel.scala index bfe5344..388eacd 100644 --- a/src/main/scala/org/pmml4s/model/TreeModel.scala +++ b/src/main/scala/org/pmml4s/model/TreeModel.scala @@ -21,7 +21,6 @@ import org.pmml4s.common._ import org.pmml4s.data.{DataVal, Series} import org.pmml4s.metadata.{MiningSchema, Output, OutputField, Targets} import org.pmml4s.transformations.LocalTransformations -import org.pmml4s.util.Utils import scala.collection.mutable.ArrayBuffer import scala.collection.{immutable, mutable} @@ -91,20 +90,20 @@ class TreeModel( while (i < len && !hit) { val c = children(i) c.eval(series) match { - case Predication.TRUE => { + case Predication.TRUE => { r = Predication.TRUE child = c hit = true } - case Predication.SURROGATE => { + case Predication.FALSE => + case Predication.SURROGATE => { r = Predication.SURROGATE child = c hit = true } - case Predication.UNKNOWN => { + case Predication.UNKNOWN => { unknown = true } - case _ => } i += 1 } @@ -133,10 +132,15 @@ class TreeModel( val total = selected.recordCount.getOrElse(Double.NaN) val candidates = selected.children.filter { x => x.eval(series) == UNKNOWN } var max = 0.0 - for (cls <- classes) { + var i = 0 + while (i < numClasses) { + val cls = classes(i) var conf = 0.0 - for (cand <- candidates) { - conf += cand.getConfidence(cls) * cand.recordCount.getOrElse(0.0) / total + var j = 0 + while (j < candidates.length) { + val candi = candidates(j) + conf += candi.getConfidence(cls) * candi.recordCount.getOrElse(0.0) / total + j += 1 } if (conf > max) { @@ -144,6 +148,8 @@ class TreeModel( outputs.predictedValue = cls outputs.confidence = conf } + + i += 1 } done = true @@ -213,7 +219,7 @@ class TreeModel( result(series, outputs) } - /** The sub-classes can override this method to provide classes of target inside model. */ + /** The subclasses can override this method to provide classes of target inside model. */ override def inferClasses: Array[DataVal] = { firstLeaf.scoreDistributions.classes } @@ -289,8 +295,12 @@ class TreeModel( i += 1 } - for (child <- candidates) + i = 0 + while (i < candidates.length) { + val child = candidates(i) traverseLeaves(child, series, leaves) + i += 1 + } } } diff --git a/src/main/scala/org/pmml4s/util/Utils.scala b/src/main/scala/org/pmml4s/util/Utils.scala index bd2541a..6556825 100644 --- a/src/main/scala/org/pmml4s/util/Utils.scala +++ b/src/main/scala/org/pmml4s/util/Utils.scala @@ -109,12 +109,17 @@ object Utils { // Support such float number, for example "1.0", which is converted into double firstly, // then converted to integer again. case _: NumberFormatException => { - s.toDouble.toLong + val d = StringUtils.asDouble(s) + if (d != d) { + null + } else { + d.toLong + } } - case e: Throwable => throw e + case e: Throwable => null } } - case _: NumericType => s.toDouble + case _: NumericType => StringUtils.asDouble(s) case BooleanType => s.toBoolean case _ => s } @@ -127,7 +132,12 @@ object Utils { // Support such float number, for example "1.0", which is converted into double firstly, // then converted to integer again. case _: NumberFormatException => { - LongVal(s.toDouble.toLong) + val d = StringUtils.asDouble(s) + if (d != d) { + DataVal.NULL + } else { + LongVal(d.toLong) + } } case e: Throwable => DataVal.NULL } diff --git a/src/main/scala/org/pmml4s/xml/Builder.scala b/src/main/scala/org/pmml4s/xml/Builder.scala index 445c07f..fb18ded 100644 --- a/src/main/scala/org/pmml4s/xml/Builder.scala +++ b/src/main/scala/org/pmml4s/xml/Builder.scala @@ -502,13 +502,13 @@ trait Builder[T <: Model] extends TransformationsBuilder { } def makeMatrix(reader: XMLEventReader, attrs: XmlAttrs): Matrix = { - val kind = attrs.get(AttrTags.KIND).map(MatrixKind.withName(_)).getOrElse(MatrixKind.any) + val kind = attrs.get(AttrTags.KIND).map(MatrixKind.withName).getOrElse(MatrixKind.any) val nbRows = attrs.getInt(AttrTags.NB_ROWS) val nbCols = attrs.getInt(AttrTags.NB_COLS) val diagDefault = attrs.getDouble(AttrTags.DIAG_DEFAULT) val offDiagDefault = attrs.getDouble(AttrTags.OFF_DIAG_DEFAULT) val arrays = mutable.ArrayBuilder.make[Array[Double]] - nbRows.foreach(arrays.sizeHint(_)) + nbRows.foreach(arrays.sizeHint) val matCells = mutable.ArrayBuilder.make[MatCell] traverseElems(reader, ElemTags.MATRIX, { @@ -517,7 +517,7 @@ trait Builder[T <: Model] extends TransformationsBuilder { override def build(reader: XMLEventReader, attrs: XmlAttrs): MatCell = { val row = attrs.int(AttrTags.ROW) val col = attrs.int(AttrTags.COL) - val value = extractText(reader, ElemTags.MAT_CELL).toDouble + val value = StringUtils.asDouble(extractText(reader, ElemTags.MAT_CELL)) new MatCell(row, col, value) } @@ -625,7 +625,7 @@ trait Builder[T <: Model] extends TransformationsBuilder { override def build(reader: XMLEventReader, attrs: XmlAttrs): ComparisonMeasure = { val kind = ComparisonMeasureKind.withName(attrs(AttrTags.KIND)) val compareFunction = - attrs.get(AttrTags.COMPARE_FUNCTION).map(CompareFunction.withName(_)).getOrElse(CompareFunction.absDiff) + attrs.get(AttrTags.COMPARE_FUNCTION).map(CompareFunction.withName).getOrElse(CompareFunction.absDiff) val minimum = attrs.getDouble(AttrTags.MINIMUM) val maximum = attrs.getDouble(AttrTags.MAXIMUM) var distance: Distance = null