Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark][2.4] Fix a data loss bug in MergeIntoCommand #2157

Open
wants to merge 3 commits into
base: branch-2.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, PreprocessTableMerge}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.DeltaViewHelper
import org.apache.spark.sql.delta.commands.MergeIntoCommand
import org.apache.spark.sql.delta.util.AnalysisHelper
Expand Down Expand Up @@ -265,29 +266,32 @@ class DeltaMergeBuilder private(
*/
def execute(): Unit = improveUnsupportedOpError {
val sparkSession = targetTable.toDF.sparkSession
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as explained
// in the function `mergePlan` and https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
withActiveSession(sparkSession) {
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as
// explained in the function `mergePlan` and
// https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferencesAndSchema(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferences(sparkSession) _)
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
}
val strippedMergeInto = resolvedMergeInto.copy(
target = DeltaViewHelper.stripTempViewForMerge(resolvedMergeInto.target, SQLConf.get)
)
// Preprocess the actions and verify
val mergeIntoCommand =
PreprocessTableMerge(sparkSession.sessionState.conf)(strippedMergeInto)
.asInstanceOf[MergeIntoCommand]
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}
val strippedMergeInto = resolvedMergeInto.copy(
target = DeltaViewHelper.stripTempViewForMerge(resolvedMergeInto.target, SQLConf.get)
)
// Preprocess the actions and verify
val mergeIntoCommand =
PreprocessTableMerge(sparkSession.sessionState.conf)(strippedMergeInto)
.asInstanceOf[MergeIntoCommand]
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}

/**
Expand Down
20 changes: 11 additions & 9 deletions core/src/main/scala/io/delta/tables/DeltaOptimizeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.delta.tables

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.commands.OptimizeTableCommand
import org.apache.spark.sql.delta.util.AnalysisHelper

Expand Down Expand Up @@ -75,15 +76,16 @@ class DeltaOptimizeBuilder private(
execute(attrs)
}

private def execute(zOrderBy: Seq[UnresolvedAttribute]): DataFrame = {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tableIdentifier)
val optimize =
OptimizeTableCommand(None, Some(tableId), partitionFilter, options)(zOrderBy = zOrderBy)
toDataset(sparkSession, optimize)
}
private def execute(zOrderBy: Seq[UnresolvedAttribute]): DataFrame =
withActiveSession(sparkSession) {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tableIdentifier)
val optimize =
OptimizeTableCommand(None, Some(tableId), partitionFilter, options)(zOrderBy = zOrderBy)
toDataset(sparkSession, optimize)
}
}

private[delta] object DeltaOptimizeBuilder {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/io/delta/tables/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.AlterTableSetPropertiesDeltaCommand
Expand Down Expand Up @@ -976,7 +977,7 @@ object DeltaTable {
* @since 1.0.0
*/
@Evolving
def createOrReplace(spark: SparkSession): DeltaTableBuilder = {
def createOrReplace(spark: SparkSession): DeltaTableBuilder = withActiveSession(spark) {
new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true))
}

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.mutable

import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import io.delta.tables.execution._

