diff --git a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala index f6519395..85b2fe00 100644 --- a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala +++ b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameGenerator.scala @@ -49,12 +49,15 @@ object DataFrameGenerator { def arbitraryDataFrameWithCustomFields( sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1) (userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = { + import sqlContext._ val arbitraryRDDs = RDDGenerator.genRDD( sqlContext.sparkContext, minPartitions)( getRowGenerator(schema, userGenerators)) Arbitrary { - arbitraryRDDs.map(sqlContext.createDataFrame(_, schema)) + arbitraryRDDs.map { r => + sqlContext.createDataFrame(r, schema) + } } } @@ -128,9 +131,21 @@ object DataFrameGenerator { l => new Date(l/10000) } case dec: DecimalType => { + // With the new ANSI default we need to make sure were passing in + // valid values. Arbitrary.arbitrary[BigDecimal] - .retryUntil(_.precision <= dec.precision) + .retryUntil { d => + try { + val sd = new Decimal() + // Make sure it can be converted + sd.set(d, dec.precision, dec.scale) + true + } catch { + case e: Exception => false + } + } .map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP)) + .asInstanceOf[Gen[java.math.BigDecimal]] } case arr: ArrayType => { val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull) diff --git a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala index 2569121c..85301ffc 100644 --- a/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala +++ b/core/src/test/2.0/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala @@ -319,6 +319,24 @@ class SampleScalaCheckTest extends AnyFunSuite check(property) } + test("decimal generation mini") { + val schema = StructType(List( + StructField("bloop", DecimalType(38, 2), nullable=true))) + + val sqlContext = new SQLContext(sc) + val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema) + + val property = + forAll(dataframeGen.arbitrary) { + dataframe => { + dataframe.schema === schema && dataframe.count >= 0 + } + } + + check(property) + } + + test("decimal generation") { val schema = StructType(List( StructField("small", DecimalType(3, 1)),