Skip to content

Commit

Permalink
[SPARK-49967][SQL] Codegen Support for StructsToJson(to_json)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `StructsToJson`(`to_json`).

### Why are the changes needed?
- improve codegen coverage.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: JsonFunctionsSuite#`*to_json*`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48467 from panbingkun/SPARK-49967.

Authored-by: panbingkun <panbingkun@baidu.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
panbingkun authored and MaxGekk committed Oct 21, 2024
1 parent f86df1e commit b4eb034
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
*/
package org.apache.spark.sql.catalyst.expressions.json

import java.io.CharArrayWriter

import com.fasterxml.jackson.core.JsonFactory

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions}
import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode}
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions}
import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.Utils

object JsonExpressionEvalUtils {
Expand Down Expand Up @@ -111,3 +113,51 @@ case class JsonToStructsEvaluator(
}
}
}

case class StructsToJsonEvaluator(
options: Map[String, String],
inputSchema: DataType,
timeZoneId: Option[String]) extends Serializable {

@transient
private lazy val writer = new CharArrayWriter()

@transient
private lazy val gen = new JacksonGenerator(
inputSchema, writer, new JSONOptions(options, timeZoneId.get))

// This converts rows to the JSON output according to the given schema.
@transient
private lazy val converter: Any => UTF8String = {
def getAndReset(): UTF8String = {
gen.flush()
val json = writer.toString
writer.reset()
UTF8String.fromString(json)
}

inputSchema match {
case _: StructType =>
(row: Any) =>
gen.write(row.asInstanceOf[InternalRow])
getAndReset()
case _: ArrayType =>
(arr: Any) =>
gen.write(arr.asInstanceOf[ArrayData])
getAndReset()
case _: MapType =>
(map: Any) =>
gen.write(map.asInstanceOf[MapData])
getAndReset()
case _: VariantType =>
(v: Any) =>
gen.write(v.asInstanceOf[VariantVal])
getAndReset()
}
}

final def evaluate(value: Any): Any = {
if (value == null) return null
converter(value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator}
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, StructsToJsonEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

private[this] sealed trait PathInstruction
Expand Down Expand Up @@ -748,14 +747,15 @@ case class StructsToJson(
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
with RuntimeReplaceable
with ExpectsInputTypes
with NullIntolerant
with TimeZoneAwareExpression
with QueryErrorsBase {

override def nullable: Boolean = true

override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

def this(options: Map[String, String], child: Expression) = this(options, child, None)

// Used in `FunctionRegistry`
Expand All @@ -767,44 +767,7 @@ case class StructsToJson(
timeZoneId = None)

@transient
lazy val writer = new CharArrayWriter()

@transient
lazy val gen = new JacksonGenerator(
inputSchema, writer, new JSONOptions(options, timeZoneId.get))

@transient
lazy val inputSchema = child.dataType

// This converts rows to the JSON output according to the given schema.
@transient
lazy val converter: Any => UTF8String = {
def getAndReset(): UTF8String = {
gen.flush()
val json = writer.toString
writer.reset()
UTF8String.fromString(json)
}

inputSchema match {
case _: StructType =>
(row: Any) =>
gen.write(row.asInstanceOf[InternalRow])
getAndReset()
case _: ArrayType =>
(arr: Any) =>
gen.write(arr.asInstanceOf[ArrayData])
getAndReset()
case _: MapType =>
(map: Any) =>
gen.write(map.asInstanceOf[MapData])
getAndReset()
case _: VariantType =>
(v: Any) =>
gen.write(v.asInstanceOf[VariantVal])
getAndReset()
}
}
private lazy val inputSchema = child.dataType

override def dataType: DataType = SQLConf.get.defaultStringType

Expand All @@ -820,14 +783,23 @@ case class StructsToJson(
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(value: Any): Any = converter(value)

override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil

override def prettyName: String = "to_json"

override protected def withNewChildInternal(newChild: Expression): StructsToJson =
copy(child = newChild)

@transient
private lazy val evaluator = StructsToJsonEvaluator(options, inputSchema, timeZoneId)

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[StructsToJsonEvaluator])),
"evaluate",
dataType,
Seq(child),
Seq(child.dataType)
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance()
val resolver = ResolveTimeZone
val expr = resolver.resolveTimeZones(replace(expression))
val expr = replace(resolver.resolveTimeZones(expression))
assert(expr.resolved)
serializer.deserialize(serializer.serialize(expr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val schema = StructType(StructField("\"quote", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), schema)
GenerateUnsafeProjection.generate(
StructsToJson(Map.empty, struct, UTC_OPT) :: Nil)
StructsToJson(Map.empty, struct, UTC_OPT).replacement :: Nil)
}

test("to_json - struct") {
Expand Down Expand Up @@ -729,8 +729,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("from/to json - interval support") {
val schema = StructType(StructField("i", CalendarIntervalType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType),
UTC_OPT),
JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType)),
InternalRow(new CalendarInterval(12, 1, 0)))

Seq(MapType(CalendarIntervalType, IntegerType), MapType(IntegerType, CalendarIntervalType))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [to_json((timestampFormat,dd/MM/yyyy), d#0, Some(America/Los_Angeles)) AS to_json(d)#0]
Project [invoke(StructsToJsonEvaluator(Map(timestampFormat -> dd/MM/yyyy),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),Some(America/Los_Angeles)).evaluate(d#0)) AS to_json(d)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [to_json(struct(id, id#0L, a, a#0, b, b#0, d, d#0, e, e#0, f, f#0, g, g#0), Some(America/Los_Angeles)) AS to_json(struct(id, a, b, d, e, f, g))#0]
Project [invoke(StructsToJsonEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true),StructField(d,StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),true),StructField(e,ArrayType(IntegerType,true),true),StructField(f,MapType(StringType,StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),true),true),StructField(g,StringType,true)),Some(America/Los_Angeles)).evaluate(struct(id, id#0L, a, a#0, b, b#0, d, d#0, e, e#0, f, f#0, g, g#0))) AS to_json(struct(id, a, b, d, e, f, g))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
child.output.map { in =>
in.dataType match {
case _: ArrayType | _: MapType | _: StructType =>
new StructsToJson(ioschema.inputSerdeProps.toMap, in)
.withTimeZone(conf.sessionLocalTimeZone)
StructsToJson(ioschema.inputSerdeProps.toMap, in,
Some(conf.sessionLocalTimeZone)).replacement
case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone)
}
}
Expand Down

0 comments on commit b4eb034

Please sign in to comment.