Skip to content

Commit dada2b9

Browse files
committed
Review comments
1 parent 4ac5b0a commit dada2b9

File tree

2 files changed

+12
-86
lines changed

2 files changed

+12
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,12 +1672,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
16721672
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _)
16731673
if !m.resolved && targetTable.resolved && sourceTable.resolved =>
16741674

1675-
// Do not throw exception for schema evolution case if it has not had a chance to run.
1675+
// Do not throw exception for schema evolution case.
16761676
// This allows unresolved assignment keys a chance to be resolved by a second pass
16771677
// by newly column/nested fields added by schema evolution.
1678-
// If schema evolution has already had a chance to run, this will be the final pass
1679-
val throws = !m.schemaEvolutionEnabled ||
1680-
(m.canEvaluateSchemaEvolution && !m.schemaChangesNonEmpty)
1678+
val throws = !m.schemaEvolutionEnabled
16811679

16821680
EliminateSubqueryAliases(targetTable) match {
16831681
case r: NamedRelation if r.skipSchemaResolution =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala

Lines changed: 10 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, GetStructField}
2120
import org.apache.spark.sql.catalyst.plans.logical._
2221
import org.apache.spark.sql.catalyst.rules.Rule
2322
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -37,88 +36,17 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
3736
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
3837
// This rule should run only if all assignments are resolved, except those
3938
// that will be satisfied by schema evolution
40-
case m @ MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution =>
41-
val newTarget = m.targetTable.transform {
42-
case r : DataSourceV2Relation => performSchemaEvolution(r, m)
43-
}
44-
45-
// Unresolve all references based on old target output
46-
val targetOutput = m.targetTable.output
47-
val unresolvedMergeCondition = unresolveCondition(m.mergeCondition, targetOutput)
48-
val unresolvedMatchedActions = unresolveActions(m.matchedActions, targetOutput)
49-
val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions, targetOutput)
50-
val unresolvedNotMatchedBySourceActions =
51-
unresolveActions(m.notMatchedBySourceActions, targetOutput)
52-
53-
m.copy(
54-
targetTable = newTarget,
55-
mergeCondition = unresolvedMergeCondition,
56-
matchedActions = unresolvedMatchedActions,
57-
notMatchedActions = unresolvedNotMatchedActions,
58-
notMatchedBySourceActions = unresolvedNotMatchedBySourceActions)
59-
}
60-
61-
private def unresolveActions(actions: Seq[MergeAction], output: Seq[Attribute]):
62-
Seq[MergeAction] = {
63-
actions.map {
64-
case UpdateAction(condition, assignments) =>
65-
UpdateAction(condition.map(unresolveCondition(_, output)),
66-
unresolveAssignmentKeys(assignments))
67-
case InsertAction(condition, assignments) =>
68-
InsertAction(condition.map(unresolveCondition(_, output)),
69-
unresolveAssignmentKeys(assignments))
70-
case DeleteAction(condition) =>
71-
DeleteAction(condition.map(unresolveCondition(_, output)))
72-
case other => other
73-
}
74-
}
75-
76-
private def unresolveCondition(expr: Expression, output: Seq[Attribute]): Expression = {
77-
val outputSet = AttributeSet(output)
78-
expr.transform {
79-
case attr: AttributeReference if outputSet.contains(attr) =>
80-
val nameParts = if (attr.qualifier.nonEmpty) {
81-
attr.qualifier ++ Seq(attr.name)
82-
} else {
83-
Seq(attr.name)
84-
}
85-
UnresolvedAttribute(nameParts)
86-
}
87-
}
88-
89-
private def unresolveAssignmentKeys(assignments: Seq[Assignment]): Seq[Assignment] = {
90-
assignments.map { assignment =>
91-
val unresolvedKey = assignment.key match {
92-
case _: UnresolvedAttribute => assignment.key
93-
case gsf: GetStructField =>
94-
// Recursively collect all nested GetStructField names and the base AttributeReference
95-
val nameParts = collectStructFieldNames(gsf)
96-
nameParts match {
97-
case Some(names) => UnresolvedAttribute(names)
98-
case None => assignment.key
99-
}
100-
case attr: AttributeReference =>
101-
UnresolvedAttribute(Seq(attr.name))
102-
case attr: Attribute =>
103-
UnresolvedAttribute(Seq(attr.name))
104-
case other => other
39+
case m@MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution =>
40+
m transformUpWithNewOutput {
41+
case r: DataSourceV2Relation =>
42+
val newTarget = performSchemaEvolution(r, m)
43+
val oldTargetOutput = m.targetTable.output
44+
val newTargetOutput = newTarget.output
45+
val attributeMapping = oldTargetOutput.map(
46+
oldAttr => (oldAttr, newTargetOutput.find(_.name == oldAttr.name).getOrElse(oldAttr))
47+
)
48+
newTarget -> attributeMapping
10549
}
106-
Assignment(unresolvedKey, assignment.value)
107-
}
108-
}
109-
110-
private def collectStructFieldNames(expr: Expression): Option[Seq[String]] = {
111-
expr match {
112-
case GetStructField(child, _, Some(fieldName)) =>
113-
collectStructFieldNames(child) match {
114-
case Some(childNames) => Some(childNames :+ fieldName)
115-
case None => None
116-
}
117-
case attr: AttributeReference =>
118-
Some(Seq(attr.name))
119-
case _ =>
120-
None
121-
}
12250
}
12351

12452
private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable)

0 commit comments

Comments
 (0)