Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
commit

uniform expression

commit

commit

commit
  • Loading branch information
dtenedor committed Sep 5, 2024
1 parent 48f9cc7 commit d9223e5
Show file tree
Hide file tree
Showing 6 changed files with 1,247 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
expression[RandStr]("randstr"),
expression[Stack]("stack"),
expression[Uniform]("uniform"),
expression[ZeroIfNull]("zeroifnull"),
CaseWhen.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.spark.sql.catalyst.expressions

import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.random.XORShiftRandom

/**
Expand Down Expand Up @@ -181,3 +187,211 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
}

@ExpressionDescription(
usage = """
_FUNC_(min, max, seed) - Returns a random value with independent and identically
distributed (i.i.d.) values with the specified range of numbers. The random seed is optional.
The provided numbers specifying the minimum and maximum values of the range must be constant.
If both of these numbers are integers, then the result will also be an integer. Otherwise if
one or both of these are floating-point numbers, then the result will also be a floating-point
number.
""",
examples = """
Examples:
> SELECT _FUNC_(0, 1);
-0.3254147983080288
> SELECT _FUNC_(10, 20, 0);
26.034991609278433
""",
since = "4.0.0",
group = "math_funcs")
case class Uniform(min: Expression, max: Expression, seed: Expression)
extends RuntimeReplaceable with TernaryLike[Expression] with ExpressionWithRandomSeed {
def this(min: Expression, max: Expression) =
this(min, max, Literal(Uniform.random.nextLong(), LongType))

final override lazy val deterministic: Boolean = false
override val nodePatterns: Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)

override val dataType: DataType = {
val first = min.dataType
val second = max.dataType
(min.dataType, max.dataType) match {
case _ if !valid(min) || !valid(max) => NullType
case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) => LongType
case (_, IntegerType) | (IntegerType, _) if Seq(first, second).forall(integer) => IntegerType
case (_, ShortType) | (ShortType, _) if Seq(first, second).forall(integer) => ShortType
case (_, DoubleType) | (DoubleType, _) => DoubleType
case (_, FloatType) | (FloatType, _) => FloatType
case _ => NullType
}
}

private def valid(e: Expression): Boolean = e.dataType match {
case _ if !e.foldable => false
case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType => true
case _ => false
}

private def integer(t: DataType): Boolean = t match {
case _: ShortType | _: IntegerType | _: LongType => true
case _ => false
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: Int) =>
if (!valid(expr)) {
result = DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(index),
"requiredType" -> "constant value of integer or floating-point",
"inputSql" -> toSQLExpr(expr),
"inputType" -> toSQLType(expr.dataType)))
}
}
result
}

override def first: Expression = min
override def second: Expression = max
override def third: Expression = seed

override def seedExpression: Expression = seed
override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType))

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird)

override def replacement: Expression = {
def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to)
cast(Add(
cast(min, DoubleType),
Multiply(
Subtract(
cast(max, DoubleType),
cast(min, DoubleType)),
Rand(seed))),
dataType)
}
}

object Uniform {
lazy val random = new Random()
}

@ExpressionDescription(
usage = """
_FUNC_(length, seed) - Returns a string of the specified length whose characters are chosen
uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is
optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT,
respectively).
""",
examples =
"""
Examples:
> SELECT _FUNC_(3, 0);
abc
""",
since = "4.0.0",
group = "math_funcs")
case class RandStr(length: Expression, override val seedExpression: Expression)
extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic {
def this(length: Expression) = this(length, Literal(Uniform.random.nextLong(), LongType))

override def nullable: Boolean = false
override def dataType: DataType = StringType
override def stateful: Boolean = true
override def left: Expression = length
override def right: Expression = seedExpression

/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
*/
@transient protected var rng: XORShiftRandom = _

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}
override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType))
override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
RandStr(newFirst, newSecond)

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
Seq(length, seedExpression).zipWithIndex.foreach { case (expr: Expression, index: Int) =>
val valid = expr.dataType match {
case _ if !expr.foldable => false
case _: ShortType | _: IntegerType => true
case _: LongType if index == 1 => true
case _ => false
}
if (!valid) {
result = DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(index),
"requiredType" -> "constant value of INT or SMALLINT",
"inputSql" -> toSQLExpr(expr),
"inputType" -> toSQLType(expr.dataType)))
}
}
result
}

override def evalInternal(input: InternalRow): Any = {
val numChars: Int = length.eval(input).asInstanceOf[Int]
val bytes = new Array[Byte](numChars)
(0 until numChars).foreach { i =>
val num = (rng.nextInt() % 30).abs
num match {
case _ if num < 10 =>
bytes.update(i, ('0' + num).toByte)
case _ if num < 20 =>
bytes.update(i, ('a' + num - 10).toByte)
case _ =>
bytes.update(i, ('A' + num - 20).toByte)
}
}
val result: UTF8String = UTF8String.fromBytes(bytes.toArray)
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = classOf[XORShiftRandom].getName
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
val eval = length.genCode(ctx)
ev.copy(code =
code"""
|${eval.code}
|int length = (int)(${eval.value});
|char[] chars = new char[length];
|for (int i = 0; i < length; i++) {
| int v = Math.abs($rngTerm.nextInt() % 30);
| if (v < 10) {
| chars[i] = (char)('0' + v);
| } else if (v < 20) {
| chars[i] = (char)('a' + (v - 10));
| } else {
| chars[i] = (char)('A' + (v - 20));
| }
|}
|UTF8String ${ev.value} = UTF8String.fromString(new String(chars));
|boolean ${ev.isNull} = false;
|""".stripMargin,
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
| org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct<raise_error(USER_RAISED_EXCEPTION, map(errorMessage, custom error message)):void> |
| org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) | struct<randstr(3, 0):string> |
| org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct<randn():double> |
| org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int> |
| org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v&#124;ph)en') | struct<regexp_count(Steven Jones and Stephen Smith are the best players, Ste(v&#124;ph)en):int> |
Expand Down Expand Up @@ -367,6 +368,7 @@
| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct<negative(1):int> |
| org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> |
| org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct<decode(unhex(537061726B2053514C), UTF-8):string> |
| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct<uniform(0, 1, 8191290556685094889):int> |
| org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct<unix_date(1970-01-02):int> |
| org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_micros(1970-01-01 00:00:01Z):bigint> |
| org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_millis(1970-01-01 00:00:01Z):bigint> |
Expand Down
Loading

0 comments on commit d9223e5

Please sign in to comment.