Expand Down Expand Up @@ -302,7 +303,7 @@ class DeltaTableBuilder private[tables](
* @since 1.0.0
*/
@Evolving
def execute(): DeltaTable = {
def execute(): DeltaTable = withActiveSession(spark) {
if (identifier == null && location.isEmpty) {
throw DeltaErrors.analysisException("Table name or location has to be specified")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.delta.tables.execution

import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import io.delta.tables.DeltaTable

Expand All @@ -28,7 +29,7 @@ trait DeltaConvertBase {
spark: SparkSession,
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
deltaPath: Option[String]): DeltaTable = {
deltaPath: Option[String]): DeltaTable = withActiveSession(spark) {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, collectStats = true,
deltaPath)
cvt.run(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.Map

import org.apache.spark.sql.catalyst.TimeTravel
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.{DeltaGenerateCommand, DescribeDeltaDetailCommand, VacuumCommand}
import org.apache.spark.sql.delta.util.AnalysisHelper
Expand All @@ -39,59 +40,64 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable =>

protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError {
val delete = DeleteFromTable(
self.toDF.queryExecution.analyzed,
condition.getOrElse(Literal.TrueLiteral))
toDataset(sparkSession, delete)
withActiveSession(sparkSession) {
val delete = DeleteFromTable(
self.toDF.queryExecution.analyzed,
condition.getOrElse(Literal.TrueLiteral))
toDataset(sparkSession, delete)
}
}

protected def executeHistory(
deltaLog: DeltaLog,
limit: Option[Int] = None,
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
val history = deltaLog.history
val spark = self.toDF.sparkSession
spark.createDataFrame(history.getHistory(limit))
sparkSession.createDataFrame(history.getHistory(limit))
}

protected def executeDetails(
path: String,
tableIdentifier: Option[TableIdentifier]): DataFrame = {
tableIdentifier: Option[TableIdentifier]): DataFrame = withActiveSession(sparkSession) {
val details = DescribeDeltaDetailCommand(Option(path), tableIdentifier, self.deltaLog.options)
toDataset(sparkSession, details)
}

protected def executeGenerate(tblIdentifier: String, mode: String): Unit = {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId, self.deltaLog.options)
toDataset(sparkSession, generate)
}
protected def executeGenerate(tblIdentifier: String, mode: String): Unit =
withActiveSession(sparkSession) {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId, self.deltaLog.options)
toDataset(sparkSession, generate)
}

protected def executeUpdate(
set: Map[String, Column],
condition: Option[Column]): Unit = improveUnsupportedOpError {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update = UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
withActiveSession(sparkSession) {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update =
UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
}
}

protected def executeVacuum(
deltaLog: DeltaLog,
retentionHours: Option[Double],
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours)
sparkSession.emptyDataFrame
}

protected def executeRestore(
table: DeltaTableV2,
versionAsOf: Option[Long],
timestampAsOf: Option[String]): DataFrame = {
timestampAsOf: Option[String]): DataFrame = withActiveSession(sparkSession) {
val identifier = table.getTableIdentifierIfExists.map(
id => Identifier.of(id.database.toArray, id.table))
val sourceRelation = DataSourceV2Relation.create(table, None, identifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ object DeltaTableUtils extends PredicateHelper
}
}

// Workaround for withActive not being visible in io/delta.
def withActiveSession[T](spark: SparkSession)(body: => T): T = spark.withActive(body)

def parseColToTransform(col: String): IdentityTransform = {
IdentityTransform(FieldReference(Seq(col)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5102,6 +5102,35 @@ abstract class MergeIntoSuiteBase
Option("Aggregate functions are not supported in the .* condition of MERGE operation.*")
)

test("Merge should use the same SparkSession consistently") {
withTempDir { dir =>
withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") {
val r = dir.getCanonicalPath
val sourcePath = s"$r/source"
val targetPath = s"$r/target"
val numSourceRecords = 20
spark.range(numSourceRecords)
.withColumn("x", $"id")
.withColumn("y", $"id")
.write.mode("overwrite").format("delta").save(sourcePath)
spark.range(1)
.withColumn("x", $"id")
.write.mode("overwrite").format("delta").save(targetPath)
val spark2 = spark.newSession
spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true")
val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath)
val source = spark.read.format("delta").load(sourcePath).alias("s")
val merge = target.alias("t")
.merge(source, "t.id = s.id")
.whenMatched.updateExpr(Map("t.x" -> "t.x + 1"))
.whenNotMatched.insertAll()
.execute()
// The target table should have the same number of rows as the source after the merge
assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords)
}
}
}

testWithTempView("test merge on temp view - basic") { isSQLTempView =>
withTable("tab") {
withTempView("src") {
Expand Down
30 changes: 30 additions & 0 deletions python/delta/tests/test_deltatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
import os
from typing import List, Set, Dict, Optional, Any, Callable, Union, Tuple
from multiprocessing.pool import ThreadPool

from pyspark.sql import DataFrame, Row
from pyspark.sql.column import _to_seq # type: ignore[attr-defined]
Expand Down Expand Up @@ -469,6 +470,35 @@ def reset_table() -> None:
with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
dt.merge(source, "key = k").whenNotMatchedBySourceDelete(1) # type: ignore[arg-type]

def test_merge_with_inconsistent_sessions(self) -> None:
source_path = os.path.join(self.tempFile, "source")
target_path = os.path.join(self.tempFile, "target")
spark = self.spark

def f(spark):
spark.range(20) \
.withColumn("x", col("id")) \
.withColumn("y", col("id")) \
.write.mode("overwrite").format("delta").save(source_path)
spark.range(1) \
.withColumn("x", col("id")) \
.write.mode("overwrite").format("delta").save(target_path)
target = DeltaTable.forPath(spark, target_path)
source = spark.read.format("delta").load(source_path).alias("s")
target.alias("t") \
.merge(source, "t.id = s.id") \
.whenMatchedUpdate(set={"t.x": "t.x + 1"}) \
.whenNotMatchedInsertAll() \
.execute()
assert(spark.read.format("delta").load(target_path).count() == 20)

pool = ThreadPool(3)
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
try:
pool.starmap(f, [(spark,)])
finally:
spark.conf.unset("spark.databricks.delta.schema.autoMerge.enabled")

def test_history(self) -> None:
self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)])
Expand Down
Loading