Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,25 @@ package org.apache.spark.sql.execution.datasources.v2

import org.roaringbitmap.longlong.Roaring64Bitmap

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.BasePredicate
import org.apache.spark.sql.catalyst.expressions.BindReferences
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, GeneratePredicate, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.UnaryExecNode
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.BooleanType

case class MergeRowsExec(
isSourceRowPresent: Expression,
Expand All @@ -44,7 +47,7 @@ case class MergeRowsExec(
notMatchedBySourceInstructions: Seq[Instruction],
checkCardinality: Boolean,
output: Seq[Attribute],
child: SparkPlan) extends UnaryExecNode {
child: SparkPlan) extends UnaryExecNode with CodegenSupport {

override lazy val metrics: Map[String, SQLMetric] = Map(
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
Expand Down Expand Up @@ -92,6 +95,277 @@ case class MergeRowsExec(
child.execute().mapPartitions(processPartition)
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
// Save the input variables that were passed to doConsume
val inputCurrentVars = input

// code for instruction execution code
generateInstructionExecutionCode(ctx, inputCurrentVars)
}


/**
* code for cardinality validation
*/
private def generateCardinalityValidationCode(ctx: CodegenContext, rowIdOrdinal: Int,
input: Seq[ExprCode]): ExprCode = {
val bitmapClass = classOf[Roaring64Bitmap]
val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
v => s"$v = new ${bitmapClass.getName}();")

val currentRowId = input(rowIdOrdinal)
val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + ".MODULE$"
val code =
code"""
|${currentRowId.code}
|if ($rowIdBitmap.contains(${currentRowId.value})) {
| throw $queryExecutionErrorsClass.mergeCardinalityViolationError();
|}
|$rowIdBitmap.add(${currentRowId.value});
""".stripMargin
ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
}

/**
* Generate code for instruction execution based on row presence conditions
*/
private def generateInstructionExecutionCode(ctx: CodegenContext,
inputExprs: Seq[ExprCode]): String = {

// code for evaluating src/tgt presence conditions
val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs)
val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs)

// code for each instruction type
val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions,
"matched", inputExprs, sourcePresent = true)
val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions,
"notMatched", inputExprs, sourcePresent = true)
val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, sourcePresent = false)

val cardinalityValidationCode = if (checkCardinality) {
val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID))
assert(rowIdOrdinal != -1, "Cannot find row ID attr")
generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
} else {
""
}

s"""
|${sourcePresentExpr.code}
|${targetPresentExpr.code}
|
|if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
| $cardinalityValidationCode
| $matchedInstructionsCode
|} else if (${sourcePresentExpr.value}) {
| $notMatchedInstructionsCode
|} else if (${targetPresentExpr.value}) {
| $notMatchedBySourceInstructionsCode
|}
""".stripMargin
}

/**
* Generate code for executing a sequence of instructions
*/
private def generateInstructionsCode(ctx: CodegenContext, instructions: Seq[Instruction],
instructionType: String,
inputExprs: Seq[ExprCode],
sourcePresent: Boolean): String = {
if (instructions.isEmpty) {
""
} else {
val instructionCodes = instructions.map(instruction =>
generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent))

s"""
|${instructionCodes.mkString("\n")}
|return;
""".stripMargin
}
}

private def generateSingleInstructionCode(ctx: CodegenContext,
instruction: Instruction,
inputExprs: Seq[ExprCode],
sourcePresent: Boolean): String = {
instruction match {
case Keep(context, condition, outputExprs) =>
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)

// Generate metric updates based on context
val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent)

s"""
|${code.code}
|if (${code.value}) {
| $metricUpdateCode
| ${consume(ctx, projectionExpr)}
| return;
|}
""".stripMargin

case Discard(condition) =>
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent)

s"""
|${code.code}
|if (${code.value}) {
| $metricUpdateCode
| return; // Discar row
|}
""".stripMargin

