1717
1818package org .apache .spark .sql .catalyst .analysis
1919
20- import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeReference , AttributeSet , Expression , GetStructField }
2120import org .apache .spark .sql .catalyst .plans .logical ._
2221import org .apache .spark .sql .catalyst .rules .Rule
2322import org .apache .spark .sql .catalyst .types .DataTypeUtils
@@ -37,88 +36,17 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
3736 override def apply (plan : LogicalPlan ): LogicalPlan = plan resolveOperators {
3837 // This rule should run only if all assignments are resolved, except those
3938 // that will be satisfied by schema evolution
40- case m @ MergeIntoTable (_, _, _, _, _, _, _) if m.needSchemaEvolution =>
41- val newTarget = m.targetTable.transform {
42- case r : DataSourceV2Relation => performSchemaEvolution(r, m)
43- }
44-
45- // Unresolve all references based on old target output
46- val targetOutput = m.targetTable.output
47- val unresolvedMergeCondition = unresolveCondition(m.mergeCondition, targetOutput)
48- val unresolvedMatchedActions = unresolveActions(m.matchedActions, targetOutput)
49- val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions, targetOutput)
50- val unresolvedNotMatchedBySourceActions =
51- unresolveActions(m.notMatchedBySourceActions, targetOutput)
52-
53- m.copy(
54- targetTable = newTarget,
55- mergeCondition = unresolvedMergeCondition,
56- matchedActions = unresolvedMatchedActions,
57- notMatchedActions = unresolvedNotMatchedActions,
58- notMatchedBySourceActions = unresolvedNotMatchedBySourceActions)
59- }
60-
61- private def unresolveActions (actions : Seq [MergeAction ], output : Seq [Attribute ]):
62- Seq [MergeAction ] = {
63- actions.map {
64- case UpdateAction (condition, assignments) =>
65- UpdateAction (condition.map(unresolveCondition(_, output)),
66- unresolveAssignmentKeys(assignments))
67- case InsertAction (condition, assignments) =>
68- InsertAction (condition.map(unresolveCondition(_, output)),
69- unresolveAssignmentKeys(assignments))
70- case DeleteAction (condition) =>
71- DeleteAction (condition.map(unresolveCondition(_, output)))
72- case other => other
73- }
74- }
75-
76- private def unresolveCondition (expr : Expression , output : Seq [Attribute ]): Expression = {
77- val outputSet = AttributeSet (output)
78- expr.transform {
79- case attr : AttributeReference if outputSet.contains(attr) =>
80- val nameParts = if (attr.qualifier.nonEmpty) {
81- attr.qualifier ++ Seq (attr.name)
82- } else {
83- Seq (attr.name)
84- }
85- UnresolvedAttribute (nameParts)
86- }
87- }
88-
89- private def unresolveAssignmentKeys (assignments : Seq [Assignment ]): Seq [Assignment ] = {
90- assignments.map { assignment =>
91- val unresolvedKey = assignment.key match {
92- case _ : UnresolvedAttribute => assignment.key
93- case gsf : GetStructField =>
94- // Recursively collect all nested GetStructField names and the base AttributeReference
95- val nameParts = collectStructFieldNames(gsf)
96- nameParts match {
97- case Some (names) => UnresolvedAttribute (names)
98- case None => assignment.key
99- }
100- case attr : AttributeReference =>
101- UnresolvedAttribute (Seq (attr.name))
102- case attr : Attribute =>
103- UnresolvedAttribute (Seq (attr.name))
104- case other => other
39+ case m@ MergeIntoTable (_, _, _, _, _, _, _) if m.needSchemaEvolution =>
40+ m transformUpWithNewOutput {
41+ case r : DataSourceV2Relation =>
42+ val newTarget = performSchemaEvolution(r, m)
43+ val oldTargetOutput = m.targetTable.output
44+ val newTargetOutput = newTarget.output
45+ val attributeMapping = oldTargetOutput.map(
46+ oldAttr => (oldAttr, newTargetOutput.find(_.name == oldAttr.name).getOrElse(oldAttr))
47+ )
48+ newTarget -> attributeMapping
10549 }
106- Assignment (unresolvedKey, assignment.value)
107- }
108- }
109-
110- private def collectStructFieldNames (expr : Expression ): Option [Seq [String ]] = {
111- expr match {
112- case GetStructField (child, _, Some (fieldName)) =>
113- collectStructFieldNames(child) match {
114- case Some (childNames) => Some (childNames :+ fieldName)
115- case None => None
116- }
117- case attr : AttributeReference =>
118- Some (Seq (attr.name))
119- case _ =>
120- None
121- }
12250 }
12351
12452 private def performSchemaEvolution (relation : DataSourceV2Relation , m : MergeIntoTable )
0 commit comments