From 39832b8a5743fb994ed39fcc51bff5ed36b14f27 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 19 Aug 2024 13:14:51 -0700 Subject: [PATCH] commit commit --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/nullExpressions.scala | 41 +++++++++++++++++ .../apache/spark/sql/MiscFunctionsSuite.scala | 44 ++++++++++++++++++- 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ef7c2f6345b2..39b51d0c7e8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -376,6 +376,7 @@ object FunctionRegistry { expression[Least]("least"), expression[NaNvl]("nanvl"), expression[NullIf]("nullif"), + expression[NullIfZero]("nullifzero"), expression[Nvl]("nvl"), expression[Nvl2]("nvl2"), expression[PosExplode]("posexplode"), @@ -384,6 +385,7 @@ object FunctionRegistry { expression[Rand]("random", true, Some("3.0.0")), expression[Randn]("randn"), expression[Stack]("stack"), + expression[ZeroIfNull]("zeroifnull"), CaseWhen.registryEntry, // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 010d79f808d10..290f523cc02c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -177,6 +177,47 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns null if `expr` is equal to zero, or `expr` otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(0); + NULL + > SELECT _FUNC_(2); + 2 + """, + since = "4.0.0", + group = "conditional_funcs") +case class NullIfZero(input: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(input: Expression) = this(input, If(EqualTo(input, Literal(0)), Literal(null), input)) + + override def parameters: Seq[Expression] = Seq(input) + + override protected def withNewChildInternal(newInput: Expression): Expression = + copy(replacement = newInput) +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns zero if `expr` is equal to null, or `expr` otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(NULL); + 0 + > SELECT _FUNC_(2); + 2 + """, + since = "4.0.0", + group = "conditional_funcs") +case class ZeroIfNull(input: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(input: Expression) = this(input, new Nvl(input, Literal(0))) + + override def parameters: Seq[Expression] = Seq(input) + + override protected def withNewChildInternal(newInput: Expression): Expression = + copy(replacement = newInput) +} @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala index b9daece4913f2..e602e89b36b49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} +import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT, SparkNumberFormatException} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.functions._ @@ -285,6 +285,48 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { assert(df.selectExpr("random(1)").collect() != null) assert(df.select(random(lit(1))).collect() != null) } + + test("SPARK-49306 nullifzero and zeroifnull functions") { + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + checkAnswer(df.selectExpr("nullifzero(0)"), Row(null)) + checkAnswer(df.selectExpr("nullifzero(cast(0 as tinyint))"), Row(null)) + checkAnswer(df.selectExpr("nullifzero(cast(0 as bigint))"), Row(null)) + checkAnswer(df.selectExpr("nullifzero('0')"), Row(null)) + checkAnswer(df.selectExpr("nullifzero(0.0)"), Row(null)) + checkAnswer(df.selectExpr("nullifzero(1)"), Row(1)) + checkAnswer(df.selectExpr("nullifzero(null)"), Row(null)) + var expr = "nullifzero('abc')" + checkError( + exception = intercept[SparkNumberFormatException] { + checkAnswer(df.selectExpr(expr), Row(null)) + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'abc'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\"", + "ansiConfig" -> "\"spark.sql.ansi.enabled\"" + ), + context = ExpectedContext("", "", 0, expr.length - 1, expr)) + + checkAnswer(df.selectExpr("zeroifnull(null)"), Row(0)) + checkAnswer(df.selectExpr("zeroifnull(1)"), Row(1)) + checkAnswer(df.selectExpr("zeroifnull(cast(1 as tinyint))"), Row(1)) + checkAnswer(df.selectExpr("zeroifnull(cast(1 as bigint))"), Row(1)) + expr = "zeroifnull('abc')" + checkError( + exception = intercept[SparkNumberFormatException] { + checkAnswer(df.selectExpr(expr), Row(null)) + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'abc'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\"", + "ansiConfig" -> "\"spark.sql.ansi.enabled\"" + ), + context = ExpectedContext("", "", 0, expr.length - 1, expr)) + } } object ReflectClass {