From f1ea7ef94c43be8f41680929dfb1d201655543ad Mon Sep 17 00:00:00 2001 From: Thang Long Vu <107926660+longvu-db@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:46:59 +0200 Subject: [PATCH 1/6] [Spark] Add Delta Connect Merge Server and Scala Client (#3580) ## Description Add support for `merge` for Delta Connect Server and Scala Client. ## How was this patch tested? Added UTs. ## Does this PR introduce _any_ user-facing changes? No. --- python/delta/connect/proto/relations_pb2.py | 54 ++- python/delta/connect/proto/relations_pb2.pyi | 267 ++++++++++++ .../io/delta/connect/tables/DeltaTable.scala | 110 +++++ .../tables/DeltaMergeBuilderSuite.scala | 411 ++++++++++++++++++ .../protobuf/delta/connect/relations.proto | 63 +++ .../delta/connect/DeltaRelationPlugin.scala | 97 ++++- .../connect/DeltaConnectPlannerSuite.scala | 153 +++++++ 7 files changed, 1133 insertions(+), 22 deletions(-) create mode 100644 spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala diff --git a/python/delta/connect/proto/relations_pb2.py b/python/delta/connect/proto/relations_pb2.py index 2407cff662e..11d6935361c 100644 --- a/python/delta/connect/proto/relations_pb2.py +++ b/python/delta/connect/proto/relations_pb2.py @@ -34,7 +34,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1d\x64\x65lta/connect/relations.proto\x12\rdelta.connect\x1a\x18\x64\x65lta/connect/base.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xc5\x04\n\rDeltaRelation\x12)\n\x04scan\x18\x01 \x01(\x0b\x32\x13.delta.connect.ScanH\x00R\x04scan\x12K\n\x10\x64\x65scribe_history\x18\x02 \x01(\x0b\x32\x1e.delta.connect.DescribeHistoryH\x00R\x0f\x64\x65scribeHistory\x12H\n\x0f\x64\x65scribe_detail\x18\x03 \x01(\x0b\x32\x1d.delta.connect.DescribeDetailH\x00R\x0e\x64\x65scribeDetail\x12I\n\x10\x63onvert_to_delta\x18\x04 \x01(\x0b\x32\x1d.delta.connect.ConvertToDeltaH\x00R\x0e\x63onvertToDelta\x12\x42\n\rrestore_table\x18\x05 \x01(\x0b\x32\x1b.delta.connect.RestoreTableH\x00R\x0crestoreTable\x12\x43\n\x0eis_delta_table\x18\x06 \x01(\x0b\x32\x1b.delta.connect.IsDeltaTableH\x00R\x0cisDeltaTable\x12L\n\x11\x64\x65lete_from_table\x18\x07 \x01(\x0b\x32\x1e.delta.connect.DeleteFromTableH\x00R\x0f\x64\x65leteFromTable\x12?\n\x0cupdate_table\x18\x08 \x01(\x0b\x32\x1a.delta.connect.UpdateTableH\x00R\x0bupdateTableB\x0f\n\rrelation_type"7\n\x04Scan\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"B\n\x0f\x44\x65scribeHistory\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"A\n\x0e\x44\x65scribeDetail\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"\xd1\x01\n\x0e\x43onvertToDelta\x12\x1e\n\nidentifier\x18\x01 \x01(\tR\nidentifier\x12\x38\n\x17partition_schema_string\x18\x02 \x01(\tH\x00R\x15partitionSchemaString\x12Q\n\x17partition_schema_struct\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x15partitionSchemaStructB\x12\n\x10partition_schema"\x93\x01\n\x0cRestoreTable\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table\x12\x1a\n\x07version\x18\x02 \x01(\x03H\x00R\x07version\x12\x1e\n\ttimestamp\x18\x03 \x01(\tH\x00R\ttimestampB\x16\n\x14version_or_timestamp""\n\x0cIsDeltaTable\x12\x12\n\x04path\x18\x01 \x01(\tR\x04path"{\n\x0f\x44\x65leteFromTable\x12/\n\x06target\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x06target\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xb4\x01\n\x0bUpdateTable\x12/\n\x06target\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x06target\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition\x12;\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32\x19.delta.connect.AssignmentR\x0b\x61ssignments"n\n\nAssignment\x12/\n\x05\x66ield\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x66ield\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05valueB\x1a\n\x16io.delta.connect.protoP\x01\x62\x06proto3' + b'\n\x1d\x64\x65lta/connect/relations.proto\x12\rdelta.connect\x1a\x18\x64\x65lta/connect/base.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x90\x05\n\rDeltaRelation\x12)\n\x04scan\x18\x01 \x01(\x0b\x32\x13.delta.connect.ScanH\x00R\x04scan\x12K\n\x10\x64\x65scribe_history\x18\x02 \x01(\x0b\x32\x1e.delta.connect.DescribeHistoryH\x00R\x0f\x64\x65scribeHistory\x12H\n\x0f\x64\x65scribe_detail\x18\x03 \x01(\x0b\x32\x1d.delta.connect.DescribeDetailH\x00R\x0e\x64\x65scribeDetail\x12I\n\x10\x63onvert_to_delta\x18\x04 \x01(\x0b\x32\x1d.delta.connect.ConvertToDeltaH\x00R\x0e\x63onvertToDelta\x12\x42\n\rrestore_table\x18\x05 \x01(\x0b\x32\x1b.delta.connect.RestoreTableH\x00R\x0crestoreTable\x12\x43\n\x0eis_delta_table\x18\x06 \x01(\x0b\x32\x1b.delta.connect.IsDeltaTableH\x00R\x0cisDeltaTable\x12L\n\x11\x64\x65lete_from_table\x18\x07 \x01(\x0b\x32\x1e.delta.connect.DeleteFromTableH\x00R\x0f\x64\x65leteFromTable\x12?\n\x0cupdate_table\x18\x08 \x01(\x0b\x32\x1a.delta.connect.UpdateTableH\x00R\x0bupdateTable\x12I\n\x10merge_into_table\x18\t \x01(\x0b\x32\x1d.delta.connect.MergeIntoTableH\x00R\x0emergeIntoTableB\x0f\n\rrelation_type"7\n\x04Scan\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"B\n\x0f\x44\x65scribeHistory\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"A\n\x0e\x44\x65scribeDetail\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table"\xd1\x01\n\x0e\x43onvertToDelta\x12\x1e\n\nidentifier\x18\x01 \x01(\tR\nidentifier\x12\x38\n\x17partition_schema_string\x18\x02 \x01(\tH\x00R\x15partitionSchemaString\x12Q\n\x17partition_schema_struct\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x15partitionSchemaStructB\x12\n\x10partition_schema"\x93\x01\n\x0cRestoreTable\x12/\n\x05table\x18\x01 \x01(\x0b\x32\x19.delta.connect.DeltaTableR\x05table\x12\x1a\n\x07version\x18\x02 \x01(\x03H\x00R\x07version\x12\x1e\n\ttimestamp\x18\x03 \x01(\tH\x00R\ttimestampB\x16\n\x14version_or_timestamp""\n\x0cIsDeltaTable\x12\x12\n\x04path\x18\x01 \x01(\tR\x04path"{\n\x0f\x44\x65leteFromTable\x12/\n\x06target\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x06target\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xb4\x01\n\x0bUpdateTable\x12/\n\x06target\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x06target\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition\x12;\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32\x19.delta.connect.AssignmentR\x0b\x61ssignments"\x8c\n\n\x0eMergeIntoTable\x12/\n\x06target\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x06target\x12/\n\x06source\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x06source\x12\x37\n\tcondition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition\x12M\n\x0fmatched_actions\x18\x04 \x03(\x0b\x32$.delta.connect.MergeIntoTable.ActionR\x0ematchedActions\x12T\n\x13not_matched_actions\x18\x05 \x03(\x0b\x32$.delta.connect.MergeIntoTable.ActionR\x11notMatchedActions\x12\x66\n\x1dnot_matched_by_source_actions\x18\x06 \x03(\x0b\x32$.delta.connect.MergeIntoTable.ActionR\x19notMatchedBySourceActions\x12\x37\n\x15with_schema_evolution\x18\x07 \x01(\x08H\x00R\x13withSchemaEvolution\x88\x01\x01\x1a\xfe\x05\n\x06\x41\x63tion\x12\x37\n\tcondition\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition\x12X\n\rdelete_action\x18\x02 \x01(\x0b\x32\x31.delta.connect.MergeIntoTable.Action.DeleteActionH\x00R\x0c\x64\x65leteAction\x12X\n\rupdate_action\x18\x03 \x01(\x0b\x32\x31.delta.connect.MergeIntoTable.Action.UpdateActionH\x00R\x0cupdateAction\x12\x65\n\x12update_star_action\x18\x04 \x01(\x0b\x32\x35.delta.connect.MergeIntoTable.Action.UpdateStarActionH\x00R\x10updateStarAction\x12X\n\rinsert_action\x18\x05 \x01(\x0b\x32\x31.delta.connect.MergeIntoTable.Action.InsertActionH\x00R\x0cinsertAction\x12\x65\n\x12insert_star_action\x18\x06 \x01(\x0b\x32\x35.delta.connect.MergeIntoTable.Action.InsertStarActionH\x00R\x10insertStarAction\x1a\x0e\n\x0c\x44\x65leteAction\x1aK\n\x0cUpdateAction\x12;\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32\x19.delta.connect.AssignmentR\x0b\x61ssignments\x1a\x12\n\x10UpdateStarAction\x1aK\n\x0cInsertAction\x12;\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32\x19.delta.connect.AssignmentR\x0b\x61ssignments\x1a\x12\n\x10InsertStarActionB\r\n\x0b\x61\x63tion_typeB\x18\n\x16_with_schema_evolution"n\n\nAssignment\x12/\n\x05\x66ield\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x66ield\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05valueB\x1a\n\x16io.delta.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -43,23 +43,37 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\026io.delta.connect.protoP\001" _DELTARELATION._serialized_start = 166 - _DELTARELATION._serialized_end = 747 - _SCAN._serialized_start = 749 - _SCAN._serialized_end = 804 - _DESCRIBEHISTORY._serialized_start = 806 - _DESCRIBEHISTORY._serialized_end = 872 - _DESCRIBEDETAIL._serialized_start = 874 - _DESCRIBEDETAIL._serialized_end = 939 - _CONVERTTODELTA._serialized_start = 942 - _CONVERTTODELTA._serialized_end = 1151 - _RESTORETABLE._serialized_start = 1154 - _RESTORETABLE._serialized_end = 1301 - _ISDELTATABLE._serialized_start = 1303 - _ISDELTATABLE._serialized_end = 1337 - _DELETEFROMTABLE._serialized_start = 1339 - _DELETEFROMTABLE._serialized_end = 1462 - _UPDATETABLE._serialized_start = 1465 - _UPDATETABLE._serialized_end = 1645 - _ASSIGNMENT._serialized_start = 1647 - _ASSIGNMENT._serialized_end = 1757 + _DELTARELATION._serialized_end = 822 + _SCAN._serialized_start = 824 + _SCAN._serialized_end = 879 + _DESCRIBEHISTORY._serialized_start = 881 + _DESCRIBEHISTORY._serialized_end = 947 + _DESCRIBEDETAIL._serialized_start = 949 + _DESCRIBEDETAIL._serialized_end = 1014 + _CONVERTTODELTA._serialized_start = 1017 + _CONVERTTODELTA._serialized_end = 1226 + _RESTORETABLE._serialized_start = 1229 + _RESTORETABLE._serialized_end = 1376 + _ISDELTATABLE._serialized_start = 1378 + _ISDELTATABLE._serialized_end = 1412 + _DELETEFROMTABLE._serialized_start = 1414 + _DELETEFROMTABLE._serialized_end = 1537 + _UPDATETABLE._serialized_start = 1540 + _UPDATETABLE._serialized_end = 1720 + _MERGEINTOTABLE._serialized_start = 1723 + _MERGEINTOTABLE._serialized_end = 3015 + _MERGEINTOTABLE_ACTION._serialized_start = 2223 + _MERGEINTOTABLE_ACTION._serialized_end = 2989 + _MERGEINTOTABLE_ACTION_DELETEACTION._serialized_start = 2766 + _MERGEINTOTABLE_ACTION_DELETEACTION._serialized_end = 2780 + _MERGEINTOTABLE_ACTION_UPDATEACTION._serialized_start = 2782 + _MERGEINTOTABLE_ACTION_UPDATEACTION._serialized_end = 2857 + _MERGEINTOTABLE_ACTION_UPDATESTARACTION._serialized_start = 2859 + _MERGEINTOTABLE_ACTION_UPDATESTARACTION._serialized_end = 2877 + _MERGEINTOTABLE_ACTION_INSERTACTION._serialized_start = 2879 + _MERGEINTOTABLE_ACTION_INSERTACTION._serialized_end = 2954 + _MERGEINTOTABLE_ACTION_INSERTSTARACTION._serialized_start = 2956 + _MERGEINTOTABLE_ACTION_INSERTSTARACTION._serialized_end = 2974 + _ASSIGNMENT._serialized_start = 3017 + _ASSIGNMENT._serialized_end = 3127 # @@protoc_insertion_point(module_scope) diff --git a/python/delta/connect/proto/relations_pb2.pyi b/python/delta/connect/proto/relations_pb2.pyi index f8f05f52bbb..3ebda3f402c 100644 --- a/python/delta/connect/proto/relations_pb2.pyi +++ b/python/delta/connect/proto/relations_pb2.pyi @@ -62,6 +62,7 @@ class DeltaRelation(google.protobuf.message.Message): IS_DELTA_TABLE_FIELD_NUMBER: builtins.int DELETE_FROM_TABLE_FIELD_NUMBER: builtins.int UPDATE_TABLE_FIELD_NUMBER: builtins.int + MERGE_INTO_TABLE_FIELD_NUMBER: builtins.int @property def scan(self) -> global___Scan: ... @property @@ -78,6 +79,8 @@ class DeltaRelation(google.protobuf.message.Message): def delete_from_table(self) -> global___DeleteFromTable: ... @property def update_table(self) -> global___UpdateTable: ... + @property + def merge_into_table(self) -> global___MergeIntoTable: ... def __init__( self, *, @@ -89,6 +92,7 @@ class DeltaRelation(google.protobuf.message.Message): is_delta_table: global___IsDeltaTable | None = ..., delete_from_table: global___DeleteFromTable | None = ..., update_table: global___UpdateTable | None = ..., + merge_into_table: global___MergeIntoTable | None = ..., ) -> None: ... def HasField( self, @@ -103,6 +107,8 @@ class DeltaRelation(google.protobuf.message.Message): b"describe_history", "is_delta_table", b"is_delta_table", + "merge_into_table", + b"merge_into_table", "relation_type", b"relation_type", "restore_table", @@ -126,6 +132,8 @@ class DeltaRelation(google.protobuf.message.Message): b"describe_history", "is_delta_table", b"is_delta_table", + "merge_into_table", + b"merge_into_table", "relation_type", b"relation_type", "restore_table", @@ -148,6 +156,7 @@ class DeltaRelation(google.protobuf.message.Message): "is_delta_table", "delete_from_table", "update_table", + "merge_into_table", ] | None ): ... @@ -433,6 +442,264 @@ class UpdateTable(google.protobuf.message.Message): global___UpdateTable = UpdateTable +class MergeIntoTable(google.protobuf.message.Message): + """Command that merges a source query/table into a Delta table. + + Needs to be a Relation, as it returns a row containing the execution metrics. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Action(google.protobuf.message.Message): + """Rule that specifies how the target table should be modified.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class DeleteAction(google.protobuf.message.Message): + """Action that deletes the target row.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + + class UpdateAction(google.protobuf.message.Message): + """Action that updates the target row using a set of assignments.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ASSIGNMENTS_FIELD_NUMBER: builtins.int + @property + def assignments( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___Assignment + ]: + """(Optional) Set of assignments to apply.""" + def __init__( + self, + *, + assignments: collections.abc.Iterable[global___Assignment] | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["assignments", b"assignments"] + ) -> None: ... + + class UpdateStarAction(google.protobuf.message.Message): + """Action that updates the target row by overwriting all columns.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + + class InsertAction(google.protobuf.message.Message): + """Action that inserts the source row into the target using a set of assignments.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ASSIGNMENTS_FIELD_NUMBER: builtins.int + @property + def assignments( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___Assignment + ]: + """(Optional) Set of assignments to apply.""" + def __init__( + self, + *, + assignments: collections.abc.Iterable[global___Assignment] | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["assignments", b"assignments"] + ) -> None: ... + + class InsertStarAction(google.protobuf.message.Message): + """Action that inserts the source row into the target by setting all columns.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + + CONDITION_FIELD_NUMBER: builtins.int + DELETE_ACTION_FIELD_NUMBER: builtins.int + UPDATE_ACTION_FIELD_NUMBER: builtins.int + UPDATE_STAR_ACTION_FIELD_NUMBER: builtins.int + INSERT_ACTION_FIELD_NUMBER: builtins.int + INSERT_STAR_ACTION_FIELD_NUMBER: builtins.int + @property + def condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression: + """(Optional) Condition for the action to be applied.""" + @property + def delete_action(self) -> global___MergeIntoTable.Action.DeleteAction: ... + @property + def update_action(self) -> global___MergeIntoTable.Action.UpdateAction: ... + @property + def update_star_action(self) -> global___MergeIntoTable.Action.UpdateStarAction: ... + @property + def insert_action(self) -> global___MergeIntoTable.Action.InsertAction: ... + @property + def insert_star_action(self) -> global___MergeIntoTable.Action.InsertStarAction: ... + def __init__( + self, + *, + condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ..., + delete_action: global___MergeIntoTable.Action.DeleteAction | None = ..., + update_action: global___MergeIntoTable.Action.UpdateAction | None = ..., + update_star_action: global___MergeIntoTable.Action.UpdateStarAction | None = ..., + insert_action: global___MergeIntoTable.Action.InsertAction | None = ..., + insert_star_action: global___MergeIntoTable.Action.InsertStarAction | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "action_type", + b"action_type", + "condition", + b"condition", + "delete_action", + b"delete_action", + "insert_action", + b"insert_action", + "insert_star_action", + b"insert_star_action", + "update_action", + b"update_action", + "update_star_action", + b"update_star_action", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "action_type", + b"action_type", + "condition", + b"condition", + "delete_action", + b"delete_action", + "insert_action", + b"insert_action", + "insert_star_action", + b"insert_star_action", + "update_action", + b"update_action", + "update_star_action", + b"update_star_action", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["action_type", b"action_type"] + ) -> ( + typing_extensions.Literal[ + "delete_action", + "update_action", + "update_star_action", + "insert_action", + "insert_star_action", + ] + | None + ): ... + + TARGET_FIELD_NUMBER: builtins.int + SOURCE_FIELD_NUMBER: builtins.int + CONDITION_FIELD_NUMBER: builtins.int + MATCHED_ACTIONS_FIELD_NUMBER: builtins.int + NOT_MATCHED_ACTIONS_FIELD_NUMBER: builtins.int + NOT_MATCHED_BY_SOURCE_ACTIONS_FIELD_NUMBER: builtins.int + WITH_SCHEMA_EVOLUTION_FIELD_NUMBER: builtins.int + @property + def target(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: + """(Required) Target table to merge into.""" + @property + def source(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: + """(Required) Source data to merge from.""" + @property + def condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression: + """(Required) Condition for a source row to match with a target row.""" + @property + def matched_actions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___MergeIntoTable.Action + ]: + """(Optional) Actions to apply when a source row matches a target row.""" + @property + def not_matched_actions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___MergeIntoTable.Action + ]: + """(Optional) Actions to apply when a source row does not match a target row.""" + @property + def not_matched_by_source_actions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___MergeIntoTable.Action + ]: + """(Optional) Actions to apply when a target row does not match a source row.""" + with_schema_evolution: builtins.bool + """(Optional) Whether Schema Evolution is enabled for this command.""" + def __init__( + self, + *, + target: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + source: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ..., + matched_actions: collections.abc.Iterable[global___MergeIntoTable.Action] | None = ..., + not_matched_actions: collections.abc.Iterable[global___MergeIntoTable.Action] | None = ..., + not_matched_by_source_actions: collections.abc.Iterable[global___MergeIntoTable.Action] + | None = ..., + with_schema_evolution: builtins.bool | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_with_schema_evolution", + b"_with_schema_evolution", + "condition", + b"condition", + "source", + b"source", + "target", + b"target", + "with_schema_evolution", + b"with_schema_evolution", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_with_schema_evolution", + b"_with_schema_evolution", + "condition", + b"condition", + "matched_actions", + b"matched_actions", + "not_matched_actions", + b"not_matched_actions", + "not_matched_by_source_actions", + b"not_matched_by_source_actions", + "source", + b"source", + "target", + b"target", + "with_schema_evolution", + b"with_schema_evolution", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_with_schema_evolution", b"_with_schema_evolution"], + ) -> typing_extensions.Literal["with_schema_evolution"] | None: ... + +global___MergeIntoTable = MergeIntoTable + class Assignment(google.protobuf.message.Message): """Represents an assignment of a value to a field.""" diff --git a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaTable.scala b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaTable.scala index c54a8d6aa8a..a1718708a74 100644 --- a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaTable.scala +++ b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaTable.scala @@ -382,6 +382,116 @@ class DeltaTable private[tables]( executeUpdate(Some(functions.expr(condition)), toStrColumnMap(set.asScala.toMap)) } + /** + * Merge data from the `source` DataFrame based on the given merge `condition`. This returns + * a [[DeltaMergeBuilder]] object that can be used to specify the update, delete, or insert + * actions to be performed on rows based on whether the rows matched the condition or not. + * + * See the [[DeltaMergeBuilder]] for a full description of this operation and what combinations of + * update, delete and insert operations are allowed. + * + * Scala example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched + * .updateExpr(Map( + * "value" -> "source.value")) + * .whenNotMatched + * .insertExpr(Map( + * "key" -> "source.key", + * "value" -> "source.value")) + * .execute() + * }}} + * + * Java example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched + * .updateExpr( + * new HashMap() {{ + * put("value" -> "source.value"); + * }}) + * .whenNotMatched + * .insertExpr( + * new HashMap() {{ + * put("key", "source.key"); + * put("value", "source.value"); + * }}) + * .execute(); + * }}} + * + * @param source source Dataframe to be merged. + * @param condition boolean expression as SQL formatted string. + * + * @since 4.0.0 + */ + def merge(source: DataFrame, condition: String): DeltaMergeBuilder = { + merge(source, functions.expr(condition)) + } + + /** + * Merge data from the `source` DataFrame based on the given merge `condition`. This returns + * a [[DeltaMergeBuilder]] object that can be used to specify the update, delete, or insert + * actions to be performed on rows based on whether the rows matched the condition or not. + * + * See the [[DeltaMergeBuilder]] for a full description of this operation and what combinations of + * update, delete and insert operations are allowed. + * + * Scala example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched + * .updateExpr(Map( + * "value" -> "source.value")) + * .whenNotMatched + * .insertExpr(Map( + * "key" -> "source.key", + * "value" -> "source.value")) + * .execute() + * }}} + * + * Java example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched + * .updateExpr( + * new HashMap() {{ + * put("value" -> "source.value") + * }}) + * .whenNotMatched + * .insertExpr( + * new HashMap() {{ + * put("key", "source.key"); + * put("value", "source.value"); + * }}) + * .execute() + * }}} + * + * @param source source Dataframe to be merged. + * @param condition boolean expression as a Column object. + * + * @since 4.0.0 + */ + def merge(source: DataFrame, condition: Column): DeltaMergeBuilder = { + DeltaMergeBuilder(this, source, condition) + } + /** * Helper method for the restoreToVersion and restoreToTimestamp APIs. * diff --git a/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala new file mode 100644 index 00000000000..42175a50c7a --- /dev/null +++ b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala @@ -0,0 +1,411 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.tables + +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.{col, expr} +import org.apache.spark.sql.test.DeltaQueryTest + +class DeltaMergeBuilderSuite extends DeltaQueryTest with RemoteSparkSession { + private def writeTargetTable(path: String): Unit = { + val session = spark + import session.implicits._ + Seq(("a", 1), ("b", 2), ("c", 3), ("d", 4)).toDF("key", "value") + .write.mode("overwrite").format("delta").save(path) + } + + private def testSource = { + val session = spark + import session.implicits._ + Seq(("a", -1), ("b", 0), ("e", -5), ("f", -6)).toDF("k", "v") + } + + test("string expressions in merge conditions and assignments") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, "key = k") + .whenMatched().updateExpr(Map("value" -> "value + v")) + .whenNotMatched().insertExpr(Map("key" -> "k", "value" -> "v")) + .whenNotMatchedBySource().updateExpr(Map("value" -> "value - 1")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 0), Row("b", 2), Row("c", 2), Row("d", 3), Row("e", -5), Row("f", -6))) + } + } + + test("column expressions in merge conditions and assignments") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, col("key") === col("k")) + .whenMatched().update(Map("value" -> (col("value") + col("v")))) + .whenNotMatched().insert(Map("key" -> col("k"), "value" -> col("v"))) + .whenNotMatchedBySource().update(Map("value" -> (col("value") - 1))) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 0), Row("b", 2), Row("c", 2), Row("d", 3), Row("e", -5), Row("f", -6))) + } + } + + test("multiple when matched then update clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, expr("key = k")) + .whenMatched("key = 'a'").updateExpr(Map("value" -> "5")) + .whenMatched().updateExpr(Map("value" -> "0")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 5), Row("b", 0), Row("c", 3), Row("d", 4))) + } + } + + test("multiple when matched then delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, "key = k") + .whenMatched("key = 'a'").delete() + .whenMatched().delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("c", 3), Row("d", 4))) + } + } + + test("redundant when matched then update and delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, col("key") === col("k")) + .whenMatched("key = 'a'").updateExpr(Map("value" -> "5")) + .whenMatched("key = 'a'").updateExpr(Map("value" -> "0")) + .whenMatched("key = 'b'").updateExpr(Map("value" -> "6")) + .whenMatched("key = 'b'").delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 5), Row("b", 6), Row("c", 3), Row("d", 4))) + } + } + + test("interleaved when matched then update and delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource, col("t.key") === col("k")) + .whenMatched("t.key = 'a'").delete() + .whenMatched("t.key = 'a'").updateExpr(Map("value" -> "5")) + .whenMatched("t.key = 'b'").delete() + .whenMatched().updateExpr(Map("value" -> "6")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("c", 3), Row("d", 4))) + } + } + + test("multiple when not matched then insert clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), col("t.key") === col("s.key")) + .whenNotMatched("s.key = 'e'").insertExpr(Map("t.key" -> "s.key", "t.value" -> "5")) + .whenNotMatched().insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2), Row("c", 3), Row("d", 4), Row("e", 5), Row("f", -6))) + } + } + + test("redundant when not matched then insert clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, expr("key = k")) + .whenNotMatched("k = 'e'").insertExpr(Map("key" -> "k", "value" -> "5")) + .whenNotMatched("k = 'e'").insertExpr(Map("key" -> "k", "value" -> "6")) + .whenNotMatched("k = 'f'").insertExpr(Map("key" -> "k", "value" -> "7")) + .whenNotMatched("k = 'f'").insertExpr(Map("key" -> "k", "value" -> "8")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2), Row("c", 3), Row("d", 4), Row("e", 5), Row("f", 7))) + } + } + + test("multiple when not matched by source then update clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.merge(testSource, expr("key = k")) + .whenNotMatchedBySource("key = 'c'").updateExpr(Map("value" -> "5")) + .whenNotMatchedBySource().updateExpr(Map("value" -> "0")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2), Row("c", 5), Row("d", 0))) + } + } + + test("multiple when not matched by source then delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.merge(testSource, expr("key = k")) + .whenNotMatchedBySource("key = 'c'").delete() + .whenNotMatchedBySource().delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2))) + } + } + + test("redundant when not matched by source then update and delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.merge(testSource, expr("key = k")) + .whenNotMatchedBySource("key = 'c'").updateExpr(Map("value" -> "5")) + .whenNotMatchedBySource("key = 'c'").updateExpr(Map("value" -> "0")) + .whenNotMatchedBySource("key = 'd'").updateExpr(Map("value" -> "6")) + .whenNotMatchedBySource("key = 'd'").delete() + .whenNotMatchedBySource().delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2), Row("c", 5), Row("d", 6))) + } + } + + + test("interleaved when not matched by source then update and delete clauses") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.merge(testSource, expr("key = k")) + .whenNotMatchedBySource("key = 'c'").delete() + .whenNotMatchedBySource("key = 'c'").updateExpr(Map("value" -> "5")) + .whenNotMatchedBySource("key = 'd'").delete() + .whenNotMatchedBySource().updateExpr(Map("value" -> "6")) + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", 1), Row("b", 2))) + } + } + + test("string expressions in all conditions and assignments") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, "key = k") + .whenMatched("k = 'a'").updateExpr(Map("value" -> "v + 0")) + .whenMatched("k = 'b'").delete() + .whenNotMatched("k = 'e'").insertExpr(Map("key" -> "k", "value" -> "v + 0")) + .whenNotMatchedBySource("key = 'c'").updateExpr(Map("value" -> "value + 0")) + .whenNotMatchedBySource("key = 'd'").delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", -1), Row("c", 3), Row("e", -5))) + } + } + + test("column expressions in all conditions and assignments") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable + .merge(testSource, expr("key = k")) + .whenMatched(expr("k = 'a'")).update(Map("value" -> (col("v") + 0))) + .whenMatched(expr("k = 'b'")).delete() + .whenNotMatched(expr("k = 'e'")).insert(Map("key" -> col("k"), "value" -> (col("v") + 0))) + .whenNotMatchedBySource(expr("key = 'c'")).update(Map("value" -> (col("value") + 0))) + .whenNotMatchedBySource(expr("key = 'd'")).delete() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", -1), Row("c", 3), Row("e", -5))) + } + } + + test("no clause conditions and insertAll/updateAll + aliases") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), expr("t.key = s.key")) + .whenMatched().updateAll() + .whenNotMatched().insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", -1), Row("b", 0), Row("c", 3), Row("d", 4), Row("e", -5), Row("f", -6))) + } + } + + test("string expressions in all clause conditions and insertAll/updateAll + aliases") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), "t.key = s.key") + .whenMatched("s.key = 'a'").updateAll() + .whenNotMatched("s.key = 'e'").insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", -1), Row("b", 2), Row("c", 3), Row("d", 4), Row("e", -5))) + } + } + + test("column expressions in all clause conditions and insertAll/updateAll + aliases") { + withTempPath { dir => + val path = dir.getAbsolutePath + writeTargetTable(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), expr("t.key = s.key")) + .whenMatched(expr("s.key = 'a'")).updateAll() + .whenNotMatched(expr("s.key = 'e'")).insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(Row("a", -1), Row("b", 2), Row("c", 3), Row("d", 4), Row("e", -5))) + } + } + + test("automatic schema evolution") { + val session = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getAbsolutePath + Seq("a", "b", "c", "d").toDF("key") + .write.mode("overwrite").format("delta").save(path) + val deltaTable = DeltaTable.forPath(spark, path) + + withSQLConf("spark.databricks.delta.schema.autoMerge.enabled" -> "true") { + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), expr("t.key = s.key")) + .whenMatched().updateAll() + .whenNotMatched().insertAll() + .execute() + } + + checkAnswer( + deltaTable.toDF, + Seq( + Row("a", -1), Row("b", 0), Row("c", null), Row("d", null), Row("e", -5), Row("f", -6))) + } + } + + test("merge dataframe with many columns") { + withTempPath { dir => + val path = dir.getAbsolutePath + var df1 = spark.range(1).toDF + val numColumns = 100 + for (i <- 0 until numColumns) { + df1 = df1.withColumn(s"col$i", col("id")) + } + df1.write.mode("overwrite").format("delta").save(path) + val deltaTable = io.delta.tables.DeltaTable.forPath(spark, path) + + var df2 = spark.range(1).toDF + for (i <- 0 until numColumns) { + df2 = df2.withColumn(s"col$i", col("id") + 1) + } + + deltaTable + .as("t") + .merge(df2.as("s"), "s.id = t.id") + .whenMatched().updateAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq(df2.collectAsList().get(0))) + } + } +} diff --git a/spark-connect/common/src/main/protobuf/delta/connect/relations.proto b/spark-connect/common/src/main/protobuf/delta/connect/relations.proto index 35a9a40746f..fac426081ef 100644 --- a/spark-connect/common/src/main/protobuf/delta/connect/relations.proto +++ b/spark-connect/common/src/main/protobuf/delta/connect/relations.proto @@ -37,6 +37,7 @@ message DeltaRelation { IsDeltaTable is_delta_table = 6; DeleteFromTable delete_from_table = 7; UpdateTable update_table = 8; + MergeIntoTable merge_into_table = 9; } } @@ -130,6 +131,68 @@ message UpdateTable { repeated Assignment assignments = 3; } +// Command that merges a source query/table into a Delta table. +// +// Needs to be a Relation, as it returns a row containing the execution metrics. +message MergeIntoTable { + // (Required) Target table to merge into. + spark.connect.Relation target = 1; + + // (Required) Source data to merge from. + spark.connect.Relation source = 2; + + // (Required) Condition for a source row to match with a target row. + spark.connect.Expression condition = 3; + + // (Optional) Actions to apply when a source row matches a target row. + repeated Action matched_actions = 4; + + // (Optional) Actions to apply when a source row does not match a target row. + repeated Action not_matched_actions = 5; + + // (Optional) Actions to apply when a target row does not match a source row. + repeated Action not_matched_by_source_actions = 6; + + // (Optional) Whether Schema Evolution is enabled for this command. + optional bool with_schema_evolution = 7; + + // Rule that specifies how the target table should be modified. + message Action { + // (Optional) Condition for the action to be applied. + spark.connect.Expression condition = 1; + + // (Required) + oneof action_type { + DeleteAction delete_action = 2; + UpdateAction update_action = 3; + UpdateStarAction update_star_action = 4; + InsertAction insert_action = 5; + InsertStarAction insert_star_action = 6; + } + + // Action that deletes the target row. + message DeleteAction {} + + // Action that updates the target row using a set of assignments. + message UpdateAction { + // (Optional) Set of assignments to apply. + repeated Assignment assignments = 1; + } + + // Action that updates the target row by overwriting all columns. + message UpdateStarAction {} + + // Action that inserts the source row into the target using a set of assignments. + message InsertAction { + // (Optional) Set of assignments to apply. + repeated Assignment assignments = 1; + } + + // Action that inserts the source row into the target by setting all columns. + message InsertStarAction {} + } +} + // Represents an assignment of a value to a field. message Assignment { // (Required) Expression identifying the (struct) field that is assigned a new value. diff --git a/spark-connect/server/src/main/scala-spark-master/io/delta/connect/DeltaRelationPlugin.scala b/spark-connect/server/src/main/scala-spark-master/io/delta/connect/DeltaRelationPlugin.scala index 32bf501e34d..e6b12b7100b 100644 --- a/spark-connect/server/src/main/scala-spark-master/io/delta/connect/DeltaRelationPlugin.scala +++ b/spark-connect/server/src/main/scala-spark-master/io/delta/connect/DeltaRelationPlugin.scala @@ -20,6 +20,7 @@ import java.util.Optional import scala.collection.JavaConverters._ +import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand import com.google.protobuf import com.google.protobuf.{ByteString, InvalidProtocolBufferException} import io.delta.connect.proto @@ -27,7 +28,8 @@ import io.delta.tables.DeltaTable import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoders, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput} import org.apache.spark.sql.connect.config.Connect @@ -35,7 +37,6 @@ import org.apache.spark.sql.connect.delta.DeltaRelationPlugin.{parseAnyFrom, par import org.apache.spark.sql.connect.delta.ImplicitProtoConversions._ import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.plugin.RelationPlugin -import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand import org.apache.spark.sql.types.StructType /** @@ -77,6 +78,8 @@ class DeltaRelationPlugin extends RelationPlugin with DeltaPlannerBase { transformDeleteFromTable(planner, relation.getDeleteFromTable) case proto.DeltaRelation.RelationTypeCase.UPDATE_TABLE => transformUpdateTable(planner, relation.getUpdateTable) + case proto.DeltaRelation.RelationTypeCase.MERGE_INTO_TABLE => + transformMergeIntoTable(planner, relation.getMergeIntoTable) case _ => throw InvalidPlanInput(s"Unknown DeltaRelation ${relation.getRelationTypeCase}") } @@ -172,6 +175,96 @@ class DeltaRelationPlugin extends RelationPlugin with DeltaPlannerBase { .queryExecution.commandExecuted } + private def transformMergeIntoTable( + planner: SparkConnectPlanner, protoMerge: proto.MergeIntoTable): LogicalPlan = { + val target = planner.transformRelation(protoMerge.getTarget) + val source = planner.transformRelation(protoMerge.getSource) + val condition = planner.transformExpression(protoMerge.getCondition) + val matchedActions = protoMerge.getMatchedActionsList.asScala + .map(transformMergeWhenMatchedAction(planner, _)) + val notMatchedActions = protoMerge.getNotMatchedActionsList.asScala + .map(transformMergeWhenNotMatchedAction(planner, _)) + val notMatchedBySourceActions = protoMerge.getNotMatchedBySourceActionsList.asScala + .map(transformMergeWhenNotMatchedBySourceAction(planner, _)) + val withSchemaEvolution = protoMerge.getWithSchemaEvolution + + val merge = DeltaMergeInto( + target, + source, + condition, + matchedActions.toSeq ++ notMatchedActions.toSeq ++ notMatchedBySourceActions.toSeq, + withSchemaEvolution + ) + Dataset.ofRows(planner.session, merge).queryExecution.commandExecuted + } + + private def transformMergeActionCondition( + planner: SparkConnectPlanner, + protoAction: proto.MergeIntoTable.Action): Option[Expression] = { + if (protoAction.hasCondition) { + Some(planner.transformExpression(protoAction.getCondition)) + } else { + None + } + } + + private def transformMergeWhenMatchedAction( + planner: SparkConnectPlanner, + protoAction: proto.MergeIntoTable.Action): DeltaMergeIntoMatchedClause = { + val condition = transformMergeActionCondition(planner, protoAction) + + protoAction.getActionTypeCase match { + case proto.MergeIntoTable.Action.ActionTypeCase.DELETE_ACTION => + DeltaMergeIntoMatchedDeleteClause(condition) + case proto.MergeIntoTable.Action.ActionTypeCase.UPDATE_ACTION => + val actions = transformMergeAssignments( + planner, protoAction.getUpdateAction.getAssignmentsList.asScala.toSeq) + DeltaMergeIntoMatchedUpdateClause(condition, actions) + case proto.MergeIntoTable.Action.ActionTypeCase.UPDATE_STAR_ACTION => + DeltaMergeIntoMatchedUpdateClause(condition, Seq(UnresolvedStar(None))) + } + } + + private def transformMergeWhenNotMatchedAction( + planner: SparkConnectPlanner, + protoAction: proto.MergeIntoTable.Action): DeltaMergeIntoNotMatchedClause = { + val condition = transformMergeActionCondition(planner, protoAction) + + protoAction.getActionTypeCase match { + case proto.MergeIntoTable.Action.ActionTypeCase.INSERT_ACTION => + val actions = transformMergeAssignments( + planner, protoAction.getInsertAction.getAssignmentsList.asScala.toSeq) + DeltaMergeIntoNotMatchedInsertClause(condition, actions) + case proto.MergeIntoTable.Action.ActionTypeCase.INSERT_STAR_ACTION => + DeltaMergeIntoNotMatchedInsertClause(condition, Seq(UnresolvedStar(None))) + } + } + + private def transformMergeWhenNotMatchedBySourceAction( + planner: SparkConnectPlanner, + protoAction: proto.MergeIntoTable.Action): DeltaMergeIntoNotMatchedBySourceClause = { + val condition = transformMergeActionCondition(planner, protoAction) + + protoAction.getActionTypeCase match { + case proto.MergeIntoTable.Action.ActionTypeCase.DELETE_ACTION => + DeltaMergeIntoNotMatchedBySourceDeleteClause(condition) + case proto.MergeIntoTable.Action.ActionTypeCase.UPDATE_ACTION => + val actions = transformMergeAssignments( + planner, protoAction.getUpdateAction.getAssignmentsList.asScala.toSeq) + DeltaMergeIntoNotMatchedBySourceUpdateClause(condition, actions) + } + } + + private def transformMergeAssignments( + planner: SparkConnectPlanner, + protoAssignments: Seq[proto.Assignment]): Seq[Expression] = { + if (protoAssignments.isEmpty) { + Seq.empty + } else { + DeltaMergeIntoClause.toActions(protoAssignments.map(transformAssignment(planner, _))) + } + } + private def transformAssignment( planner: SparkConnectPlanner, assignment: proto.Assignment): Assignment = { Assignment( diff --git a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala index c30afba8f0d..9e3d8815aa5 100644 --- a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala +++ b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala @@ -454,6 +454,148 @@ class DeltaConnectPlannerSuite } } + test("merge - insert only") { + val targetTableName = "target" + val sourceTableName = "source" + withTable(targetTableName, sourceTableName) { + spark.range(end = 100).select(col("id") as "key", col("id") as "value") + .write.format("delta").saveAsTable(targetTableName) + + spark.range(end = 100) + .select(col("id") + 50 as "id") + .select(col("id") as "key", col("id") + 1000 as "value") + .write.format("delta").saveAsTable(sourceTableName) + + val input = createSparkRelation( + proto.DeltaRelation.newBuilder() + .setMergeIntoTable( + proto.MergeIntoTable.newBuilder() + .setTarget(createSubqueryAlias(createScan(targetTableName), alias = "t")) + .setSource(createSubqueryAlias(createScan(sourceTableName), alias = "s")) + .setCondition(createExpression("t.key = s.key")) + .addNotMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setInsertAction( + proto.MergeIntoTable.Action.InsertAction.newBuilder() + .addAssignments(createAssignment(field = "t.key", value = "s.key")) + .addAssignments(createAssignment(field = "t.value", value = "s.value")) + ) + ) + ) + ) + + val plan = transform(input) + val result = Dataset.ofRows(spark, plan).collect() + assert(result.length === 1) + assert(result.head.getLong(0) === 50) // num_affected_rows + assert(result.head.getLong(1) === 0) // num_updated_rows + assert(result.head.getLong(2) === 0) // num_deleted_rows + assert(result.head.getLong(3) === 50) // num_inserted_rows + + checkAnswer( + spark.read.table(targetTableName), + Seq.tabulate(100)(i => Row(i, i)) ++ Seq.tabulate(50)(i => Row(i + 100, i + 1100)) + ) + } + } + + test("merge - update only") { + val targetTableName = "target" + val sourceTableName = "source" + withTable(targetTableName, sourceTableName) { + spark.range(end = 100).select(col("id") as "key", col("id") as "value") + .write.format("delta").saveAsTable(targetTableName) + + spark.range(end = 100) + .select(col("id") + 50 as "id") + .select(col("id") as "key", col("id") + 1000 as "value") + .write.format("delta").saveAsTable(sourceTableName) + + val input = createSparkRelation( + proto.DeltaRelation.newBuilder() + .setMergeIntoTable( + proto.MergeIntoTable.newBuilder() + .setTarget(createSubqueryAlias(createScan(targetTableName), alias = "t")) + .setSource(createSubqueryAlias(createScan(sourceTableName), alias = "s")) + .setCondition(createExpression("t.key = s.key")) + .addMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setUpdateAction( + proto.MergeIntoTable.Action.UpdateAction.newBuilder() + .addAssignments(createAssignment(field = "t.key", value = "s.key")) + .addAssignments(createAssignment(field = "t.value", value = "s.value")) + ) + ) + ) + ) + + val plan = transform(input) + val result = Dataset.ofRows(spark, plan).collect() + assert(result.length === 1) + assert(result.head.getLong(0) === 50) // num_affected_rows + assert(result.head.getLong(1) === 50) // num_updated_rows + assert(result.head.getLong(2) === 0) // num_deleted_rows + assert(result.head.getLong(3) === 0) // num_inserted_rows + + checkAnswer( + spark.read.table(targetTableName), + Seq.tabulate(50)(i => Row(i, i)) ++ Seq.tabulate(50)(i => Row(i + 50, i + 1050)) + ) + } + } + + test("merge - mixed") { + val targetTableName = "target" + val sourceTableName = "source" + withTable(targetTableName, sourceTableName) { + spark.range(end = 100).select(col("id") as "key", col("id") as "value") + .write.format("delta").saveAsTable(targetTableName) + + spark.range(end = 100) + .select(col("id") + 50 as "id") + .select(col("id") as "key", col("id") + 1000 as "value") + .write.format("delta").saveAsTable(sourceTableName) + + val input = createSparkRelation( + proto.DeltaRelation.newBuilder() + .setMergeIntoTable( + proto.MergeIntoTable.newBuilder() + .setTarget(createSubqueryAlias(createScan(targetTableName), alias = "t")) + .setSource(createSubqueryAlias(createScan(sourceTableName), alias = "s")) + .setCondition(createExpression("t.key = s.key")) + .addMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setUpdateStarAction(proto.MergeIntoTable.Action.UpdateStarAction.newBuilder()) + ) + .addNotMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setInsertStarAction(proto.MergeIntoTable.Action.InsertStarAction.newBuilder()) + ) + .addNotMatchedBySourceActions( + proto.MergeIntoTable.Action.newBuilder() + .setCondition(createExpression("t.value < 25")) + .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) + ) + ) + ) + + val plan = transform(input) + val result = Dataset.ofRows(spark, plan).collect() + assert(result.length === 1) + assert(result.head.getLong(0) === 125) // num_affected_rows + assert(result.head.getLong(1) === 50) // num_updated_rows + assert(result.head.getLong(2) === 25) // num_deleted_rows + assert(result.head.getLong(3) === 50) // num_inserted_rows + + checkAnswer( + spark.read.table(targetTableName), + Seq.tabulate(25)(i => Row(25 + i, 25 + i)) ++ + Seq.tabulate(50)(i => Row(i + 50, i + 1050)) ++ + Seq.tabulate(50)(i => Row(i + 100, i + 1100)) + ) + } + } + private def createScan(tableName: String): spark_proto.Relation = { createSparkRelation( proto.DeltaRelation.newBuilder() @@ -464,6 +606,17 @@ class DeltaConnectPlannerSuite ) } + private def createSubqueryAlias( + input: spark_proto.Relation, alias: String): spark_proto.Relation = { + spark_proto.Relation.newBuilder() + .setSubqueryAlias( + spark_proto.SubqueryAlias.newBuilder() + .setAlias(alias) + .setInput(input) + ) + .build() + } + private def createExpression(expr: String): spark_proto.Expression = { spark_proto.Expression.newBuilder() .setExpressionString( From 5efd86dd1f84de30cbc0f06e32fe1ebbee1760c6 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Wed, 11 Sep 2024 19:22:48 +0200 Subject: [PATCH 2/6] Fix description --- .github/workflows/spark_master_test.yaml | 4 +- .../connect/tables/DeltaMergeBuilder.scala | 690 ++++++++++++++++++ 2 files changed, 692 insertions(+), 2 deletions(-) create mode 100644 spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala diff --git a/.github/workflows/spark_master_test.yaml b/.github/workflows/spark_master_test.yaml index 3906c31f221..07f72e26ed2 100644 --- a/.github/workflows/spark_master_test.yaml +++ b/.github/workflows/spark_master_test.yaml @@ -51,7 +51,7 @@ jobs: - name: Run Spark Master tests # when changing TEST_PARALLELISM_COUNT make sure to also change it in spark_test.yaml run: | - TEST_PARALLELISM_COUNT=4 SHARD_ID=${{matrix.shard}} build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean spark/test - TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/test TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/assembly connectClient/test + TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/test + TEST_PARALLELISM_COUNT=4 SHARD_ID=${{matrix.shard}} build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean spark/test if: steps.git-diff.outputs.diff diff --git a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala new file mode 100644 index 00000000000..31e1e0e2d3b --- /dev/null +++ b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala @@ -0,0 +1,690 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.tables + +import scala.collection.JavaConverters._ +import scala.collection.Map + +import io.delta.connect.proto +import io.delta.connect.spark.{proto => spark_proto} + +import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.{functions, Column, DataFrame} +import org.apache.spark.sql.connect.delta.ImplicitProtoConversions._ +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr + +/** + * Builder to specify how to merge data from source DataFrame into the target Delta table. + * You can specify any number of `whenMatched` and `whenNotMatched` clauses. + * Here are the constraints on these clauses. + * + * - `whenMatched` clauses: + * + * - The condition in a `whenMatched` clause is optional. However, if there are multiple + * `whenMatched` clauses, then only the last one may omit the condition. + * + * - When there are more than one `whenMatched` clauses and there are conditions (or the lack + * of) such that a row satisfies multiple clauses, then the action for the first clause + * satisfied is executed. In other words, the order of the `whenMatched` clauses matters. + * + * - If none of the `whenMatched` clauses match a source-target row pair that satisfy + * the merge condition, then the target rows will not be updated or deleted. + * + * - If you want to update all the columns of the target Delta table with the + * corresponding column of the source DataFrame, then you can use the + * `whenMatched(...).updateAll()`. This is equivalent to + *
+ *         whenMatched(...).updateExpr(Map(
+ *           ("col1", "source.col1"),
+ *           ("col2", "source.col2"),
+ *           ...))
+ *       
+ * + * - `whenNotMatched` clauses: + * + * - The condition in a `whenNotMatched` clause is optional. However, if there are + * multiple `whenNotMatched` clauses, then only the last one may omit the condition. + * + * - When there are more than one `whenNotMatched` clauses and there are conditions (or the + * lack of) such that a row satisfies multiple clauses, then the action for the first clause + * satisfied is executed. In other words, the order of the `whenNotMatched` clauses matters. + * + * - If no `whenNotMatched` clause is present or if it is present but the non-matching source + * row does not satisfy the condition, then the source row is not inserted. + * + * - If you want to insert all the columns of the target Delta table with the + * corresponding column of the source DataFrame, then you can use + * `whenNotMatched(...).insertAll()`. This is equivalent to + *
+ *         whenNotMatched(...).insertExpr(Map(
+ *           ("col1", "source.col1"),
+ *           ("col2", "source.col2"),
+ *           ...))
+ *       
+ * + * - `whenNotMatchedBySource` clauses: + * + * - The condition in a `whenNotMatchedBySource` clause is optional. However, if there are + * multiple `whenNotMatchedBySource` clauses, then only the last one may omit the condition. + * + * - When there are more than one `whenNotMatchedBySource` clauses and there are conditions (or + * the lack of) such that a row satisfies multiple clauses, then the action for the first + * clause satisfied is executed. In other words, the order of the `whenNotMatchedBySource` + * clauses matters. + * + * - If no `whenNotMatchedBySource` clause is present or if it is present but the + * non-matching target row does not satisfy any of the `whenNotMatchedBySource` clause + * condition, then the target row will not be updated or deleted. + * + * + * Scala example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched() + * .updateExpr(Map( + * "value" -> "source.value")) + * .whenNotMatched() + * .insertExpr(Map( + * "key" -> "source.key", + * "value" -> "source.value")) + * .whenNotMatchedBySource() + * .updateExpr(Map( + * "value" -> "target.value + 1")) + * .execute() + * }}} + * + * Java example to update a key-value Delta table with new key-values from a source DataFrame: + * {{{ + * deltaTable + * .as("target") + * .merge( + * source.as("source"), + * "target.key = source.key") + * .whenMatched() + * .updateExpr( + * new HashMap() {{ + * put("value", "source.value"); + * }}) + * .whenNotMatched() + * .insertExpr( + * new HashMap() {{ + * put("key", "source.key"); + * put("value", "source.value"); + * }}) + * .whenNotMatchedBySource() + * .updateExpr( + * new HashMap() {{ + * put("value", "target.value + 1"); + * }}) + * .execute(); + * }}} + * + * @since 4.0.0 + */ +class DeltaMergeBuilder private( + private val targetTable: DeltaTable, + private val source: DataFrame, + private val onCondition: Column, + private val whenMatchedClauses: Seq[proto.MergeIntoTable.Action], + private val whenNotMatchedClauses: Seq[proto.MergeIntoTable.Action], + private val whenNotMatchedBySourceClauses: Seq[proto.MergeIntoTable.Action]) { + + /** + * Build the actions to perform when the merge condition was matched. This returns + * [[DeltaMergeMatchedActionBuilder]] object which can be used to specify how + * to update or delete the matched target table row with the source row. + * + * @since 4.0.0 + */ + def whenMatched(): DeltaMergeMatchedActionBuilder = { + DeltaMergeMatchedActionBuilder(this, None) + } + + /** + * Build the actions to perform when the merge condition was matched and + * the given `condition` is true. This returns [[DeltaMergeMatchedActionBuilder]] object + * which can be used to specify how to update or delete the matched target table row with the + * source row. + * + * @param condition boolean expression as a SQL formatted string. + * + * @since 4.0.0 + */ + def whenMatched(condition: String): DeltaMergeMatchedActionBuilder = { + whenMatched(expr(condition)) + } + + /** + * Build the actions to perform when the merge condition was matched and + * the given `condition` is true. This returns a [[DeltaMergeMatchedActionBuilder]] object + * which can be used to specify how to update or delete the matched target table row with the + * source row. + * + * @param condition boolean expression as a Column object. + * + * @since 4.0.0 + */ + def whenMatched(condition: Column): DeltaMergeMatchedActionBuilder = { + DeltaMergeMatchedActionBuilder(this, Some(condition)) + } + + /** + * Build the action to perform when the merge condition was not matched. This returns + * [[DeltaMergeNotMatchedActionBuilder]] object which can be used to specify how + * to insert the new sourced row into the target table. + * + * @since 4.0.0 + */ + def whenNotMatched(): DeltaMergeNotMatchedActionBuilder = { + DeltaMergeNotMatchedActionBuilder(this, None) + } + + /** + * Build the actions to perform when the merge condition was not matched and + * the given `condition` is true. This returns [[DeltaMergeMatchedActionBuilder]] object + * which can be used to specify how to insert the new sourced row into the target table. + * + * @param condition boolean expression as a SQL formatted string. + * + * @since 4.0.0 + */ + def whenNotMatched(condition: String): DeltaMergeNotMatchedActionBuilder = { + whenNotMatched(expr(condition)) + } + + /** + * Build the actions to perform when the merge condition was not matched and + * the given `condition` is true. This returns [[DeltaMergeMatchedActionBuilder]] object + * which can be used to specify how to insert the new sourced row into the target table. + * + * @param condition boolean expression as a Column object. + * + * @since 4.0.0 + */ + def whenNotMatched(condition: Column): DeltaMergeNotMatchedActionBuilder = { + DeltaMergeNotMatchedActionBuilder(this, Some(condition)) + } + + /** + * Build the actions to perform when the merge condition was not matched by the source. This + * returns [[DeltaMergeNotMatchedBySourceActionBuilder]] object which can be used to specify how + * to update or delete the target table row. + * + * @since 4.0.0 + */ + def whenNotMatchedBySource(): DeltaMergeNotMatchedBySourceActionBuilder = { + DeltaMergeNotMatchedBySourceActionBuilder(this, None) + } + + /** + * Build the actions to perform when the merge condition was not matched by the source and the + * given `condition` is true. This returns [[DeltaMergeNotMatchedBySourceActionBuilder]] object + * which can be used to specify how to update or delete the target table row. + * + * @param condition boolean expression as a SQL formatted string. + * + * @since 4.0.0 + */ + def whenNotMatchedBySource(condition: String): DeltaMergeNotMatchedBySourceActionBuilder = { + whenNotMatchedBySource(expr(condition)) + } + + /** + * Build the actions to perform when the merge condition was not matched by the source and the + * given `condition` is true. This returns [[DeltaMergeNotMatchedBySourceActionBuilder]] object + * which can be used to specify how to update or delete the target table row. + * + * @param condition boolean expression as a Column object. + * + * @since 4.0.0 + */ + def whenNotMatchedBySource(condition: Column): DeltaMergeNotMatchedBySourceActionBuilder = { + DeltaMergeNotMatchedBySourceActionBuilder(this, Some(condition)) + } + + /** + * Execute the merge operation based on the built matched and not matched actions. + * + * @since 4.0.0 + */ + def execute(): Unit = { + val sparkSession = targetTable.toDF.sparkSession + val merge = proto.MergeIntoTable + .newBuilder() + .setTarget(targetTable.toDF.plan.getRoot) + .setSource(source.plan.getRoot) + .setCondition(toExpr(onCondition)) + .addAllMatchedActions(whenMatchedClauses.asJava) + .addAllNotMatchedActions(whenNotMatchedClauses.asJava) + .addAllNotMatchedBySourceActions(whenNotMatchedBySourceClauses.asJava) + val relation = proto.DeltaRelation.newBuilder().setMergeIntoTable(merge).build() + val extension = com.google.protobuf.Any.pack(relation) + val sparkRelation = spark_proto.Relation.newBuilder().setExtension(extension).build() + sparkSession.newDataFrame(_.mergeFrom(sparkRelation)).collect() + } + + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def withWhenMatchedClause( + clause: proto.MergeIntoTable.Action): DeltaMergeBuilder = { + new DeltaMergeBuilder( + this.targetTable, + this.source, + this.onCondition, + this.whenMatchedClauses :+ clause, + this.whenNotMatchedClauses, + this.whenNotMatchedBySourceClauses) + } + + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def withWhenNotMatchedClause( + clause: proto.MergeIntoTable.Action): DeltaMergeBuilder = { + new DeltaMergeBuilder( + this.targetTable, + this.source, + this.onCondition, + this.whenMatchedClauses, + this.whenNotMatchedClauses :+ clause, + this.whenNotMatchedBySourceClauses) + } + + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def withWhenNotMatchedBySourceClause( + clause: proto.MergeIntoTable.Action): DeltaMergeBuilder = { + new DeltaMergeBuilder( + this.targetTable, + this.source, + this.onCondition, + this.whenMatchedClauses, + this.whenNotMatchedClauses, + this.whenNotMatchedBySourceClauses :+ clause) + } +} + +object DeltaMergeBuilder { + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def apply( + targetTable: DeltaTable, source: DataFrame, onCondition: Column): DeltaMergeBuilder = { + new DeltaMergeBuilder(targetTable, source, onCondition, Nil, Nil, Nil) + } +} + +/** + * Builder class to specify the actions to perform when a target table row has matched a + * source row based on the given merge condition and optional match condition. + * + * See [[DeltaMergeBuilder]] for more information. + * + * @since 4.0.0 + */ +class DeltaMergeMatchedActionBuilder private( + private val mergeBuilder: DeltaMergeBuilder, + private val matchCondition: Option[Column]) { + + /** + * Update the matched table rows based on the rules defined by `set`. + * + * @param set rules to update a row as a Scala map between target column names and + * corresponding update expressions as Column objects. + * + * @since 4.0.0 + */ + def update(set: Map[String, Column]): DeltaMergeBuilder = { + addUpdateClause(set) + } + + /** + * Update the matched table rows based on the rules defined by `set`. + * + * @param set rules to update a row as a Scala map between target column names and + * corresponding update expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def updateExpr(set: Map[String, String]): DeltaMergeBuilder = { + addUpdateClause(toStrColumnMap(set)) + } + + /** + * Update a matched table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Java map between target column names and + * corresponding expressions as Column objects. + * + * @since 4.0.0 + */ + def update(set: java.util.Map[String, Column]): DeltaMergeBuilder = { + addUpdateClause(set.asScala.toMap) + } + + /** + * Update a matched table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Java map between target column names and + * corresponding expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def updateExpr(set: java.util.Map[String, String]): DeltaMergeBuilder = { + addUpdateClause(toStrColumnMap(set.asScala.toMap)) + } + + /** + * Update all the columns of the matched table row with the values of the + * corresponding columns in the source row. + * + * @since 4.0.0 + */ + def updateAll(): DeltaMergeBuilder = { + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setUpdateStarAction(proto.MergeIntoTable.Action.UpdateStarAction.newBuilder()) + matchCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenMatchedClause(clause.build()) + } + + /** + * Delete a matched row from the table. + * + * @since 4.0.0 + */ + def delete(): DeltaMergeBuilder = { + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) + matchCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenMatchedClause(clause.build()) + } + + private def addUpdateClause(set: Map[String, Column]): DeltaMergeBuilder = { + if (set.isEmpty && matchCondition.isEmpty) { + // This is a catch all clause that doesn't update anything: we can ignore it. + mergeBuilder + } else { + val assignments = set.map { case (field, value) => + proto.Assignment.newBuilder().setField(toExpr(expr(field))).setValue(toExpr(value)).build() + } + val action = proto.MergeIntoTable.Action.UpdateAction + .newBuilder() + .addAllAssignments(assignments.asJava) + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setUpdateAction(action) + matchCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenMatchedClause(clause.build()) + } + } + + private def toStrColumnMap(map: Map[String, String]): Map[String, Column] = + map.mapValues(functions.expr).toMap +} + +object DeltaMergeMatchedActionBuilder { + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def apply( + mergeBuilder: DeltaMergeBuilder, + matchCondition: Option[Column]): DeltaMergeMatchedActionBuilder = { + new DeltaMergeMatchedActionBuilder(mergeBuilder, matchCondition) + } +} + + +/** + * Builder class to specify the actions to perform when a source row has not matched any target + * Delta table row based on the merge condition, but has matched the additional condition + * if specified. + * + * See [[DeltaMergeBuilder]] for more information. + * + * @since 4.0.0 + */ +class DeltaMergeNotMatchedActionBuilder private( + private val mergeBuilder: DeltaMergeBuilder, + private val notMatchCondition: Option[Column]) { + + /** + * Insert a new row to the target table based on the rules defined by `values`. + * + * @param values rules to insert a row as a Scala map between target column names and + * corresponding expressions as Column objects. + * + * @since 4.0.0 + */ + def insert(values: Map[String, Column]): DeltaMergeBuilder = { + addInsertClause(values) + } + + /** + * Insert a new row to the target table based on the rules defined by `values`. + * + * @param values rules to insert a row as a Scala map between target column names and + * corresponding expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def insertExpr(values: Map[String, String]): DeltaMergeBuilder = { + addInsertClause(toStrColumnMap(values)) + } + + /** + * Insert a new row to the target table based on the rules defined by `values`. + * + * @param values rules to insert a row as a Java map between target column names and + * corresponding expressions as Column objects. + * + * @since 4.0.0 + */ + def insert(values: java.util.Map[String, Column]): DeltaMergeBuilder = { + addInsertClause(values.asScala) + } + + /** + * Insert a new row to the target table based on the rules defined by `values`. + * + * @param values rules to insert a row as a Java map between target column names and + * corresponding expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def insertExpr(values: java.util.Map[String, String]): DeltaMergeBuilder = { + addInsertClause(toStrColumnMap(values.asScala)) + } + + /** + * Insert a new target Delta table row by assigning the target columns to the values of the + * corresponding columns in the source row. + * + * @since 4.0.0 + */ + def insertAll(): DeltaMergeBuilder = { + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setInsertStarAction(proto.MergeIntoTable.Action.InsertStarAction.newBuilder()) + notMatchCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenNotMatchedClause(clause.build()) + } + + private def addInsertClause(setValues: Map[String, Column]): DeltaMergeBuilder = { + val assignments = setValues.map { case (field, value) => + proto.Assignment.newBuilder().setField(toExpr(expr(field))).setValue(toExpr(value)).build() + } + val action = proto.MergeIntoTable.Action.InsertAction + .newBuilder() + .addAllAssignments(assignments.asJava) + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setInsertAction(action) + notMatchCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenNotMatchedClause(clause.build()) + } + + private def toStrColumnMap(map: Map[String, String]): Map[String, Column] = + map.mapValues(functions.expr).toMap +} + +object DeltaMergeNotMatchedActionBuilder { + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def apply( + mergeBuilder: DeltaMergeBuilder, + notMatchCondition: Option[Column]): DeltaMergeNotMatchedActionBuilder = { + new DeltaMergeNotMatchedActionBuilder(mergeBuilder, notMatchCondition) + } +} + +/** + * Builder class to specify the actions to perform when a target table row has no match in the + * source table based on the given merge condition and optional match condition. + * + * See [[DeltaMergeBuilder]] for more information. + * + * @since 4.0.0 + */ +class DeltaMergeNotMatchedBySourceActionBuilder private( + private val mergeBuilder: DeltaMergeBuilder, + private val notMatchBySourceCondition: Option[Column]) { + + /** + * Update an unmatched target table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Scala map between target column names and + * corresponding update expressions as Column objects. + * + * @since 4.0.0 + */ + def update(set: Map[String, Column]): DeltaMergeBuilder = { + addUpdateClause(set) + } + + /** + * Update an unmatched target table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Scala map between target column names and + * corresponding update expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def updateExpr(set: Map[String, String]): DeltaMergeBuilder = { + addUpdateClause(toStrColumnMap(set)) + } + + /** + * Update an unmatched target table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Java map between target column names and + * corresponding expressions as Column objects. + * + * @since 4.0.0 + */ + def update(set: java.util.Map[String, Column]): DeltaMergeBuilder = { + addUpdateClause(set.asScala) + } + + /** + * Update an unmatched target table row based on the rules defined by `set`. + * + * @param set rules to update a row as a Java map between target column names and + * corresponding expressions as SQL formatted strings. + * + * @since 4.0.0 + */ + def updateExpr(set: java.util.Map[String, String]): DeltaMergeBuilder = { + addUpdateClause(toStrColumnMap(set.asScala)) + } + + /** + * Delete an unmatched row from the target table. + * + * @since 4.0.0 + */ + def delete(): DeltaMergeBuilder = { + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) + notMatchBySourceCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenNotMatchedBySourceClause(clause.build()) + } + + private def addUpdateClause(set: Map[String, Column]): DeltaMergeBuilder = { + if (set.isEmpty && notMatchBySourceCondition.isEmpty) { + // This is a catch all clause that doesn't update anything: we can ignore it. + mergeBuilder + } else { + val assignments = set.map { case (field, value) => + proto.Assignment.newBuilder().setField(toExpr(expr(field))).setValue(toExpr(value)).build() + } + val action = proto.MergeIntoTable.Action.UpdateAction + .newBuilder() + .addAllAssignments(assignments.asJava) + val clause = proto.MergeIntoTable.Action + .newBuilder() + .setUpdateAction(action) + notMatchBySourceCondition.foreach(c => clause.setCondition(toExpr(c))) + mergeBuilder.withWhenNotMatchedBySourceClause(clause.build()) + } + } + + private def toStrColumnMap(map: Map[String, String]): Map[String, Column] = + map.mapValues(functions.expr).toMap +} + +object DeltaMergeNotMatchedBySourceActionBuilder { + /** + * :: Unstable :: + * + * Private method for internal usage only. Do not call this directly. + */ + @Unstable + private[delta] def apply( + mergeBuilder: DeltaMergeBuilder, + notMatchBySourceCondition: Option[Column]): DeltaMergeNotMatchedBySourceActionBuilder = { + new DeltaMergeNotMatchedBySourceActionBuilder(mergeBuilder, notMatchBySourceCondition) + } +} From 8afe303a0ac1edc09af8de5e10f0814663f8ca9a Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Thu, 19 Sep 2024 23:17:25 +0200 Subject: [PATCH 3/6] Add withSchemaEvolution --- .github/workflows/spark_master_test.yaml | 4 +- .../connect/tables/DeltaMergeBuilder.scala | 46 ++++++++++++-- .../tables/DeltaMergeBuilderSuite.scala | 49 ++++++++++++++- .../connect/DeltaConnectPlannerSuite.scala | 63 +++++++++++++++++-- 4 files changed, 149 insertions(+), 13 deletions(-) diff --git a/.github/workflows/spark_master_test.yaml b/.github/workflows/spark_master_test.yaml index 07f72e26ed2..3906c31f221 100644 --- a/.github/workflows/spark_master_test.yaml +++ b/.github/workflows/spark_master_test.yaml @@ -51,7 +51,7 @@ jobs: - name: Run Spark Master tests # when changing TEST_PARALLELISM_COUNT make sure to also change it in spark_test.yaml run: | - TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/assembly connectClient/test - TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/test TEST_PARALLELISM_COUNT=4 SHARD_ID=${{matrix.shard}} build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean spark/test + TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/test + TEST_PARALLELISM_COUNT=4 build/sbt -DsparkVersion=master "++ ${{ matrix.scala }}" clean connectServer/assembly connectClient/test if: steps.git-diff.outputs.diff diff --git a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala index 31e1e0e2d3b..4a39d894cec 100644 --- a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala +++ b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala @@ -99,6 +99,7 @@ import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr * .merge( * source.as("source"), * "target.key = source.key") + * .withSchemaEvolution() * .whenMatched() * .updateExpr(Map( * "value" -> "source.value")) @@ -119,6 +120,7 @@ import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr * .merge( * source.as("source"), * "target.key = source.key") + * .withSchemaEvolution() * .whenMatched() * .updateExpr( * new HashMap() {{ @@ -146,7 +148,19 @@ class DeltaMergeBuilder private( private val onCondition: Column, private val whenMatchedClauses: Seq[proto.MergeIntoTable.Action], private val whenNotMatchedClauses: Seq[proto.MergeIntoTable.Action], - private val whenNotMatchedBySourceClauses: Seq[proto.MergeIntoTable.Action]) { + private val whenNotMatchedBySourceClauses: Seq[proto.MergeIntoTable.Action], + private val schemaEvolutionEnabled: Boolean) { + + + def this( + targetTable: DeltaTable, + source: DataFrame, + onCondition: Column, + whenMatchedClauses: Seq[proto.MergeIntoTable.Action], + whenNotMatchedClauses: Seq[proto.MergeIntoTable.Action], + whenNotMatchedBySourceClauses: Seq[proto.MergeIntoTable.Action]) = + this(targetTable, source, onCondition, whenMatchedClauses, + whenNotMatchedClauses, whenNotMatchedBySourceClauses, schemaEvolutionEnabled = false) /** * Build the actions to perform when the merge condition was matched. This returns @@ -200,7 +214,7 @@ class DeltaMergeBuilder private( /** * Build the actions to perform when the merge condition was not matched and - * the given `condition` is true. This returns [[DeltaMergeMatchedActionBuilder]] object + * the given `condition` is true. This returns [[DeltaMergeNotMatchedActionBuilder]] object * which can be used to specify how to insert the new sourced row into the target table. * * @param condition boolean expression as a SQL formatted string. @@ -213,7 +227,7 @@ class DeltaMergeBuilder private( /** * Build the actions to perform when the merge condition was not matched and - * the given `condition` is true. This returns [[DeltaMergeMatchedActionBuilder]] object + * the given `condition` is true. This returns [[DeltaMergeNotMatchedActionBuilder]] object * which can be used to specify how to insert the new sourced row into the target table. * * @param condition boolean expression as a Column object. @@ -261,6 +275,23 @@ class DeltaMergeBuilder private( DeltaMergeNotMatchedBySourceActionBuilder(this, Some(condition)) } + /** + * Enable schema evolution for the merge operation. This allows the schema of the target + * table/columns to be automatically updated based on the schema of the source table/columns. + * + * @since 4.0.0 + */ + def withSchemaEvolution(): DeltaMergeBuilder = { + new DeltaMergeBuilder( + this.targetTable, + this.source, + this.onCondition, + this.whenMatchedClauses, + this.whenNotMatchedClauses, + this.whenNotMatchedBySourceClauses, + schemaEvolutionEnabled = true) + } + /** * Execute the merge operation based on the built matched and not matched actions. * @@ -296,7 +327,8 @@ class DeltaMergeBuilder private( this.onCondition, this.whenMatchedClauses :+ clause, this.whenNotMatchedClauses, - this.whenNotMatchedBySourceClauses) + this.whenNotMatchedBySourceClauses, + this.schemaEvolutionEnabled) } /** @@ -313,7 +345,8 @@ class DeltaMergeBuilder private( this.onCondition, this.whenMatchedClauses, this.whenNotMatchedClauses :+ clause, - this.whenNotMatchedBySourceClauses) + this.whenNotMatchedBySourceClauses, + this.schemaEvolutionEnabled) } /** @@ -330,7 +363,8 @@ class DeltaMergeBuilder private( this.onCondition, this.whenMatchedClauses, this.whenNotMatchedClauses, - this.whenNotMatchedBySourceClauses :+ clause) + this.whenNotMatchedBySourceClauses :+ clause, + this.schemaEvolutionEnabled) } } diff --git a/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala index 42175a50c7a..5727176dab4 100644 --- a/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala +++ b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala @@ -381,11 +381,58 @@ class DeltaMergeBuilderSuite extends DeltaQueryTest with RemoteSparkSession { } } + test("merge with the withSchemaEvolution API") { + val session = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getAbsolutePath + Seq("a", "b", "c", "d").toDF("key") + .write.mode("overwrite").format("delta").save(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), expr("t.key = s.key")) + .withSchemaEvolution() + .whenMatched().updateAll() + .whenNotMatched().insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq( + Row("a", -1), Row("b", 0), Row("c", null), Row("d", null), Row("e", -5), Row("f", -6))) + } + } + + test("merge fails due to no withSchemaEvolution while schema evolution is needed") { + val session = spark + import session.implicits._ + + withTempPath { dir => + val path = dir.getAbsolutePath + Seq("a", "b", "c", "d").toDF("key") + .write.mode("overwrite").format("delta").save(path) + val deltaTable = DeltaTable.forPath(spark, path) + + deltaTable.as("t") + .merge(testSource.toDF("key", "value").as("s"), expr("t.key = s.key")) + .whenMatched().updateAll() + .whenNotMatched().insertAll() + .execute() + + checkAnswer( + deltaTable.toDF, + Seq( + Row("a", -1), Row("b", 0), Row("c", null), Row("d", null), Row("e", -5), Row("f", -6))) + } + } + test("merge dataframe with many columns") { withTempPath { dir => val path = dir.getAbsolutePath var df1 = spark.range(1).toDF - val numColumns = 100 + val numColumns = 20 for (i <- 0 until numColumns) { df1 = df1.withColumn(s"col$i", col("id")) } diff --git a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala index 9e3d8815aa5..bbfccb348ae 100644 --- a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala +++ b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala @@ -544,7 +544,7 @@ class DeltaConnectPlannerSuite } } - test("merge - mixed") { + test("merge - withSchemaEvolution") { val targetTableName = "target" val sourceTableName = "source" withTable(targetTableName, sourceTableName) { @@ -554,6 +554,7 @@ class DeltaConnectPlannerSuite spark.range(end = 100) .select(col("id") + 50 as "id") .select(col("id") as "key", col("id") + 1000 as "value") + .select(col("id") + 1 as "deltaLake") .write.format("delta").saveAsTable(sourceTableName) val input = createSparkRelation( @@ -576,6 +577,7 @@ class DeltaConnectPlannerSuite .setCondition(createExpression("t.value < 25")) .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) ) + .setWithSchemaEvolution(true) ) ) @@ -589,9 +591,62 @@ class DeltaConnectPlannerSuite checkAnswer( spark.read.table(targetTableName), - Seq.tabulate(25)(i => Row(25 + i, 25 + i)) ++ - Seq.tabulate(50)(i => Row(i + 50, i + 1050)) ++ - Seq.tabulate(50)(i => Row(i + 100, i + 1100)) + Seq.tabulate(25)(i => Row(25 + i, 25 + i, None)) ++ + Seq.tabulate(50)(i => Row(i + 50, i + 1050, i + 51)) ++ + Seq.tabulate(50)(i => Row(i + 100, i + 1100, i + 101)) + ) + } + } + + test("merge fails due to no withSchemaEvolution while schema evolution is needed") { + val targetTableName = "target" + val sourceTableName = "source" + withTable(targetTableName, sourceTableName) { + spark.range(end = 100).select(col("id") as "key", col("id") as "value") + .write.format("delta").saveAsTable(targetTableName) + + spark.range(end = 100) + .select(col("id") + 50 as "id") + .select(col("id") as "key", col("id") + 1000 as "value") + .select(col("id") + 1 as "deltaLake") + .write.format("delta").saveAsTable(sourceTableName) + + val input = createSparkRelation( + proto.DeltaRelation.newBuilder() + .setMergeIntoTable( + proto.MergeIntoTable.newBuilder() + .setTarget(createSubqueryAlias(createScan(targetTableName), alias = "t")) + .setSource(createSubqueryAlias(createScan(sourceTableName), alias = "s")) + .setCondition(createExpression("t.key = s.key")) + .addMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setUpdateStarAction(proto.MergeIntoTable.Action.UpdateStarAction.newBuilder()) + ) + .addNotMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setInsertStarAction(proto.MergeIntoTable.Action.InsertStarAction.newBuilder()) + ) + .addNotMatchedBySourceActions( + proto.MergeIntoTable.Action.newBuilder() + .setCondition(createExpression("t.value < 25")) + .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) + ) + ) + ) + + val plan = transform(input) + val result = Dataset.ofRows(spark, plan).collect() + assert(result.length === 1) + assert(result.head.getLong(0) === 125) // num_affected_rows + assert(result.head.getLong(1) === 50) // num_updated_rows + assert(result.head.getLong(2) === 25) // num_deleted_rows + assert(result.head.getLong(3) === 50) // num_inserted_rows + + checkAnswer( + spark.read.table(targetTableName), + Seq.tabulate(25)(i => Row(25 + i, 25 + i, None)) ++ + Seq.tabulate(50)(i => Row(i + 50, i + 1050, i + 51)) ++ + Seq.tabulate(50)(i => Row(i + 100, i + 1100, i + 101)) ) } } From b86e9a21bef94ff63db797356b59a905ca88f7f1 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Thu, 19 Sep 2024 23:42:06 +0200 Subject: [PATCH 4/6] Update --- .../connect/DeltaConnectPlannerSuite.scala | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala index bbfccb348ae..82676030eae 100644 --- a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala +++ b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala @@ -485,6 +485,7 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) + assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 50) // num_affected_rows @@ -530,6 +531,7 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) + assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 50) // num_affected_rows @@ -544,6 +546,59 @@ class DeltaConnectPlannerSuite } } + test("merge - mixed") { + val targetTableName = "target" + val sourceTableName = "source" + withTable(targetTableName, sourceTableName) { + spark.range(end = 100).select(col("id") as "key", col("id") as "value") + .write.format("delta").saveAsTable(targetTableName) + + spark.range(end = 100) + .select(col("id") + 50 as "id") + .select(col("id") as "key", col("id") + 1000 as "value") + .write.format("delta").saveAsTable(sourceTableName) + + val input = createSparkRelation( + proto.DeltaRelation.newBuilder() + .setMergeIntoTable( + proto.MergeIntoTable.newBuilder() + .setTarget(createSubqueryAlias(createScan(targetTableName), alias = "t")) + .setSource(createSubqueryAlias(createScan(sourceTableName), alias = "s")) + .setCondition(createExpression("t.key = s.key")) + .addMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setUpdateStarAction(proto.MergeIntoTable.Action.UpdateStarAction.newBuilder()) + ) + .addNotMatchedActions( + proto.MergeIntoTable.Action.newBuilder() + .setInsertStarAction(proto.MergeIntoTable.Action.InsertStarAction.newBuilder()) + ) + .addNotMatchedBySourceActions( + proto.MergeIntoTable.Action.newBuilder() + .setCondition(createExpression("t.value < 25")) + .setDeleteAction(proto.MergeIntoTable.Action.DeleteAction.newBuilder()) + ) + ) + ) + + val plan = transform(input) + assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) + val result = Dataset.ofRows(spark, plan).collect() + assert(result.length === 1) + assert(result.head.getLong(0) === 125) // num_affected_rows + assert(result.head.getLong(1) === 50) // num_updated_rows + assert(result.head.getLong(2) === 25) // num_deleted_rows + assert(result.head.getLong(3) === 50) // num_inserted_rows + + checkAnswer( + spark.read.table(targetTableName), + Seq.tabulate(25)(i => Row(25 + i, 25 + i)) ++ + Seq.tabulate(50)(i => Row(i + 50, i + 1050)) ++ + Seq.tabulate(50)(i => Row(i + 100, i + 1100)) + ) + } + } + test("merge - withSchemaEvolution") { val targetTableName = "target" val sourceTableName = "source" @@ -554,7 +609,7 @@ class DeltaConnectPlannerSuite spark.range(end = 100) .select(col("id") + 50 as "id") .select(col("id") as "key", col("id") + 1000 as "value") - .select(col("id") + 1 as "deltaLake") + .select(col("id") + 1 as "keyplusone") .write.format("delta").saveAsTable(sourceTableName) val input = createSparkRelation( @@ -582,6 +637,7 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) + assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 125) // num_affected_rows @@ -608,7 +664,7 @@ class DeltaConnectPlannerSuite spark.range(end = 100) .select(col("id") + 50 as "id") .select(col("id") as "key", col("id") + 1000 as "value") - .select(col("id") + 1 as "deltaLake") + .select(col("id") + 1 as "valueplusone") .write.format("delta").saveAsTable(sourceTableName) val input = createSparkRelation( @@ -635,6 +691,7 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) + assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 125) // num_affected_rows From 63c3dd16cd6e4e63c6dffe163403bd5f11cd8743 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Fri, 20 Sep 2024 13:50:02 +0200 Subject: [PATCH 5/6] Update withSchemaEvolution code --- .../connect/tables/DeltaMergeBuilder.scala | 3 ++- .../tables/DeltaMergeBuilderSuite.scala | 6 ++--- .../connect/DeltaConnectPlannerSuite.scala | 23 ++++++++++--------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala index 4a39d894cec..a84ec01bba9 100644 --- a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala +++ b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala @@ -151,7 +151,7 @@ class DeltaMergeBuilder private( private val whenNotMatchedBySourceClauses: Seq[proto.MergeIntoTable.Action], private val schemaEvolutionEnabled: Boolean) { - + // Schema Evolution is off by default in Merge. def this( targetTable: DeltaTable, source: DataFrame, @@ -307,6 +307,7 @@ class DeltaMergeBuilder private( .addAllMatchedActions(whenMatchedClauses.asJava) .addAllNotMatchedActions(whenNotMatchedClauses.asJava) .addAllNotMatchedBySourceActions(whenNotMatchedBySourceClauses.asJava) + .setWithSchemaEvolution(schemaEvolutionEnabled) val relation = proto.DeltaRelation.newBuilder().setMergeIntoTable(merge).build() val extension = com.google.protobuf.Any.pack(relation) val sparkRelation = spark_proto.Relation.newBuilder().setExtension(extension).build() diff --git a/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala index 5727176dab4..a81e200abd4 100644 --- a/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala +++ b/spark-connect/client/src/test/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilderSuite.scala @@ -405,7 +405,8 @@ class DeltaMergeBuilderSuite extends DeltaQueryTest with RemoteSparkSession { } } - test("merge fails due to no withSchemaEvolution while schema evolution is needed") { + test("merge with no withSchemaEvolution while the source's schema " + + "is different than the target's schema") { val session = spark import session.implicits._ @@ -423,8 +424,7 @@ class DeltaMergeBuilderSuite extends DeltaQueryTest with RemoteSparkSession { checkAnswer( deltaTable.toDF, - Seq( - Row("a", -1), Row("b", 0), Row("c", null), Row("d", null), Row("e", -5), Row("f", -6))) + Seq(Row("a"), Row("b"), Row("c"), Row("d"), Row("e"), Row("f"))) } } diff --git a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala index 82676030eae..5fc6eb59586 100644 --- a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala +++ b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala @@ -608,8 +608,7 @@ class DeltaConnectPlannerSuite spark.range(end = 100) .select(col("id") + 50 as "id") - .select(col("id") as "key", col("id") + 1000 as "value") - .select(col("id") + 1 as "keyplusone") + .select(col("id") as "key", col("id") + 1000 as "value", col("id") + 1 as "extracol") .write.format("delta").saveAsTable(sourceTableName) val input = createSparkRelation( @@ -636,7 +635,8 @@ class DeltaConnectPlannerSuite ) ) - val plan = transform(input) + val plan = new SparkConnectPlanner( + SparkConnectTestUtils.createDummySessionHolder(spark)).transformRelation(input) assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) @@ -647,14 +647,15 @@ class DeltaConnectPlannerSuite checkAnswer( spark.read.table(targetTableName), - Seq.tabulate(25)(i => Row(25 + i, 25 + i, None)) ++ + Seq.tabulate(25)(i => Row(25 + i, 25 + i, null)) ++ Seq.tabulate(50)(i => Row(i + 50, i + 1050, i + 51)) ++ Seq.tabulate(50)(i => Row(i + 100, i + 1100, i + 101)) ) } } - test("merge fails due to no withSchemaEvolution while schema evolution is needed") { + test("merge with no withSchemaEvolution while the source's schema " + + "is different than the target's schema") { val targetTableName = "target" val sourceTableName = "source" withTable(targetTableName, sourceTableName) { @@ -663,8 +664,7 @@ class DeltaConnectPlannerSuite spark.range(end = 100) .select(col("id") + 50 as "id") - .select(col("id") as "key", col("id") + 1000 as "value") - .select(col("id") + 1 as "valueplusone") + .select(col("id") as "key", col("id") + 1000 as "value", col("id") + 1 as "extracol") .write.format("delta").saveAsTable(sourceTableName) val input = createSparkRelation( @@ -690,7 +690,8 @@ class DeltaConnectPlannerSuite ) ) - val plan = transform(input) + val plan = new SparkConnectPlanner( + SparkConnectTestUtils.createDummySessionHolder(spark)).transformRelation(input) assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) @@ -701,9 +702,9 @@ class DeltaConnectPlannerSuite checkAnswer( spark.read.table(targetTableName), - Seq.tabulate(25)(i => Row(25 + i, 25 + i, None)) ++ - Seq.tabulate(50)(i => Row(i + 50, i + 1050, i + 51)) ++ - Seq.tabulate(50)(i => Row(i + 100, i + 1100, i + 101)) + Seq.tabulate(25)(i => Row(25 + i, 25 + i)) ++ + Seq.tabulate(50)(i => Row(i + 50, i + 1050)) ++ + Seq.tabulate(50)(i => Row(i + 100, i + 1100)) ) } } From 66dbfa870d5db37b2755cb2a867c69b73e528506 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Fri, 20 Sep 2024 15:11:19 +0200 Subject: [PATCH 6/6] Polish --- .../io/delta/connect/tables/DeltaMergeBuilder.scala | 7 +++++++ .../io/delta/connect/DeltaConnectPlannerSuite.scala | 11 ++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala index a84ec01bba9..d4f2de9cf13 100644 --- a/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala +++ b/spark-connect/client/src/main/scala-spark-master/io/delta/connect/tables/DeltaMergeBuilder.scala @@ -488,6 +488,13 @@ class DeltaMergeMatchedActionBuilder private( } } + /** + * Converts a map of strings to expressions as SQL formatted string + * into a map of strings to Column objects. + * + * @param map A map where the value is an expression as SQL formatted string. + * @return A map where the value is a Column object created from the expression. + */ private def toStrColumnMap(map: Map[String, String]): Map[String, Column] = map.mapValues(functions.expr).toMap } diff --git a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala index 5fc6eb59586..e6cf16b6b3e 100644 --- a/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala +++ b/spark-connect/server/src/test/scala-spark-master/io/delta/connect/DeltaConnectPlannerSuite.scala @@ -485,7 +485,6 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) - assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 50) // num_affected_rows @@ -531,7 +530,6 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) - assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 50) // num_affected_rows @@ -582,7 +580,6 @@ class DeltaConnectPlannerSuite ) val plan = transform(input) - assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 125) // num_affected_rows @@ -635,9 +632,7 @@ class DeltaConnectPlannerSuite ) ) - val plan = new SparkConnectPlanner( - SparkConnectTestUtils.createDummySessionHolder(spark)).transformRelation(input) - assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) + val plan = transform(input) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 125) // num_affected_rows @@ -690,9 +685,7 @@ class DeltaConnectPlannerSuite ) ) - val plan = new SparkConnectPlanner( - SparkConnectTestUtils.createDummySessionHolder(spark)).transformRelation(input) - assert(plan.columns.toSeq === V2CommandOutputs.mergeOutput.map(_.name)) + val plan = transform(input) val result = Dataset.ofRows(spark, plan).collect() assert(result.length === 1) assert(result.head.getLong(0) === 125) // num_affected_rows