Skip to content

Commit

Permalink
Fix the defect of modelChain with Output
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Dec 9, 2024
1 parent 67c9bb7 commit 5e4a672
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/main/scala/org/pmml4s/model/MiningModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class MiningModel(
case ResultFeature.probability => x._2.value.foreach(y => {
probabilities += (y -> x._1.toDouble)
})
case _ =>
}
})
outputs.probabilities = probabilities.toMap
Expand All @@ -151,7 +152,7 @@ class MiningModel(
if (outputs.probabilities.nonEmpty && outputs.predictedValue == null) {
outputs.evalPredictedValueByProbabilities()
}
result(series, outputs)
result(last, outputs)
} else last
}
case method => {
Expand Down
84 changes: 84 additions & 0 deletions src/test/scala/org/pmml4s/model/MiningModelTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,88 @@ class MiningModelTest extends BaseModelTest {
assert(r(2) === 0.5555556666666667)
assert(r(3) === 3)
}

test("a modelChain model with Output") {
val model = Model.fromString(
"""
|<PMML xmlns="http://www.dmg.org/PMML-4_4" version="4.4">
| <Header description="Test Rule Model"/>
| <DataDictionary numberOfFields="4">
| <DataField name="X" optype="categorical" dataType="string">
| <Value value="A"/>
| <Value value="B"/>
| </DataField>
| <DataField name="_predictedValue1" optype="continuous" dataType="double"/>
| <DataField name="_predictedValue2" optype="continuous" dataType="double"/>
| <DataField name="_predictedValueSum" optype="continuous" dataType="double"/>
| <DataField name="_predictedValueFinal" optype="continuous" dataType="double"/>
| <DataField name="predictedValueFinal" optype="continuous" dataType="double"/>
| </DataDictionary>
| <MiningModel functionName="regression">
| <MiningSchema>
| <MiningField name="X" usageType="active"/>
| <MiningField name="predictedValueFinal" usageType="target"/>
| </MiningSchema>
| <Segmentation multipleModelMethod="modelChain">
| <Segment id="0">
| <True/>
| <RuleSetModel functionName="regression" algorithmName="RuleSet">
| <MiningSchema>
| <MiningField name="X" usageType="active"/>
| <MiningField name="_predictedValue1" usageType="target"/>
| </MiningSchema>
| <Output>
| <OutputField name="_predictedValue1" optype="continuous" dataType="double" feature="predictedValue"/>
| </Output>
| <RuleSet>
| <RuleSelectionMethod criterion="firstHit"/>
| <SimpleRule score="100" confidence="1" weight="1">
| <SimplePredicate field="X" operator="equal" value="A" />
| </SimpleRule>
| <SimpleRule score="-100" confidence="1" weight="1">
| <True/>
| </SimpleRule>
| </RuleSet>
| </RuleSetModel>
| </Segment>
| <Segment id="1">
| <True/>
| <RuleSetModel functionName="regression" algorithmName="RuleSet">
| <MiningSchema>
| <MiningField name="X" usageType="active"/>
| <MiningField name="_predictedValue1" usageType="active"/>
| <MiningField name="_predictedValue2" usageType="active"/>
| </MiningSchema>
| <Output>
| <OutputField name="_predictedValue2" optype="continuous" dataType="double" feature="predictedValue"/>
| <OutputField name="_predictedValueSum" optype="continuous" dataType="double" feature="transformedValue">
| <Apply function="+">
| <FieldRef field="_predictedValue2"/>
| <FieldRef field="_predictedValue1"/>
| </Apply>
| </OutputField>
| </Output>
| <RuleSet>
| <RuleSelectionMethod criterion="firstHit"/>
| <SimpleRule score="10" confidence="1" weight="1">
| <SimplePredicate field="X" operator="equal" value="A" />
| </SimpleRule>
| <SimpleRule score="-10" confidence="1" weight="1">
| <True/>
| </SimpleRule>
| </RuleSet>
| </RuleSetModel>
| </Segment>
| </Segmentation>
| <Output>
| <OutputField name="predictedValueFinal" optype="continuous" dataType="double" feature="transformedValue">
| <FieldRef field="_predictedValueSum"/>
| </OutputField>
| </Output>
| </MiningModel>
|</PMML>
|""".stripMargin)
assert(model.predict(Map("X" -> "A"))("predictedValueFinal") === 110.0)
assert(model.predict(Map("X" -> "B"))("predictedValueFinal") === -110.0)
}
}

0 comments on commit 5e4a672

Please sign in to comment.