Skip to content

Commit

Permalink
set active session for commands
Browse files Browse the repository at this point in the history
Signed-off-by: Eunjin Song <sezruby@gmail.com>
Co-authored-by: Chungmin Lee <lee@chungmin.dev>
  • Loading branch information
sezruby and Chungmin Lee committed Oct 4, 2023
1 parent d42a22d commit 2f0c669
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 94 deletions.
38 changes: 20 additions & 18 deletions core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package io.delta.tables
import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, PreprocessTableMerge}
import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils, PreprocessTableMerge}
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.util.AnalysisHelper

Expand Down Expand Up @@ -203,24 +203,26 @@ class DeltaMergeBuilder private(
*/
def execute(): Unit = improveUnsupportedOpError {
val sparkSession = targetTable.toDF.sparkSession
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as explained
// in the function `mergePlan` and https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferences(sparkSession) _)
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
DeltaTableUtils.withActiveSession(sparkSession) {
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as
// explained in the function `mergePlan` and
// https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferences(sparkSession) _)
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
}
// Preprocess the actions and verify
val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}
// Preprocess the actions and verify
val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}

/**
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/io/delta/tables/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,9 @@ object DeltaTable {
*/
@Evolving
def createOrReplace(spark: SparkSession): DeltaTableBuilder = {
new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true))
DeltaTableUtils.withActiveSession(spark) {
new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true))
}
}

/**
Expand Down
114 changes: 58 additions & 56 deletions core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,70 +296,72 @@ class DeltaTableBuilder private[tables](
*/
@Evolving
def execute(): DeltaTable = {
if (identifier == null && location.isEmpty) {
throw DeltaErrors.analysisException("Table name or location has to be specified")
}
DeltaTableUtils.withActiveSession(spark) {
if (identifier == null && location.isEmpty) {
throw DeltaErrors.analysisException("Table name or location has to be specified")
}

if (this.identifier == null) {
identifier = s"delta.`${location.get}`"
}
if (this.identifier == null) {
identifier = s"delta.`${location.get}`"
}

// Return DeltaTable Object.
val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier)
// Return DeltaTable Object.
val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier)

if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty
&& tableId.table != location.get) {
throw DeltaErrors.analysisException(
s"Creating path-based Delta table with a different location isn't supported. "
+ s"Identifier: $identifier, Location: ${location.get}")
}
if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty
&& tableId.table != location.get) {
throw DeltaErrors.analysisException(
s"Creating path-based Delta table with a different location isn't supported. "
+ s"Identifier: $identifier, Location: ${location.get}")
}

val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier)
val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier)

val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => DeltaTableUtils.parseColToTransform(name))
}.getOrElse(Seq.empty[Transform])
val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => DeltaTableUtils.parseColToTransform(name))
}.getOrElse(Seq.empty[Transform])

val stmt = builderOption match {
case CreateTableOptions(ifNotExists) =>
CreateTableStatement(
table,
StructType(columns),
partitioning,
None,
this.properties,
Some(FORMAT_NAME),
Map.empty,
location,
tblComment,
None,
false,
ifNotExists
)
case ReplaceTableOptions(orCreate) =>
ReplaceTableStatement(
table,
StructType(columns),
partitioning,
None,
this.properties,
Some(FORMAT_NAME),
Map.empty,
location,
tblComment,
None,
orCreate
)
}
val qe = spark.sessionState.executePlan(stmt)
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd)
val stmt = builderOption match {
case CreateTableOptions(ifNotExists) =>
CreateTableStatement(
table,
StructType(columns),
partitioning,
None,
this.properties,
Some(FORMAT_NAME),
Map.empty,
location,
tblComment,
None,
false,
ifNotExists
)
case ReplaceTableOptions(orCreate) =>
ReplaceTableStatement(
table,
StructType(columns),
partitioning,
None,
this.properties,
Some(FORMAT_NAME),
Map.empty,
location,
tblComment,
None,
orCreate
)
}
val qe = spark.sessionState.executePlan(stmt)
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd)

// Return DeltaTable Object.
if (DeltaTableUtils.isValidPath(tableId)) {
// Return DeltaTable Object.
if (DeltaTableUtils.isValidPath(tableId)) {
DeltaTable.forPath(location.get)
} else {
DeltaTable.forName(this.identifier)
} else {
DeltaTable.forName(this.identifier)
}
}
}
}
15 changes: 9 additions & 6 deletions core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.delta.tables.execution

import org.apache.spark.sql.delta.DeltaTableUtils
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import io.delta.tables.DeltaTable

Expand All @@ -29,12 +30,14 @@ trait DeltaConvertBase {
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
deltaPath: Option[String]): DeltaTable = {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath)
cvt.run(spark)
if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) {
DeltaTable.forName(spark, tableIdentifier.toString)
} else {
DeltaTable.forPath(spark, tableIdentifier.table)
DeltaTableUtils.withActiveSession(spark) {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath)
cvt.run(spark)
if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) {
DeltaTable.forName(spark, tableIdentifier.toString)
} else {
DeltaTable.forPath(spark, tableIdentifier.table)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.delta.tables.execution

import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, PreprocessTableUpdate}
import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, DeltaTableUtils, PreprocessTableUpdate}
import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaGenerateCommand, VacuumCommand}
import org.apache.spark.sql.delta.util.AnalysisHelper
import io.delta.tables.DeltaTable
Expand All @@ -28,24 +28,29 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.SparkSession

/**
* Interface to provide the actual implementations of DeltaTable operations.
*/
trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable =>

protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError {
val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition)
toDataset(sparkSession, delete)
DeltaTableUtils.withActiveSession(sparkSession) {
val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition)
toDataset(sparkSession, delete)
}
}

protected def executeHistory(
deltaLog: DeltaLog,
limit: Option[Int] = None,
tableId: Option[TableIdentifier] = None): DataFrame = {
val history = deltaLog.history
val spark = self.toDF.sparkSession
spark.createDataFrame(history.getHistory(limit))
DeltaTableUtils.withActiveSession(sparkSession) {
val history = deltaLog.history
val spark = self.toDF.sparkSession
spark.createDataFrame(history.getHistory(limit))
}
}

protected def executeGenerate(tblIdentifier: String, mode: String): Unit = {
Expand All @@ -60,19 +65,24 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable =>
protected def executeUpdate(
set: Map[String, Column],
condition: Option[Column]): Unit = improveUnsupportedOpError {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update = UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
DeltaTableUtils.withActiveSession(sparkSession) {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update =
UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
}
}

protected def executeVacuum(
deltaLog: DeltaLog,
retentionHours: Option[Double],
tableId: Option[TableIdentifier] = None): DataFrame = {
VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours)
sparkSession.emptyDataFrame
DeltaTableUtils.withActiveSession(sparkSession) {
VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours)
sparkSession.emptyDataFrame
}
}

protected def toStrColumnMap(map: Map[String, String]): Map[String, Column] = {
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,14 @@ object DeltaTableUtils extends PredicateHelper
def parseColToTransform(col: String): IdentityTransform = {
IdentityTransform(FieldReference(Seq(col)))
}

def withActiveSession[T](spark: SparkSession)(body: => T): T = {
val old = SparkSession.getActiveSession
SparkSession.setActiveSession(spark)
try {
body
} finally {
SparkSession.setActiveSession(old.getOrElse(null))
}
}
}

0 comments on commit 2f0c669

Please sign in to comment.