diff --git a/data/import_eventserver.py b/data/import_eventserver.py index 7e808e3..b04fae4 100644 --- a/data/import_eventserver.py +++ b/data/import_eventserver.py @@ -18,9 +18,7 @@ def import_events(client, file): entity_type="user", entity_id=str(count), # use the count num as user ID properties= { - "attr0" : int(attr[0]), - "attr1" : int(attr[1]), - "attr2" : int(attr[2]), + "attrs" : [int(attr[0]),int(attr[1]),int(attr[2]) ], "plan" : int(plan) } ) diff --git a/src/main/scala/DataSource.scala b/src/main/scala/DataSource.scala index 278d4a5..dbde706 100644 --- a/src/main/scala/DataSource.scala +++ b/src/main/scala/DataSource.scala @@ -30,17 +30,15 @@ class DataSource(val dsp: DataSourceParams) appName = dsp.appName, entityType = "user", // only keep entities with these required properties defined - required = Some(List("plan", "attr0", "attr1", "attr2")))(sc) + required = Some(List("plan", "attrs")))(sc) // aggregateProperties() returns RDD pair of // entity ID and its aggregated properties .map { case (entityId, properties) => try { LabeledPoint(properties.get[Double]("plan"), - Vectors.dense(Array( - properties.get[Double]("attr0"), - properties.get[Double]("attr1"), - properties.get[Double]("attr2") - )) + Vectors.dense( + properties.get[Array[Double]]("attrs") + ) ) } catch { case e: Exception => { @@ -68,17 +66,15 @@ class DataSource(val dsp: DataSourceParams) appName = dsp.appName, entityType = "user", // only keep entities with these required properties defined - required = Some(List("plan", "attr0", "attr1", "attr2")))(sc) + required = Some(List("plan", "attrs")))(sc) // aggregateProperties() returns RDD pair of // entity ID and its aggregated properties .map { case (entityId, properties) => try { LabeledPoint(properties.get[Double]("plan"), - Vectors.dense(Array( - properties.get[Double]("attr0"), - properties.get[Double]("attr1"), - properties.get[Double]("attr2") - )) + Vectors.dense( + properties.get[Array[Double]]("attrs") + ) ) } catch { case e: Exception => { diff --git a/src/main/scala/Engine.scala b/src/main/scala/Engine.scala index 1bce077..29bb878 100644 --- a/src/main/scala/Engine.scala +++ b/src/main/scala/Engine.scala @@ -3,11 +3,13 @@ package org.template.classification import io.prediction.controller.EngineFactory import io.prediction.controller.Engine -class Query( - val attr0 : Double, - val attr1 : Double, - val attr2 : Double -) extends Serializable +class Query extends Serializable{ + var attrs: Seq[Double] = null + def this(attrs: Double*) { + this() + this.attrs = attrs + } +} class PredictedResult( val label: Double diff --git a/src/main/scala/NaiveBayesAlgorithm.scala b/src/main/scala/NaiveBayesAlgorithm.scala index 56c86ff..6e27bb2 100644 --- a/src/main/scala/NaiveBayesAlgorithm.scala +++ b/src/main/scala/NaiveBayesAlgorithm.scala @@ -32,7 +32,7 @@ class NaiveBayesAlgorithm(val ap: AlgorithmParams) def predict(model: NaiveBayesModel, query: Query): PredictedResult = { val label = model.predict(Vectors.dense( - Array(query.attr0, query.attr1, query.attr2) + Array(query.attrs: _*) )) new PredictedResult(label) }