Skip to content

Commit 94452d6

Browse files
smurchingsueann
authored andcommitted
[ML-3035] Add persistence & test for DeepImageFeaturizer in Scala (#86)
Adds persistence for Scala implementation of DeepImageFeaturizer. Since DeepImageFeaturizer stores all its data as instances of org.apache.spark.ml.Param, all we need to do is extend Spark's default ML persistence traits (DefaultParamsReadable, DefaultParamsWritable). See MLlib's Binarizer and BinarizerSuite for an example of a similar Transformer/test suite pair. As is, this PR enables DeepImageFeaturizer persistence in Scala but not Python; to enable it in Python, we just need to inherit from JavaMLWriter & JavaMLReader, which define Python persistence methods (save, load) that call the corresponding Scala methods. However, we need to ensure the behavior in python and scala are consistent before we support persistence read/write in python.
1 parent 7fbabd0 commit 94452d6

File tree

4 files changed

+144
-4
lines changed

4 files changed

+144
-4
lines changed

src/main/scala/com/databricks/sparkdl/DeepImageFeaturizer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ import org.apache.spark.ml.Transformer
2323
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
2424
import org.apache.spark.ml.linalg.Vectors
2525
import org.apache.spark.ml.param.{Param, ParamMap}
26-
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
26+
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
2727
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2828
import org.apache.spark.sql.functions.{col, udf}
2929
import org.apache.spark.sql.types.StructType
30-
3130
import org.tensorflow.framework.GraphDef
3231
import org.tensorframes.{Shape, ShapeDescription}
3332
import org.tensorframes.impl.DebugRowOps
@@ -126,7 +125,7 @@ class DeepImageFeaturizer(override val uid: String) extends Transformer with Def
126125
}
127126
}
128127

129-
object DeepImageFeaturizer {
128+
object DeepImageFeaturizer extends DefaultParamsReadable[DeepImageFeaturizer] {
130129
/**
131130
* The deep image featurizer uses the information provided by named Image model to apply the
132131
* tensorflow graph, given in NamedImageModel.graph as a GraphDef, to an image column of a

src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, Row}
2323
import org.apache.spark.sql.types.{StructField, StructType}
2424
import org.scalatest.FunSuite
2525

26-
class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext {
26+
class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext with DefaultReadWriteTest {
2727

2828
var data: DataFrame = _
2929

@@ -118,4 +118,12 @@ class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext {
118118
featurizer.setModelName("noSuchModel")
119119
}
120120
}
121+
122+
test("DeepImageFeaturizer persistence") {
123+
val featurizer = new DeepImageFeaturizer()
124+
.setModelName("_test")
125+
.setInputCol("myInput")
126+
.setOutputCol("myOutput")
127+
testDefaultReadWrite(featurizer)
128+
}
121129
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright 2017 Databricks, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.databricks.sparkdl
18+
19+
import java.io.{File, IOException}
20+
21+
import org.scalatest.Suite
22+
23+
import org.apache.spark.ml.param.Params
24+
import org.apache.spark.ml.util.{Identifiable, MLReader, MLWritable, TempDirectory}
25+
26+
/**
27+
* Copied from Spark (https://github.com/apache/spark/blob/branch-2.2/mllib/src/test/scala/org/
28+
* apache/spark/ml/util/DefaultReadWriteTest.scala).
29+
*
30+
* Trait containing a default persistence test for Estimators/Transformers whose data is stored
31+
* entirely in [[org.apache.spark.ml.param.Param]] instances.
32+
* */
33+
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
34+
35+
/**
36+
* Checks "overwrite" option and params.
37+
* This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
38+
* in order to avoid conflicts from multiple calls to this method.
39+
*
40+
* @param instance ML instance to test saving/loading
41+
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
42+
* @tparam T ML instance type
43+
* @return Instance loaded from file
44+
*/
45+
def testDefaultReadWrite[T <: Params with MLWritable](
46+
instance: T,
47+
testParams: Boolean = true): T = {
48+
val uid = instance.uid
49+
val subdirName = Identifiable.randomUID("test")
50+
51+
val subdir = new File(tempDir, subdirName)
52+
val path = new File(subdir, uid).getPath
53+
54+
instance.save(path)
55+
intercept[IOException] {
56+
instance.save(path)
57+
}
58+
instance.write.overwrite().save(path)
59+
val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
60+
val newInstance = loader.load(path)
61+
assert(newInstance.uid === instance.uid)
62+
if (testParams) {
63+
instance.params.foreach { p =>
64+
if (instance.isDefined(p)) {
65+
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
66+
case (Array(values), Array(newValues)) =>
67+
assert(values === newValues, s"Values do not match on param ${p.name}.")
68+
case (value, newValue) =>
69+
assert(value === newValue, s"Values do not match on param ${p.name}.")
70+
}
71+
} else {
72+
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
73+
}
74+
}
75+
}
76+
77+
val load = instance.getClass.getMethod("load", classOf[String])
78+
val another = load.invoke(instance, path).asInstanceOf[T]
79+
assert(another.uid === instance.uid)
80+
another
81+
}
82+
83+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2017 Databricks, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.apache.spark.ml.util
18+
19+
import java.io.File
20+
21+
import org.scalatest.{BeforeAndAfterAll, Suite}
22+
23+
import org.apache.spark.util.Utils
24+
25+
/**
26+
* Trait that creates a temporary directory before all tests and deletes it after all.
27+
*/
28+
trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
29+
30+
private var _tempDir: File = _
31+
32+
/**
33+
* Returns the temporary directory as a `File` instance.
34+
*/
35+
protected def tempDir: File = _tempDir
36+
37+
override def beforeAll(): Unit = {
38+
super.beforeAll()
39+
_tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
40+
}
41+
42+
override def afterAll(): Unit = {
43+
try {
44+
Utils.deleteRecursively(_tempDir)
45+
} finally {
46+
super.afterAll()
47+
}
48+
}
49+
}
50+

0 commit comments

Comments
 (0)