From 3b219651fb0c2eef5723f16e6f64d2e1a2ef7a2b Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Tue, 3 Feb 2026 16:12:48 +0800 Subject: [PATCH 1/3] feat: support update using RewriteColumns mode --- docs/src/operations/dml/.pages | 1 + docs/src/operations/dml/merge-into.md | 68 +++ docs/src/operations/dml/update.md | 53 +++ .../spark/LancePositionDeltaOperation.java | 13 +- .../spark/write/SparkPositionDeltaWrite.java | 217 ++++++++- .../write/SparkPositionDeltaWriteBuilder.java | 4 + .../optimizer/UpdateColumnsExtractor.scala | 138 ++++++ .../LanceSparkSessionExtensions.scala | 9 +- .../update/MergeIntoRewriteRowsTest.java | 46 ++ .../update/UpdateTableRewriteRowsTest.java | 71 +++ .../lance/spark/update/UpdateTableTest.java | 34 +- .../optimizer/UpdateColumnsExtractor.scala | 138 ++++++ .../LanceSparkSessionExtensions.scala | 9 +- .../java/org/lance/spark/utils/SparkUtil.java | 13 + .../lance/spark/update/BaseMergeIntoTest.java | 431 +++++++++++++++++- 15 files changed, 1222 insertions(+), 23 deletions(-) create mode 100644 docs/src/operations/dml/merge-into.md create mode 100644 lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala create mode 100644 lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/MergeIntoRewriteRowsTest.java create mode 100644 lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java create mode 100644 lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala diff --git a/docs/src/operations/dml/.pages b/docs/src/operations/dml/.pages index 057a3360..17b0ac2d 100644 --- a/docs/src/operations/dml/.pages +++ b/docs/src/operations/dml/.pages @@ -3,5 +3,6 @@ nav: - insert-into.md - update.md - delete.md + - merge-into.md - add-columns.md - update-columns.md diff --git a/docs/src/operations/dml/merge-into.md b/docs/src/operations/dml/merge-into.md new file mode 100644 index 00000000..28eeb498 --- /dev/null +++ b/docs/src/operations/dml/merge-into.md @@ -0,0 +1,68 @@ +# MERGE INTO + +Currently, merge into only supports for Spark 3.5+. + +Lance fully supports Spark's MERGE INTO operation. For specific usage, please refer to Spark's relevant documentation. + +``` +MERGE INTO customers c +USING new_updates u +ON c.id = u.id +WHEN MATCHED THEN +UPDATE SET c.status = u.new_status, c.last_seen = u.timestamp; +``` + +Additionally, `lance-spark` introduces a column rewrite mode for `MERGE INTO` operation, which can significantly improve performance for narrow updates that only affect a few columns. + +## Column Rewrite Mode + +This mode allows the Lance data source to perform column-level updates by writing new data files for only the modified columns, avoiding the need to rewrite the entire data file (the "delete and insert" pattern). + +!!! warning "Spark Extension Required" +This feature requires the Lance Spark SQL extension to be enabled. +See [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +### Configuration + +You can enable or disable this feature using the Spark SQL session configuration `spark.sql.lance.rewrite_columns`. + +**Using SQL:** + +To enable the feature for the current session: +```sql +SET spark.sql.lance.rewrite_columns = true; +``` + +To disable it: +```sql +SET spark.sql.lance.rewrite_columns = false; +``` + +### Behavior and Semantics + +- When `spark.sql.lance.rewrite_columns` is set to `true`, `MERGE INTO ... WHEN MATCHED UPDATE` operation will attempt to perform column-level updates. Instead of deleting the matched rows and inserting new versions, the engine will only write new versions of the changed columns. +- When the configuration is set to `false` (the default behavior), the operation fall back to rewriting the affected rows (a "delete and insert" operation). + +### Examples + +**MERGE INTO with RewriteColumns** + +When enabled, the `UPDATE` clause in a `MERGE INTO` statement will benefit from this optimization. + +```sql +-- Enable column rewrite mode +SET spark.sql.lance.rewrite_columns = true; + +-- This will update the 'status' and 'last_seen' columns without rewriting the whole row +MERGE INTO customers c +USING new_updates u +ON c.id = u.id +WHEN MATCHED THEN + UPDATE SET c.status = u.new_status, c.last_seen = u.timestamp; +``` + +### Notes + +- **Spark Version**: `MERGE INTO` operation are supported on Spark 3.5 and newer. +- **Nested Fields**: Updating nested fields follows the existing semantics of `MERGE INTO`. The entire top-level column containing the nested field will be rewritten. +- **Troubleshooting**: If you encounter any issues or unexpected behavior with this feature, you can disable it by setting `spark.sql.lance.rewrite_columns` to `false` to revert to the row-rewrite behavior. diff --git a/docs/src/operations/dml/update.md b/docs/src/operations/dml/update.md index d2e68e07..7252bbce 100644 --- a/docs/src/operations/dml/update.md +++ b/docs/src/operations/dml/update.md @@ -34,3 +34,56 @@ SET tags = ARRAY('ios', 'mobile') WHERE event_id = 1001; ``` +## Column Rewrite Mode + +`lance-spark` introduces a column rewrite mode for `UPDATE` operations, which can significantly improve performance for narrow updates that only affect a few columns. + +This mode allows the Lance data source to perform column-level updates by writing new data files for only the modified columns, avoiding the need to rewrite the entire data file (the "delete and insert" pattern). + +!!! warning "Spark Extension Required" +This feature requires the Lance Spark SQL extension to be enabled. +See [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +### Configuration + +You can enable or disable this feature using the Spark SQL session configuration `spark.sql.lance.rewrite_columns`. + +**Using SQL:** + +To enable the feature for the current session: +```sql +SET spark.sql.lance.rewrite_columns = true; +``` + +To disable it: +```sql +SET spark.sql.lance.rewrite_columns = false; +``` + +### Behavior and Semantics + +- When `spark.sql.lance.rewrite_columns` is set to `true`, `UPDATE` operations will attempt to perform column-level updates. Instead of deleting the matched rows and inserting new versions, the engine will only write new versions of the changed columns. +- When the configuration is set to `false` (the default behavior), the operations fall back to rewriting the affected rows (a "delete and insert" operation). + +### Examples + +**UPDATE with RewriteColumns** + +Here is an example of enabling the mode and performing an `UPDATE`. + +```sql +-- Enable column rewrite mode +SET spark.sql.lance.rewrite_columns = true; + +-- Assume 'users' table has columns: id, name, address +-- This operation will only write new data for the 'name' column +UPDATE users +SET name = 'New User Name' +WHERE id > 100; +``` + +### Notes + +- **Spark Version**: `UPDATE` operations are supported on Spark 3.5 and newer. +- **Nested Fields**: Updating nested fields follows the existing semantics of `UPDATE`. The entire top-level column containing the nested field will be rewritten. +- **Troubleshooting**: If you encounter any issues or unexpected behavior with this feature, you can disable it by setting `spark.sql.lance.rewrite_columns` to `false` to revert to the row-rewrite behavior. diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java index dd08b7d6..25037ea6 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java @@ -14,8 +14,10 @@ package org.lance.spark; import org.lance.spark.read.LanceScanBuilder; +import org.lance.spark.utils.SparkUtil; import org.lance.spark.write.SparkPositionDeltaWriteBuilder; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.read.ScanBuilder; @@ -26,9 +28,11 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import java.util.List; import java.util.Map; public class LancePositionDeltaOperation implements RowLevelOperation, SupportsDelta { + private final Command command; private final StructType sparkSchema; private final LanceSparkReadOptions readOptions; @@ -44,6 +48,8 @@ public class LancePositionDeltaOperation implements RowLevelOperation, SupportsD private final Map namespaceProperties; + private List updatedColumns; + public LancePositionDeltaOperation( Command command, StructType sparkSchema, @@ -82,6 +88,7 @@ public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { .build(); return new SparkPositionDeltaWriteBuilder( sparkSchema, + updatedColumns, writeOptions, initialStorageOptions, namespaceImpl, @@ -103,6 +110,10 @@ public NamedReference[] requiredMetadataAttributes() { @Override public boolean representUpdateAsDeleteAndInsert() { - return true; + return !SparkUtil.rewriteColumns(SparkSession.active()); + } + + public void setUpdatedColumns(List updatedColumns) { + this.updatedColumns = updatedColumns; } } diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java index 6378b1d7..55cead46 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java @@ -21,16 +21,25 @@ import org.lance.RowAddress; import org.lance.Transaction; import org.lance.WriteParams; +import org.lance.fragment.FragmentUpdateResult; import org.lance.io.StorageOptionsProvider; import org.lance.operation.Update; import org.lance.spark.LanceConstant; +import org.lance.spark.LanceDataset; import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.arrow.LanceArrowWriter; import org.lance.spark.function.LanceFragmentIdWithDefaultFunction; +import org.lance.spark.utils.SparkUtil; import com.google.common.collect.ImmutableList; import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.distributions.Distribution; import org.apache.spark.sql.connector.distributions.Distributions; @@ -48,21 +57,34 @@ import org.apache.spark.sql.connector.write.PhysicalWriteInfo; import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.LanceArrowUtils; import org.roaringbitmap.RoaringBitmap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.FutureTask; +import java.util.stream.Collectors; public class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrdering { + private static final Logger LOG = LoggerFactory.getLogger(SparkPositionDeltaWrite.class); + private final StructType sparkSchema; + private final List updatedColumns; private final LanceSparkWriteOptions writeOptions; /** @@ -79,12 +101,14 @@ public class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistribution public SparkPositionDeltaWrite( StructType sparkSchema, + List updatedColumns, LanceSparkWriteOptions writeOptions, Map initialStorageOptions, String namespaceImpl, Map namespaceProperties, List tableId) { this.sparkSchema = sparkSchema; + this.updatedColumns = updatedColumns; this.writeOptions = writeOptions; this.initialStorageOptions = initialStorageOptions; this.namespaceImpl = namespaceImpl; @@ -120,6 +144,7 @@ private class PositionDeltaBatchWrite implements DeltaBatchWrite { public DeltaWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { return new PositionDeltaWriteFactory( sparkSchema, + updatedColumns, writeOptions, initialStorageOptions, namespaceImpl, @@ -132,6 +157,7 @@ public void commit(WriterCommitMessage[] messages) { List removedFragmentIds = new ArrayList<>(); List updatedFragments = new ArrayList<>(); List newFragments = new ArrayList<>(); + Set fieldsModified = new HashSet<>(); Arrays.stream(messages) .map(m -> (DeltaWriteTaskCommit) m) @@ -140,6 +166,7 @@ public void commit(WriterCommitMessage[] messages) { removedFragmentIds.addAll(m.removedFragmentIds()); updatedFragments.addAll(m.updatedFragments()); newFragments.addAll(m.newFragments()); + fieldsModified.addAll(m.fieldsModified()); }); // Use SDK directly to update fragments @@ -149,6 +176,12 @@ public void commit(WriterCommitMessage[] messages) { .removedFragmentIds(removedFragmentIds) .updatedFragments(updatedFragments) .newFragments(newFragments) + .fieldsModified(fieldsModified.stream().mapToLong(Long::longValue).toArray()) + .updateMode( + Optional.of( + fieldsModified.isEmpty() + ? Update.UpdateMode.RewriteRows + : Update.UpdateMode.RewriteColumns)) .build(); try (Transaction txn = @@ -187,6 +220,7 @@ public void abort(WriterCommitMessage[] messages) {} private static class PositionDeltaWriteFactory implements DeltaWriterFactory { private final StructType sparkSchema; + private final List updatedColumns; private final LanceSparkWriteOptions writeOptions; /** @@ -203,12 +237,14 @@ private static class PositionDeltaWriteFactory implements DeltaWriterFactory { PositionDeltaWriteFactory( StructType sparkSchema, + List updatedColumns, LanceSparkWriteOptions writeOptions, Map initialStorageOptions, String namespaceImpl, Map namespaceProperties, List tableId) { this.sparkSchema = sparkSchema; + this.updatedColumns = updatedColumns; this.writeOptions = writeOptions; this.initialStorageOptions = initialStorageOptions; this.namespaceImpl = namespaceImpl; @@ -252,6 +288,8 @@ public DeltaWriter createWriter(int partitionId, long taskId) { fragmentCreationThread.start(); return new LanceDeltaWriter( + sparkSchema, + updatedColumns, writeOptions, new LanceDataWriter(writeBuffer, fragmentCreationTask, fragmentCreationThread), initialStorageOptions); @@ -286,6 +324,8 @@ private StorageOptionsProvider getStorageOptionsProvider() { } private static class LanceDeltaWriter implements DeltaWriter { + private final Dataset dataset; + private final LanceSparkWriteOptions writeOptions; private final LanceDataWriter writer; @@ -298,14 +338,60 @@ private static class LanceDeltaWriter implements DeltaWriter { // Key is fragmentId, Value is fragment's deleted row indexes private final Map deletedRows; + // Spark schema for updated columns + private Optional updatedSparkSchema; + + // Data stream arrow schema for updated columns + private Optional updatedArrowSchema; + + // Updated column ordinals in source input row + private Optional updatedColumnOrdinals; + + private Optional> fieldsModified; + private Optional> updatedFragments; + + private Optional currentUpdateFragmentId; + private Optional currentUpdateData; + private Optional currentUpdateWriter; + private LanceDeltaWriter( + StructType sparkSchema, + List updatedColumns, LanceSparkWriteOptions writeOptions, LanceDataWriter writer, Map initialStorageOptions) { + this.dataset = openDataset(writeOptions); this.writeOptions = writeOptions; this.writer = writer; this.initialStorageOptions = initialStorageOptions; this.deletedRows = new HashMap<>(); + + if (updatedColumns != null && !updatedColumns.isEmpty()) { + StructType schema = + new StructType( + updatedColumns.stream() + .map(sparkSchema::apply) + .collect(Collectors.toList()) + .toArray(new StructField[0])); + schema = + schema.add( + LanceDataset.ROW_ADDRESS_COLUMN.name(), + LanceDataset.ROW_ADDRESS_COLUMN.dataType(), + LanceDataset.ROW_ADDRESS_COLUMN.isNullable()); + + updatedSparkSchema = Optional.of(schema); + updatedArrowSchema = Optional.of(LanceArrowUtils.toArrowSchema(schema, "UTC", false)); + + updatedColumnOrdinals = + Optional.of( + updatedColumns.stream() + .map(sparkSchema::fieldIndex) + .mapToInt(Integer::intValue) + .toArray()); + } + fieldsModified = Optional.of(new HashSet<>()); + updatedFragments = Optional.of(new HashMap<>()); + currentUpdateFragmentId = Optional.of(-1); } @Override @@ -327,7 +413,78 @@ public void delete(InternalRow metadata, InternalRow id) throws IOException { @Override public void update(InternalRow metadata, InternalRow id, InternalRow row) throws IOException { - throw new UnsupportedOperationException("Update is not supported"); + if (updatedArrowSchema == null) { + throw new UnsupportedOperationException( + "Updated columns is empty. Maybe some bugs for updated columns extractor. " + + "You can set " + + SparkUtil.REWRITE_COLUMNS + + " to false to disable this feature."); + } + + int fragmentId = metadata.getInt(0); + if (currentUpdateFragmentId.get() != fragmentId) { + + // A new fragment is coming, update old fragment columns. + updateFragmentColumns(); + + // Initialize a new arrow batch writee for new fragment. + currentUpdateFragmentId = Optional.of(fragmentId); + currentUpdateData = + Optional.of( + VectorSchemaRoot.create(updatedArrowSchema.get(), LanceRuntime.allocator())); + currentUpdateWriter = + Optional.of( + org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create( + currentUpdateData.get(), updatedSparkSchema.get())); + } + + // Copy updated columns from source row to updated data row. + for (int i = 0; i < updatedColumnOrdinals.get().length; i++) { + currentUpdateWriter.get().field(i).write(row, updatedColumnOrdinals.get()[i]); + } + // Add row address column to updated data row. + currentUpdateWriter.get().field(updatedColumnOrdinals.get().length).write(id, 0); + } + + private void updateFragmentColumns() throws IOException { + if (currentUpdateFragmentId.get() == -1) { + return; + } + + LOG.info("Update columns for fragment: {}", currentUpdateFragmentId); + + currentUpdateWriter.get().finish(); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(currentUpdateData.get(), null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + byte[] arrowData = out.toByteArray(); + ByteArrayInputStream in = new ByteArrayInputStream(arrowData); + BufferAllocator allocator = LanceRuntime.allocator(); + + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader, stream); + + // Use Dataset to get the fragment and merge columns + Fragment fragment = new Fragment(dataset, currentUpdateFragmentId.get()); + FragmentUpdateResult result = + fragment.updateColumns( + stream, + LanceDataset.ROW_ADDRESS_COLUMN.name(), + LanceDataset.ROW_ADDRESS_COLUMN.name()); + + for (long fieldId : result.getFieldsModified()) { + fieldsModified.get().add(fieldId); + } + updatedFragments.get().put(currentUpdateFragmentId.get(), result.getUpdatedFragment()); + } + + currentUpdateData.get().close(); } @Override @@ -337,28 +494,47 @@ public void insert(InternalRow row) throws IOException { @Override public WriterCommitMessage commit() throws IOException { + updateFragmentColumns(); + // Write new fragments to store new updated rows. LanceBatchWrite.TaskCommit append = (LanceBatchWrite.TaskCommit) writer.commit(); List newFragments = append.getFragments(); List removedFragmentIds = new ArrayList<>(); - List updatedFragments = new ArrayList<>(); - // Deleting updated rows from old fragments using SDK directly. - try (Dataset dataset = openDataset(writeOptions)) { - this.deletedRows.forEach( - (fragmentId, rowIndexes) -> { - FragmentMetadata updatedFragment = - dataset.getFragment(fragmentId).deleteRows(ImmutableList.copyOf(rowIndexes)); - if (updatedFragment != null) { - updatedFragments.add(updatedFragment); - } else { - removedFragmentIds.add(Long.valueOf(fragmentId)); - } - }); - } + // Deleting updated rows from old fragments. + this.deletedRows.forEach( + (fragmentId, rowIndexes) -> { + FragmentMetadata updatedFragment = + dataset.getFragment(fragmentId).deleteRows(ImmutableList.copyOf(rowIndexes)); + + if (updatedFragment != null) { + updatedFragments + .get() + .compute( + fragmentId, + (k, v) -> { + if (v == null) { + return updatedFragment; + } else { + return new FragmentMetadata( + v.getId(), + v.getFiles(), + v.getPhysicalRows(), + updatedFragment.getDeletionFile(), + v.getRowIdMeta()); + } + }); + } else { + removedFragmentIds.add(Long.valueOf(fragmentId)); + } + }); - return new DeltaWriteTaskCommit(removedFragmentIds, updatedFragments, newFragments); + return new DeltaWriteTaskCommit( + removedFragmentIds, + new ArrayList<>(updatedFragments.get().values()), + fieldsModified.get(), + newFragments); } private Dataset openDataset(LanceSparkWriteOptions options) { @@ -380,25 +556,30 @@ private Dataset openDataset(LanceSparkWriteOptions options) { @Override public void abort() throws IOException { writer.abort(); + dataset.close(); } @Override public void close() throws IOException { writer.close(); + dataset.close(); } } private static class DeltaWriteTaskCommit implements WriterCommitMessage { private List removedFragmentIds; private List updatedFragments; + private Set fieldsModified; private List newFragments; DeltaWriteTaskCommit( List removedFragmentIds, List updatedFragments, + Set fieldsModified, List newFragments) { this.removedFragmentIds = removedFragmentIds; this.updatedFragments = updatedFragments; + this.fieldsModified = fieldsModified; this.newFragments = newFragments; } @@ -410,6 +591,10 @@ public List updatedFragments() { return updatedFragments == null ? Collections.emptyList() : updatedFragments; } + public Set fieldsModified() { + return fieldsModified; + } + public List newFragments() { return newFragments == null ? Collections.emptyList() : newFragments; } diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWriteBuilder.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWriteBuilder.java index d0794993..0076d72d 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWriteBuilder.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWriteBuilder.java @@ -37,15 +37,18 @@ public class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { private final Map namespaceProperties; private final List tableId; + private final List updatedColumns; public SparkPositionDeltaWriteBuilder( StructType sparkSchema, + List updatedColumns, LanceSparkWriteOptions writeOptions, Map initialStorageOptions, String namespaceImpl, Map namespaceProperties, List tableId) { this.sparkSchema = sparkSchema; + this.updatedColumns = updatedColumns; this.writeOptions = writeOptions; this.initialStorageOptions = initialStorageOptions; this.namespaceImpl = namespaceImpl; @@ -56,6 +59,7 @@ public SparkPositionDeltaWriteBuilder( public DeltaWrite build() { return new SparkPositionDeltaWrite( sparkSchema, + updatedColumns, writeOptions, initialStorageOptions, namespaceImpl, diff --git a/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala new file mode 100644 index 00000000..14851281 --- /dev/null +++ b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeRows, Project, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Keep +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.RowDeltaUtils +import org.lance.spark.LancePositionDeltaOperation +import org.lance.spark.utils.SparkUtil + +import scala.collection.JavaConverters._ + +class UpdateColumnsExtractor(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case wd: WriteDelta if SparkUtil.rewriteColumns(session) => + try { + wd.operation match { + case op: LancePositionDeltaOperation => + val targetAttrs = wd.table.output + val targetColOrdinals = wd.projections.rowProjection match { + case Some(ProjectingInternalRow(_, colOrdinals: Seq[Int])) => colOrdinals + case _ => Seq.empty + } + + val updatedColumns = wd.query match { + case m: MergeRows => + extractMergeRowsUpdatedColumns(m, targetColOrdinals, targetAttrs) + + case p: Project => + extractProjectUpdatedColumns(p, targetColOrdinals, targetAttrs) + + case _ => Seq.empty + } + + op.setUpdatedColumns(updatedColumns.distinct.asJava) + + case _ => + } + + } catch { + case e: Exception => { + throw new RuntimeException( + "Error when extract updated columns, please set `" + SparkUtil.REWRITE_COLUMNS + "` to `false` do disable lance's RewriteColumns mode", + e) + } + } + + wd + } + + /** + * Extracts the names of columns that are updated from a MergeRows logical plan. + * It checks the merge instructions for update operations and compares target attributes with the output expressions + * to identify changed columns. + * @param mergeRows The MergeRows logical plan to process + * @param targetColOrdinals The ordinals of columns to consider + * @param targetAttrs The target table attributes + * @return Sequence of updated column names + */ + private def extractMergeRowsUpdatedColumns( + mergeRows: MergeRows, + targetColOrdinals: Seq[Int], + targetAttrs: Seq[Attribute]): Seq[String] = { + val actions = + mergeRows.matchedInstructions ++ mergeRows.notMatchedInstructions ++ mergeRows.notMatchedBySourceInstructions + + val operationColIndex = + mergeRows.output.indexWhere(_.name.equals(RowDeltaUtils.OPERATION_COLUMN)) + + actions.flatMap { + case Keep(_, output) => + val operation = output(operationColIndex).asInstanceOf[Literal].value + operation match { + case RowDeltaUtils.UPDATE_OPERATION => + // Only check update operation + + targetAttrs.zipWithIndex.flatMap { + case (attr, idx) if !attr.semanticEquals(output(targetColOrdinals(idx))) => + Some(attr.name) + case _ => None + } + + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + /** + * Extracts the names of columns that are updated from a Project logical plan. + * It checks the operation column to determine if it's an update and compares target attributes with the project expressions + * to identify changed columns. + * @param project The Project logical plan to process + * @param targetColOrdinals The ordinals of columns to consider + * @param targetAttrs The target table attributes + * @return Sequence of updated column names + */ + private def extractProjectUpdatedColumns( + project: Project, + targetColOrdinals: Seq[Int], + targetAttrs: Seq[Attribute]): Seq[String] = { + val projections = project.projectList + val operationColIndex = projections.indexWhere(_.name.equals(RowDeltaUtils.OPERATION_COLUMN)) + + if (operationColIndex == -1) return Seq.empty + + val operation = projections(operationColIndex).eval() + operation match { + case RowDeltaUtils.UPDATE_OPERATION => + // Only check update operation + + targetAttrs.zipWithIndex.flatMap { + case (attr, idx) if !attr.semanticEquals(projections(targetColOrdinals(idx))) => + Some(attr.name) + case _ => None + } + + case _ => Seq.empty + } + } + +} diff --git a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..30e64ff9 100644 --- a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -13,8 +13,8 @@ */ package org.lance.spark.extensions -import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.optimizer.{LanceFragmentAwareJoinRule, UpdateColumnsExtractor} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,11 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // optimizer rules for update using RewriteColumns mode + extensions.injectOptimizerRule { session: SparkSession => + new UpdateColumnsExtractor(session) + } + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/MergeIntoRewriteRowsTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/MergeIntoRewriteRowsTest.java new file mode 100644 index 00000000..f5f3b591 --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/MergeIntoRewriteRowsTest.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.update; + +import org.lance.spark.utils.SparkUtil; + +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.BeforeEach; + +public class MergeIntoRewriteRowsTest extends BaseMergeIntoTest { + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-merge-into-distribution-test") + .master("local[4]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", tempDir.toString()) + .config("spark.sql.shuffle.partitions", String.valueOf(SHUFFLE_PARTITIONS)) + .config("spark.sql.adaptive.enabled", "false") + .config("spark.default.parallelism", String.valueOf(SHUFFLE_PARTITIONS)) + .config("spark.ui.enabled", "false") + .getOrCreate(); + + // Use RewriteRows mode to do update/merge-into + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "false"); + + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + } +} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java new file mode 100644 index 00000000..f0878f4c --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.update; + +import org.lance.spark.utils.SparkUtil; + +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Map; + +public class UpdateTableRewriteRowsTest extends UpdateTableTest { + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-namespace-test") + .master("local[4]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", getNsImpl()) + .getOrCreate(); + + Map additionalConfigs = getAdditionalNsConfigs(); + for (Map.Entry entry : additionalConfigs.entrySet()) { + spark.conf().set("spark.sql.catalog." + catalogName + "." + entry.getKey(), entry.getValue()); + } + + // Use RewriteRows mode to do update/merge-into + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "false"); + + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + } + + @Test + public void testUpdateChildSomeRows() { + // Because Fragment.updateColumns can't accept null for struct column. + // So update for partial rows must use RewriteRow mode. + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + + op.insert( + Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(300, 301)))); + + op.updateStructValue(101, "id = 1"); + op.check( + Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 101, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(300, 301)))); + } +} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java index 2edf8f8a..87cc3fae 100644 --- a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java @@ -13,4 +13,36 @@ */ package org.lance.spark.update; -public class UpdateTableTest extends BaseUpdateTableTest {} +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.jupiter.api.BeforeEach; +import org.lance.spark.utils.SparkUtil; + +import java.util.Map; + +public class UpdateTableTest extends BaseUpdateTableTest { + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-namespace-test") + .master("local[4]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", getNsImpl()) + .getOrCreate(); + + Map additionalConfigs = getAdditionalNsConfigs(); + for (Map.Entry entry : additionalConfigs.entrySet()) { + spark.conf().set("spark.sql.catalog." + catalogName + "." + entry.getKey(), entry.getValue()); + } + + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "true"); + + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + // Create default namespace for multi-level namespace mode + spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); + } +} diff --git a/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala new file mode 100644 index 00000000..14851281 --- /dev/null +++ b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateColumnsExtractor.scala @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeRows, Project, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Keep +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.RowDeltaUtils +import org.lance.spark.LancePositionDeltaOperation +import org.lance.spark.utils.SparkUtil + +import scala.collection.JavaConverters._ + +class UpdateColumnsExtractor(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case wd: WriteDelta if SparkUtil.rewriteColumns(session) => + try { + wd.operation match { + case op: LancePositionDeltaOperation => + val targetAttrs = wd.table.output + val targetColOrdinals = wd.projections.rowProjection match { + case Some(ProjectingInternalRow(_, colOrdinals: Seq[Int])) => colOrdinals + case _ => Seq.empty + } + + val updatedColumns = wd.query match { + case m: MergeRows => + extractMergeRowsUpdatedColumns(m, targetColOrdinals, targetAttrs) + + case p: Project => + extractProjectUpdatedColumns(p, targetColOrdinals, targetAttrs) + + case _ => Seq.empty + } + + op.setUpdatedColumns(updatedColumns.distinct.asJava) + + case _ => + } + + } catch { + case e: Exception => { + throw new RuntimeException( + "Error when extract updated columns, please set `" + SparkUtil.REWRITE_COLUMNS + "` to `false` do disable lance's RewriteColumns mode", + e) + } + } + + wd + } + + /** + * Extracts the names of columns that are updated from a MergeRows logical plan. + * It checks the merge instructions for update operations and compares target attributes with the output expressions + * to identify changed columns. + * @param mergeRows The MergeRows logical plan to process + * @param targetColOrdinals The ordinals of columns to consider + * @param targetAttrs The target table attributes + * @return Sequence of updated column names + */ + private def extractMergeRowsUpdatedColumns( + mergeRows: MergeRows, + targetColOrdinals: Seq[Int], + targetAttrs: Seq[Attribute]): Seq[String] = { + val actions = + mergeRows.matchedInstructions ++ mergeRows.notMatchedInstructions ++ mergeRows.notMatchedBySourceInstructions + + val operationColIndex = + mergeRows.output.indexWhere(_.name.equals(RowDeltaUtils.OPERATION_COLUMN)) + + actions.flatMap { + case Keep(_, output) => + val operation = output(operationColIndex).asInstanceOf[Literal].value + operation match { + case RowDeltaUtils.UPDATE_OPERATION => + // Only check update operation + + targetAttrs.zipWithIndex.flatMap { + case (attr, idx) if !attr.semanticEquals(output(targetColOrdinals(idx))) => + Some(attr.name) + case _ => None + } + + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + /** + * Extracts the names of columns that are updated from a Project logical plan. + * It checks the operation column to determine if it's an update and compares target attributes with the project expressions + * to identify changed columns. + * @param project The Project logical plan to process + * @param targetColOrdinals The ordinals of columns to consider + * @param targetAttrs The target table attributes + * @return Sequence of updated column names + */ + private def extractProjectUpdatedColumns( + project: Project, + targetColOrdinals: Seq[Int], + targetAttrs: Seq[Attribute]): Seq[String] = { + val projections = project.projectList + val operationColIndex = projections.indexWhere(_.name.equals(RowDeltaUtils.OPERATION_COLUMN)) + + if (operationColIndex == -1) return Seq.empty + + val operation = projections(operationColIndex).eval() + operation match { + case RowDeltaUtils.UPDATE_OPERATION => + // Only check update operation + + targetAttrs.zipWithIndex.flatMap { + case (attr, idx) if !attr.semanticEquals(projections(targetColOrdinals(idx))) => + Some(attr.name) + case _ => None + } + + case _ => Seq.empty + } + } + +} diff --git a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..30e64ff9 100644 --- a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -13,8 +13,8 @@ */ package org.lance.spark.extensions -import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.optimizer.{LanceFragmentAwareJoinRule, UpdateColumnsExtractor} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,11 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // optimizer rules for update using RewriteColumns mode + extensions.injectOptimizerRule { session: SparkSession => + new UpdateColumnsExtractor(session) + } + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SparkUtil.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SparkUtil.java index 75311158..7720ab0f 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SparkUtil.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SparkUtil.java @@ -17,6 +17,9 @@ import org.apache.spark.sql.SparkSession; public class SparkUtil { + private static final String SPARK_LANCE_CONF_PREFIX = "spark.sql.lance"; + + public static final String REWRITE_COLUMNS = SPARK_LANCE_CONF_PREFIX + ".rewrite_columns"; private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; // Format string used as the prefix for Spark configuration keys to override Hadoop configuration @@ -66,4 +69,14 @@ public static Configuration hadoopConfCatalogOverrides(SparkSession spark, Strin private static String hadoopConfPrefixForCatalog(String catalogName) { return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); } + + /** + * Check if update/merge-into use RewriteColumns mode or not. Default value is false. + * + * @param spark spark current session + * @return true means that RewriteColumns mode is used. false means that RewriteRows mode is used. + */ + public static boolean rewriteColumns(SparkSession spark) { + return Boolean.parseBoolean(spark.conf().get(REWRITE_COLUMNS, "false")); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseMergeIntoTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseMergeIntoTest.java index e871a6c1..0bbddedb 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseMergeIntoTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseMergeIntoTest.java @@ -13,6 +13,8 @@ */ package org.lance.spark.update; +import org.lance.spark.utils.SparkUtil; + import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.TableCatalog; @@ -24,11 +26,14 @@ import java.nio.file.Path; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.UUID; +import java.util.stream.Collectors; public abstract class BaseMergeIntoTest { - private static final int SHUFFLE_PARTITIONS = 4; + protected static final int SHUFFLE_PARTITIONS = 4; protected SparkSession spark; protected TableCatalog catalog; @@ -44,6 +49,8 @@ void setup() { .master("local[4]") .config( "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") .config("spark.sql.catalog." + catalogName + ".impl", "dir") .config("spark.sql.catalog." + catalogName + ".root", tempDir.toString()) .config("spark.sql.shuffle.partitions", String.valueOf(SHUFFLE_PARTITIONS)) @@ -52,6 +59,8 @@ void setup() { .config("spark.ui.enabled", "false") .getOrCreate(); + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "true"); + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); // Create default namespace for multi-level namespace mode spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); @@ -64,6 +73,13 @@ void tearDown() { } } + private List baseRows() { + return Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(300, 301))); + } + @Test public void testMergeIntoInsertDistributionOnNullSegmentId() { String tableName = "merge_dist_" + UUID.randomUUID().toString().replace("-", ""); @@ -222,4 +238,417 @@ public void testMergeInto() { RowFactory.create(101, 1010)); Assertions.assertEquals(expected, actual, "Expected merged rows to match result set"); } + + @Test + public void testBasicMatchedUpdate() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(1, "Alice_new", 500, "Alice_meta", 500, Arrays.asList(500, 501)), + Row.of(2, "Bob_new", 400, "Bob_meta", 400, Arrays.asList(400, 401)), + Row.of(3, "Charlie_new", 600, "Charlie_meta", 600, Arrays.asList(600, 601))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET " + + "t.value = s.value + 1, " + + "t.name = s.name, " + + "t.values = s.values"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + Row.of(1, "Alice_new", 501, "Alice", 100, Arrays.asList(500, 501)), + Row.of(2, "Bob_new", 401, "Bob", 200, Arrays.asList(400, 401)), + Row.of(3, "Charlie_new", 601, "Charlie", 300, Arrays.asList(600, 601)))); + } + + @Test + public void testConditionalMatchedUpdate() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(1, "Alice_src", 100, "Alice_src", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob_src", 250, "Bob_src", 250, Arrays.asList(250, 251)), + Row.of(3, "Charlie_src", 350, "Charlie_src", 350, Arrays.asList(350, 351))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED AND s.value >= 250 THEN UPDATE SET " + + "t.value = s.value + 1"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + // id = 1 does not meet condition, unchanged + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + // id = 2 and 3 updated on value only + Row.of(2, "Bob", 251, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 351, "Charlie", 300, Arrays.asList(300, 301)))); + } + + @Test + public void testConditionalMatchedDelete() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(1, "Alice_src", 100, "Alice_src", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob_src", 200, "Bob_src", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie_src", 300, "Charlie_src", 300, Arrays.asList(300, 301))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED AND s.value >= 300 THEN DELETE"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)))); + } + + @Test + public void testNotMatchedInsert() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(4, "Dave", 400, "Dave", 400, Arrays.asList(400, 401)), + Row.of(5, "Eve", 500, "Eve", 500, Arrays.asList(500, 501))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN NOT MATCHED THEN INSERT (id, name, value, meta, values) " + + "VALUES (s.id, s.name, s.value, s.meta, s.values)"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(300, 301)), + Row.of(4, "Dave", 400, "Dave", 400, Arrays.asList(400, 401)), + Row.of(5, "Eve", 500, "Eve", 500, Arrays.asList(500, 501)))); + } + + @Test + public void testConditionalNotMatchedInsert() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(4, "Dave", 100, "Dave", 100, Arrays.asList(100, 101)), + Row.of(5, "Eve", 400, "Eve", 400, Arrays.asList(400, 401))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN NOT MATCHED AND s.value >= 300 THEN INSERT (id, name, value, meta, values) " + + "VALUES (s.id, s.name, s.value, s.meta, s.values)"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101)), + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(300, 301)), + // only Eve is inserted because of the condition + Row.of(5, "Eve", 400, "Eve", 400, Arrays.asList(400, 401)))); + } + + @Test + public void testMatchedBranchPriority() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + + List initialRows = + Collections.singletonList(Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101))); + op.insert(initialRows); + + List srcRows = + Collections.singletonList(Row.of(1, "Source", 100, "Source", 100, Arrays.asList(100, 101))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED AND s.value = 100 THEN UPDATE SET t.value = 1000 " + + "WHEN MATCHED AND s.value = 100 THEN UPDATE SET t.value = 2000"; + + spark.sql(mergeSql); + + op.check( + Collections.singletonList( + // first WHEN MATCHED branch should win + Row.of(1, "Alice", 1000, "Alice", 100, Arrays.asList(100, 101)))); + } + + @Test + public void testNoMatchWithoutInsertKeepsTargetUnchanged() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + Row.of(4, "Dave", 400, "Dave", 400, Arrays.asList(400, 401)), + Row.of(5, "Eve", 500, "Eve", 500, Arrays.asList(500, 501))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET t.value = s.value + 1"; + + spark.sql(mergeSql); + + op.check(baseRows()); + } + + @Test + public void testSourceDuplicateRowsBehavior() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + + List initialRows = + Collections.singletonList(Row.of(1, "Alice", 100, "Alice", 100, Arrays.asList(100, 101))); + op.insert(initialRows); + + List srcRows = + Arrays.asList( + Row.of(1, "Source1", 200, "Source1", 200, Arrays.asList(200, 201)), + Row.of(1, "Source2", 300, "Source2", 300, Arrays.asList(300, 301))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET t.value = s.value"; + + Exception e = + Assertions.assertThrows( + Exception.class, + () -> { + spark.sql(mergeSql); + }); + System.out.println( + "Merge source duplicate rows behavior: " + e.getClass().getName() + ": " + e.getMessage()); + } + + @Test + public void testMatchedUpdateDeleteAndNotMatchedInsertInSingleMerge() { + TableOperator op = new TableOperator(spark, catalogName); + op.create(); + op.insert(baseRows()); + + List srcRows = + Arrays.asList( + // matched for delete + Row.of(3, "Charlie_src", 350, "Charlie_src", 350, Arrays.asList(350, 351)), + // matched for update + Row.of(1, "Alice_src", 250, "Alice_src", 250, Arrays.asList(250, 251)), + // not matched for insert + Row.of(4, "David", 400, "David", 400, Arrays.asList(400, 401))); + op.createOrReplaceTempView("src", srcRows); + + String mergeSql = + "MERGE INTO " + + op.tableRef() + + " t " + + "USING src s " + + "ON t.id = s.id " + + "WHEN MATCHED AND s.value >= 300 THEN DELETE " + + "WHEN MATCHED AND s.value < 300 THEN UPDATE SET " + + "t.value = s.value + 1, " + + "t.name = s.name, " + + "t.values = s.values " + + "WHEN NOT MATCHED THEN INSERT (id, name, value, meta, values) " + + "VALUES (s.id, s.name, s.value, s.meta, s.values)"; + + spark.sql(mergeSql); + + op.check( + Arrays.asList( + // id=1 updated from src, meta remains from target + Row.of(1, "Alice_src", 251, "Alice", 100, Arrays.asList(250, 251)), + // id=2 unchanged + Row.of(2, "Bob", 200, "Bob", 200, Arrays.asList(200, 201)), + // id=3 deleted + // id=4 inserted with meta and values from src + Row.of(4, "David", 400, "David", 400, Arrays.asList(400, 401)))); + } + + protected static class TableOperator { + private final SparkSession spark; + private final String catalogName; + private final String tableName; + + public TableOperator(SparkSession spark, String catalogName) { + this.spark = spark; + this.catalogName = catalogName; + String baseName = "merge_test_table"; + this.tableName = baseName + "_" + UUID.randomUUID().toString().replace("-", ""); + } + + public void create() { + spark.sql( + "CREATE TABLE " + + catalogName + + ".default." + + tableName + + " (id INT NOT NULL, name STRING, value INT, meta STRUCT, values ARRAY)"); + } + + public void insert(List rows) { + String sql = + String.format( + "INSERT INTO %s.default.%s VALUES %s", + catalogName, + tableName, + rows.stream().map(Row::insertSql).collect(Collectors.joining(", "))); + spark.sql(sql); + } + + public void createOrReplaceTempView(String viewName, List rows) { + String valuesSql = rows.stream().map(Row::insertSql).collect(Collectors.joining(", ")); + String sql = + String.format( + "CREATE OR REPLACE TEMP VIEW %s AS SELECT * FROM VALUES %s AS %s(id, name, value, meta, values)", + viewName, valuesSql, viewName); + spark.sql(sql); + } + + public String tableRef() { + return catalogName + ".default." + tableName; + } + + public void check(List expected) { + String sql = String.format("Select * from %s.default.%s order by id", catalogName, tableName); + List actual = + spark.sql(sql).collectAsList().stream() + .map( + row -> + Row.of( + row.getInt(0), + row.getString(1), + row.getInt(2), + row.getStruct(3).getString(0), + row.getStruct(3).getInt(1), + row.getList(4))) + .collect(Collectors.toList()); + Assertions.assertEquals(expected, actual); + } + } + + protected static class Row { + int id; + String name; + int value; + String metaName; + int metaValue; + List values; + + protected static Row of( + int id, String name, int value, String metaName, int metaValue, List values) { + Row row = new Row(); + row.id = id; + row.name = name; + row.value = value; + row.metaName = metaName; + row.metaValue = metaValue; + row.values = values; + return row; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Row row = (Row) o; + return id == row.id + && value == row.value + && metaValue == row.metaValue + && Objects.equals(name, row.name) + && Objects.equals(metaName, row.metaName) + && Objects.deepEquals(values, row.values); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, value, metaName, metaValue, values); + } + + @Override + public String toString() { + return String.format( + "Row(id=%s, name=%s, value=%s, metaName=%s, metaValue=%s, values=%s)", + id, name, value, metaName, metaValue, values); + } + + private String insertSql() { + return String.format( + "(%d, '%s', %d, NAMED_STRUCT('name', '%s', 'value', %d), ARRAY(%s))", + id, + name, + value, + metaName, + metaValue, + values.stream().map(String::valueOf).collect(Collectors.joining(","))); + } + } } From ececb75d423b6fe3118e6bbb614984d83cf25284 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 18 Mar 2026 11:21:48 +0800 Subject: [PATCH 2/3] rebase main --- .../update/UpdateTableRewriteRowsTest.java | 2 +- .../lance/spark/update/UpdateTableTest.java | 27 ++++++++++++++++++- .../lance/spark/update/UpdateTableTest.java | 27 ++++++++++++++++++- .../spark/update/BaseUpdateTableTest.java | 6 ++--- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java index f0878f4c..d24609aa 100644 --- a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableRewriteRowsTest.java @@ -23,7 +23,7 @@ import java.util.Arrays; import java.util.Map; -public class UpdateTableRewriteRowsTest extends UpdateTableTest { +public class UpdateTableRewriteRowsTest extends BaseUpdateTableTest { @BeforeEach void setup() { spark = diff --git a/lance-spark-4.0_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java b/lance-spark-4.0_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java index 2edf8f8a..99de88a4 100644 --- a/lance-spark-4.0_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java +++ b/lance-spark-4.0_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java @@ -13,4 +13,29 @@ */ package org.lance.spark.update; -public class UpdateTableTest extends BaseUpdateTableTest {} +public class UpdateTableTest extends BaseUpdateTableTest { + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-namespace-test") + .master("local[4]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", getNsImpl()) + .getOrCreate(); + + Map additionalConfigs = getAdditionalNsConfigs(); + for (Map.Entry entry : additionalConfigs.entrySet()) { + spark.conf().set("spark.sql.catalog." + catalogName + "." + entry.getKey(), entry.getValue()); + } + + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "true"); + + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + // Create default namespace for multi-level namespace mode + spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); + } +} diff --git a/lance-spark-4.1_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java b/lance-spark-4.1_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java index 2edf8f8a..99de88a4 100644 --- a/lance-spark-4.1_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java +++ b/lance-spark-4.1_2.13/src/test/java/org/lance/spark/update/UpdateTableTest.java @@ -13,4 +13,29 @@ */ package org.lance.spark.update; -public class UpdateTableTest extends BaseUpdateTableTest {} +public class UpdateTableTest extends BaseUpdateTableTest { + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-namespace-test") + .master("local[4]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", getNsImpl()) + .getOrCreate(); + + Map additionalConfigs = getAdditionalNsConfigs(); + for (Map.Entry entry : additionalConfigs.entrySet()) { + spark.conf().set("spark.sql.catalog." + catalogName + "." + entry.getKey(), entry.getValue()); + } + + spark.conf().set(SparkUtil.REWRITE_COLUMNS, "true"); + + catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + // Create default namespace for multi-level namespace mode + spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseUpdateTableTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseUpdateTableTest.java index c24db07e..dd88a6d5 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseUpdateTableTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseUpdateTableTest.java @@ -378,7 +378,7 @@ public void testTransformArrayAllRows() { Row.of(3, "Charlie", 300, "Charlie", 300, Arrays.asList(301, 302)))); } - private static class TableOperator { + protected static class TableOperator { private final SparkSession spark; private final String catalogName; private final String tableName; @@ -476,7 +476,7 @@ public void check(List expected) { } } - private static class Row { + protected static class Row { int id; String name; int value; @@ -484,7 +484,7 @@ private static class Row { int metaValue; List values; - private static Row of( + protected static Row of( int id, String name, int value, String metaName, int metaValue, List values) { Row row = new Row(); row.id = id; From 7039168334ea3223303de2e8057ce4d9d06fbd01 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 18 Mar 2026 11:24:11 +0800 Subject: [PATCH 3/3] fix codestyle --- .../src/test/java/org/lance/spark/update/UpdateTableTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java index 87cc3fae..57be4a87 100644 --- a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/update/UpdateTableTest.java @@ -13,10 +13,11 @@ */ package org.lance.spark.update; +import org.lance.spark.utils.SparkUtil; + import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.junit.jupiter.api.BeforeEach; -import org.lance.spark.utils.SparkUtil; import java.util.Map;