case Split(condition, outputExprs, otherOutputExprs) =>
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs)
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent)

s"""
|${code.code}
|if (${code.value}) {
| $metricUpdateCode
| ${consume(ctx, projectionExpr)}
| ${consume(ctx, otherProjectionExpr)}
| return;
|}
""".stripMargin
case _ =>
// Codegen not implemented
throw new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_3073",
messageParameters = Map("instruction" -> instruction.toString))
}
}

/**
* metric update code based on Keep's context
*/
private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
sourcePresent: Boolean): String = {
context match {
case Copy =>
val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
s"$copyMetric.add(1);"

case Insert =>
val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
s"$insertMetric.add(1);"

case Update =>
generateUpdateMetricUpdateCode(ctx, sourcePresent)

case Delete =>
generateDeleteMetricUpdateCode(ctx, sourcePresent)

case _ =>
throw new IllegalArgumentException(s"Unexpected context for KeepExec: $context")
}
}

private def generateUpdateMetricUpdateCode(ctx: CodegenContext,
sourcePresent: Boolean): String = {
val updateMetric = metricTerm(ctx, "numTargetRowsUpdated")
if (sourcePresent) {
val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated")

s"""
|$updateMetric.add(1);
|$matchedUpdateMetric.add(1);
""".stripMargin
} else {
val notMatchedBySourceUpdateMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceUpdated")

s"""
|$updateMetric.add(1);
|$notMatchedBySourceUpdateMetric.add(1);
""".stripMargin
}
}

private def generateDeleteMetricUpdateCode(ctx: CodegenContext,
sourcePresent: Boolean): String = {
val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted")
if (sourcePresent) {
val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted")

s"""
|$deleteMetric.add(1);
|$matchedDeleteMetric.add(1);
""".stripMargin
} else {
val notMatchedBySourceDeleteMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceDeleted")

s"""
|$deleteMetric.add(1);
|$notMatchedBySourceDeleteMetric.add(1);
""".stripMargin
}
}

/**
* Helper method to save and restore CodegenContext state for code generation.
*
* This is needed because when generating code for expressions, the CodegenContext
* state (currentVars and INPUT_ROW) gets modified during expression evaluation.
* This method temporarily sets the context to the input variables from doConsume
* and restores the original state after the block completes.
*/
private def withCodegenContext[T](
ctx: CodegenContext,
inputCurrentVars: Seq[ExprCode])(block: => T): T = {
val originalCurrentVars = ctx.currentVars
val originalInputRow = ctx.INPUT_ROW
try {
// Set to the input variables saved in doConsume
ctx.currentVars = inputCurrentVars
block
} finally {
// Restore original context
ctx.currentVars = originalCurrentVars
ctx.INPUT_ROW = originalInputRow
}
}

private def generatePredicateCode(ctx: CodegenContext,
predicate: Expression,
inputAttrs: Seq[Attribute],
inputCurrentVars: Seq[ExprCode]): ExprCode = {
withCodegenContext(ctx, inputCurrentVars) {
val boundPredicate = BindReferences.bindReference(predicate, inputAttrs)
val ev = boundPredicate.genCode(ctx)
val predicateVar = ctx.freshName("predicateResult")
val code = code"""
|${ev.code}
|boolean $predicateVar = !${ev.isNull} && ${ev.value};
""".stripMargin
ExprCode(code, FalseLiteral,
JavaCode.variable(predicateVar, BooleanType))
}
}

private def generateProjectionCode(ctx: CodegenContext,
outputExprs: Seq[Expression],
inputCurrentVars: Seq[ExprCode]): Seq[ExprCode] = {
withCodegenContext(ctx, inputCurrentVars) {
val boundExprs = outputExprs.map(BindReferences.bindReference(_, child.output))
boundExprs.map(_.genCode(ctx))
}
}

private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
val isTargetRowPresentPred = createPredicate(isTargetRowPresent)
Expand Down