diff --git a/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala new file mode 100644 index 00000000..89ad2bf6 --- /dev/null +++ b/ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala @@ -0,0 +1,47 @@ +package frameless.ml.feature + +import frameless.ml.{AppendTransformer, TypedEstimator} +import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid +import frameless.ml.internals.UnaryInputsChecker + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel} +import org.apache.spark.ml.linalg.Vector + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * + * @see `TypedStringIndexer` for converting categorical values into category indices + */ +class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoder, inputCol: String) + extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] { + + override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder + .setInputCols(Array(inputCol)) + .setOutputCols(Array(AppendTransformer.tempColumnName)) + + def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] = + copy(oneHotEncoder.setHandleInvalid(value.sparkValue)) + + def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] = + copy(oneHotEncoder.setDropLast(value)) + + private def copy(newOneHotEncoder: OneHotEncoder): TypedOneHotEncoder[Inputs] = + new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol) +} + +object TypedOneHotEncoder { + + case class Outputs(output: Vector) + + final class HandleInvalid private(val sparkValue: String) extends AnyVal + + object HandleInvalid { + val Error = new HandleInvalid("error") + val Keep = new HandleInvalid("keep") + } + + def apply[T](implicit inputsChecker: UnaryInputsChecker[T, Int]): TypedOneHotEncoder[T] = + new TypedOneHotEncoder[T](new OneHotEncoder(), inputsChecker.inputCol) +} diff --git a/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala new file mode 100644 index 00000000..5e6d445f --- /dev/null +++ b/ml/src/test/scala/frameless/ml/feature/TypedOneHotEncoderTests.scala @@ -0,0 +1,57 @@ +package frameless.ml.feature + +import frameless.ml.FramelessMlSuite +import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid + +import org.apache.spark.ml.linalg._ +import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.Prop._ +import shapeless.test.illTyped + +final class TypedOneHotEncoderTests extends FramelessMlSuite { + + test(".fit() returns a correct TypedTransformer") { + implicit val arbInt = Arbitrary(Gen.choose(0, 99)) + def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) => + val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast) + val inputs = 0.to(x2.a).map(i => X2(i, x2.b)) + val ds = TypedDataset.create(inputs) + val model = encoder.fit(ds).run() + val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]] + val result = resultDs.collect.run() + if (dropLast) { + result == Seq (X3(x2.a, x2.b, + Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray))) + } else { + result == Seq (X3(x2.a, x2.b, + Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0)))) + } + } + + check(prop[Double]) + check(prop[String]) + } + + test("param setting is retained") { + implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary { + Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error) + } + + val prop = forAll { handleInvalid: HandleInvalid => + val encoder = TypedOneHotEncoder[X1[Int]] + .setHandleInvalid(handleInvalid) + val ds = TypedDataset.create(Seq(X1(1))) + val model = encoder.fit(ds).run() + + model.transformer.getHandleInvalid == handleInvalid.sparkValue + } + + check(prop) + } + + test("apply() compiles only with correct inputs") { + illTyped("TypedOneHotEncoder.apply[Double]()") + illTyped("TypedOneHotEncoder.apply[X1[Double]]()") + illTyped("TypedOneHotEncoder.apply[X2[String, Long]]()") + } +}