Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package frameless.ml.feature

import frameless.ml.{AppendTransformer, TypedEstimator}
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
import frameless.ml.internals.UnaryInputsChecker

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel}
import org.apache.spark.ml.linalg.Vector

/**
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
* at most a single one-value per row that indicates the input category index.
*
* @see `TypedStringIndexer` for converting categorical values into category indices
*/
class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoder, inputCol: String)
extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] {

override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder
.setInputCols(Array(inputCol))
.setOutputCols(Array(AppendTransformer.tempColumnName))

def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setHandleInvalid(value.sparkValue))

def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setDropLast(value))

private def copy(newOneHotEncoder: OneHotEncoder): TypedOneHotEncoder[Inputs] =
new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol)
}

object TypedOneHotEncoder {

case class Outputs(output: Vector)

final class HandleInvalid private(val sparkValue: String) extends AnyVal

object HandleInvalid {
val Error = new HandleInvalid("error")
val Keep = new HandleInvalid("keep")
}

def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] =
new TypedOneHotEncoder[T](new OneHotEncoder(), inputsChecker.inputCol)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package frameless.ml.feature

import frameless.ml.FramelessMlSuite
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid

import org.apache.spark.ml.linalg._
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop._
import shapeless.test.illTyped

final class TypedOneHotEncoderTests extends FramelessMlSuite {

test(".fit() returns a correct TypedTransformer") {
implicit val arbInt = Arbitrary(Gen.choose(0, 99))
def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) =>
val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast)
val inputs = 0.to(x2.a).map(i => X2(i, x2.b))
val ds = TypedDataset.create(inputs)
val model = encoder.fit(ds).run()
val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]]
val result = resultDs.collect.run()
if (dropLast) {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray)))
} else {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0))))
}
}

check(prop[Double])
check(prop[String])
}

test("param setting is retained") {
implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary {
Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error)
}

val prop = forAll { handleInvalid: HandleInvalid =>
val encoder = TypedOneHotEncoder[X1[Int]]
.setHandleInvalid(handleInvalid)
val ds = TypedDataset.create(Seq(X1(1)))
val model = encoder.fit(ds).run()

model.transformer.getHandleInvalid == handleInvalid.sparkValue
}

check(prop)
}

test("apply() compiles only with correct inputs") {
illTyped("TypedOneHotEncoder.apply[Double]()")
illTyped("TypedOneHotEncoder.apply[X1[Double]]()")
illTyped("TypedOneHotEncoder.apply[X2[String, Long]]()")
}
}