diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs index eb58082fe5..5a9845cde0 100644 --- a/crates/integrations/datafusion/src/physical_plan/mod.rs +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod metadata_scan; pub(crate) mod project; pub(crate) mod repartition; pub(crate) mod scan; +pub(crate) mod sort; pub(crate) mod write; pub(crate) const DATA_FILES_COL_NAME: &str = "data_files"; diff --git a/crates/integrations/datafusion/src/physical_plan/sort.rs b/crates/integrations/datafusion/src/physical_plan/sort.rs new file mode 100644 index 0000000000..2a57e16e43 --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/sort.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partition-based sorting for Iceberg tables. + +use std::sync::Arc; + +use datafusion::arrow::compute::SortOptions; +use datafusion::common::Result as DFResult; +use datafusion::error::DataFusionError; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::sorts::sort::SortExec; +use iceberg::arrow::PROJECTED_PARTITION_VALUE_COLUMN; + +/// Sorts an ExecutionPlan by partition values for Iceberg tables. +/// +/// This function takes an input ExecutionPlan that has been extended with partition values +/// (via `project_with_partition`) and returns a SortExec that sorts by the partition column. +/// The partition values are expected to be in a struct column named `PROJECTED_PARTITION_VALUE_COLUMN`. +/// +/// For unpartitioned tables or plans without the partition column, returns an error. +/// +/// # Arguments +/// * `input` - The input ExecutionPlan with projected partition values +/// +/// # Returns +/// * `Ok(Arc)` - A SortExec that sorts by partition values +/// * `Err` - If the partition column is not found +/// +/// TODO remove dead_code mark when integrating with insert_into +#[allow(dead_code)] +pub(crate) fn sort_by_partition(input: Arc) -> DFResult> { + let schema = input.schema(); + + // Find the partition column in the schema + let (partition_column_index, _partition_field) = schema + .column_with_name(PROJECTED_PARTITION_VALUE_COLUMN) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "Partition column '{}' not found in schema. Ensure the plan has been extended with partition values using project_with_partition.", + PROJECTED_PARTITION_VALUE_COLUMN + )) + })?; + + // Create a single sort expression for the partition column + let column_expr = Arc::new(Column::new( + PROJECTED_PARTITION_VALUE_COLUMN, + partition_column_index, + )); + + let sort_expr = PhysicalSortExpr { + expr: column_expr, + options: SortOptions::default(), // Ascending, nulls last + }; + + // Create a SortExec with preserve_partitioning=true to ensure the output partitioning + // is the same as the input partitioning, and the data is sorted within each partition + let lex_ordering = LexOrdering::new(vec![sort_expr]).ok_or_else(|| { + DataFusionError::Plan("Failed to create LexOrdering from sort expression".to_string()) + })?; + + let sort_exec = SortExec::new(lex_ordering, input).with_preserve_partitioning(true); + + Ok(Arc::new(sort_exec)) +} + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray, StructArray}; + use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema as ArrowSchema}; + use datafusion::datasource::{MemTable, TableProvider}; + use datafusion::prelude::SessionContext; + + use super::*; + + #[tokio::test] + async fn test_sort_by_partition_basic() { + // Create a schema with a partition column + let partition_fields = + Fields::from(vec![Field::new("id_partition", DataType::Int32, false)]); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new( + PROJECTED_PARTITION_VALUE_COLUMN, + DataType::Struct(partition_fields.clone()), + false, + ), + ])); + + // Create test data with partition values + let id_array = Arc::new(Int32Array::from(vec![3, 1, 2])); + let name_array = Arc::new(StringArray::from(vec!["c", "a", "b"])); + let partition_array = Arc::new(StructArray::from(vec![( + Arc::new(Field::new("id_partition", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![3, 1, 2])) as _, + )])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![id_array, name_array, partition_array]) + .unwrap(); + + let ctx = SessionContext::new(); + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); + let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap(); + + // Apply sort + let sorted_plan = sort_by_partition(input).unwrap(); + + // Execute and verify + let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx()) + .await + .unwrap(); + + assert_eq!(result.len(), 1); + let result_batch = &result[0]; + + let id_col = result_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify data is sorted by partition value + assert_eq!(id_col.value(0), 1); + assert_eq!(id_col.value(1), 2); + assert_eq!(id_col.value(2), 3); + } + + #[tokio::test] + async fn test_sort_by_partition_missing_column() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ]) + .unwrap(); + + let ctx = SessionContext::new(); + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); + let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap(); + + let result = sort_by_partition(input); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Partition column '_partition' not found") + ); + } + + #[tokio::test] + async fn test_sort_by_partition_multi_field() { + // Test with multiple partition fields in the struct + let partition_fields = Fields::from(vec![ + Field::new("year", DataType::Int32, false), + Field::new("month", DataType::Int32, false), + ]); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("data", DataType::Utf8, false), + Field::new( + PROJECTED_PARTITION_VALUE_COLUMN, + DataType::Struct(partition_fields.clone()), + false, + ), + ])); + + // Create test data with partition values (year, month) + let id_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let data_array = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + + // Partition values: (2024, 2), (2024, 1), (2023, 12), (2024, 1) + let year_array = Arc::new(Int32Array::from(vec![2024, 2024, 2023, 2024])); + let month_array = Arc::new(Int32Array::from(vec![2, 1, 12, 1])); + + let partition_array = Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("year", DataType::Int32, false)), + year_array as _, + ), + ( + Arc::new(Field::new("month", DataType::Int32, false)), + month_array as _, + ), + ])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![id_array, data_array, partition_array]) + .unwrap(); + + let ctx = SessionContext::new(); + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); + let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap(); + + // Apply sort + let sorted_plan = sort_by_partition(input).unwrap(); + + // Execute and verify + let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx()) + .await + .unwrap(); + + assert_eq!(result.len(), 1); + let result_batch = &result[0]; + + let id_col = result_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify data is sorted by partition value (struct comparison) + // Expected order: (2023, 12), (2024, 1), (2024, 1), (2024, 2) + // Which corresponds to ids: 3, 2, 4, 1 + assert_eq!(id_col.value(0), 3); + assert_eq!(id_col.value(1), 2); + assert_eq!(id_col.value(2), 4); + assert_eq!(id_col.value(3), 1); + } +}