Skip to content

Commit

Permalink
Only create valid decimals because ANIS error doesn't just go null an…
Browse files Browse the repository at this point in the history
…ymore
  • Loading branch information
holdenk committed Sep 30, 2024
1 parent 0e990fb commit 4e2719c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ object DataFrameGenerator {
def arbitraryDataFrameWithCustomFields(
sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1)
(userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = {
import sqlContext._

val arbitraryRDDs = RDDGenerator.genRDD(
sqlContext.sparkContext, minPartitions)(
getRowGenerator(schema, userGenerators))
Arbitrary {
arbitraryRDDs.map(sqlContext.createDataFrame(_, schema))
arbitraryRDDs.map { r =>
sqlContext.createDataFrame(r, schema)
}
}
}

Expand Down Expand Up @@ -128,9 +131,21 @@ object DataFrameGenerator {
l => new Date(l/10000)
}
case dec: DecimalType => {
// With the new ANSI default we need to make sure were passing in
// valid values.
Arbitrary.arbitrary[BigDecimal]
.retryUntil(_.precision <= dec.precision)
.retryUntil { d =>
try {
val sd = new Decimal()
// Make sure it can be converted
sd.set(d, dec.precision, dec.scale)
true
} catch {
case e: Exception => false
}
}
.map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP))
.asInstanceOf[Gen[java.math.BigDecimal]]
}
case arr: ArrayType => {
val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,24 @@ class SampleScalaCheckTest extends AnyFunSuite
check(property)
}

test("decimal generation mini") {
val schema = StructType(List(
StructField("bloop", DecimalType(38, 2), nullable=true)))

val sqlContext = new SQLContext(sc)
val dataframeGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema)

val property =
forAll(dataframeGen.arbitrary) {
dataframe => {
dataframe.schema === schema && dataframe.count >= 0
}
}

check(property)
}


test("decimal generation") {
val schema = StructType(List(
StructField("small", DecimalType(3, 1)),
Expand Down

0 comments on commit 4e2719c

Please sign in to comment.