Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
commit
  • Loading branch information
dtenedor committed Aug 20, 2024
1 parent 542b24a commit 39832b8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ object FunctionRegistry {
expression[Least]("least"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
expression[NullIfZero]("nullifzero"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
Expand All @@ -384,6 +385,7 @@ object FunctionRegistry {
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[ZeroIfNull]("zeroifnull"),
CaseWhen.registryEntry,

// math functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,47 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
}
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns null if `expr` is equal to zero, or `expr` otherwise.",
examples = """
Examples:
> SELECT _FUNC_(0);
NULL
> SELECT _FUNC_(2);
2
""",
since = "4.0.0",
group = "conditional_funcs")
case class NullIfZero(input: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(input: Expression) = this(input, If(EqualTo(input, Literal(0)), Literal(null), input))

override def parameters: Seq[Expression] = Seq(input)

override protected def withNewChildInternal(newInput: Expression): Expression =
copy(replacement = newInput)
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns zero if `expr` is equal to null, or `expr` otherwise.",
examples = """
Examples:
> SELECT _FUNC_(NULL);
0
> SELECT _FUNC_(2);
2
""",
since = "4.0.0",
group = "conditional_funcs")
case class ZeroIfNull(input: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(input: Expression) = this(input, new Nvl(input, Literal(0)))

override def parameters: Seq[Expression] = Seq(input)

override protected def withNewChildInternal(newInput: Expression): Expression =
copy(replacement = newInput)
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT, SparkNumberFormatException}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -285,6 +285,48 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
assert(df.selectExpr("random(1)").collect() != null)
assert(df.select(random(lit(1))).collect() != null)
}

test("SPARK-49306 nullifzero and zeroifnull functions") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.selectExpr("nullifzero(0)"), Row(null))
checkAnswer(df.selectExpr("nullifzero(cast(0 as tinyint))"), Row(null))
checkAnswer(df.selectExpr("nullifzero(cast(0 as bigint))"), Row(null))
checkAnswer(df.selectExpr("nullifzero('0')"), Row(null))
checkAnswer(df.selectExpr("nullifzero(0.0)"), Row(null))
checkAnswer(df.selectExpr("nullifzero(1)"), Row(1))
checkAnswer(df.selectExpr("nullifzero(null)"), Row(null))
var expr = "nullifzero('abc')"
checkError(
exception = intercept[SparkNumberFormatException] {
checkAnswer(df.selectExpr(expr), Row(null))
},
errorClass = "CAST_INVALID_INPUT",
parameters = Map(
"expression" -> "'abc'",
"sourceType" -> "\"STRING\"",
"targetType" -> "\"BIGINT\"",
"ansiConfig" -> "\"spark.sql.ansi.enabled\""
),
context = ExpectedContext("", "", 0, expr.length - 1, expr))

checkAnswer(df.selectExpr("zeroifnull(null)"), Row(0))
checkAnswer(df.selectExpr("zeroifnull(1)"), Row(1))
checkAnswer(df.selectExpr("zeroifnull(cast(1 as tinyint))"), Row(1))
checkAnswer(df.selectExpr("zeroifnull(cast(1 as bigint))"), Row(1))
expr = "zeroifnull('abc')"
checkError(
exception = intercept[SparkNumberFormatException] {
checkAnswer(df.selectExpr(expr), Row(null))
},
errorClass = "CAST_INVALID_INPUT",
parameters = Map(
"expression" -> "'abc'",
"sourceType" -> "\"STRING\"",
"targetType" -> "\"BIGINT\"",
"ansiConfig" -> "\"spark.sql.ansi.enabled\""
),
context = ExpectedContext("", "", 0, expr.length - 1, expr))
}
}

object ReflectClass {
Expand Down

0 comments on commit 39832b8

Please sign in to comment.