diff --git a/Cargo.lock b/Cargo.lock index a5039b6..faf8852 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4859,6 +4859,50 @@ dependencies = [ "zstd", ] +[[package]] +name = "parquet-variant" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "066f3e371a47d5b0bfd9491992f2400e4f155511d0c9bcfaca04bac7f1e33600" +dependencies = [ + "arrow-schema", + "chrono", + "half", + "indexmap 2.12.1", + "simdutf8", + "uuid", +] + +[[package]] +name = "parquet-variant-compute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497db08281ed670f7e5119116b961282d50645a28b2afb97686f18e2e4034414" +dependencies = [ + "arrow", + "arrow-schema", + "chrono", + "half", + "indexmap 2.12.1", + "parquet-variant", + "parquet-variant-json", + "uuid", +] + +[[package]] +name = "parquet-variant-json" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9958f58a045ec273727651a3d07e32b382a7732f7cae86d26d80e9ec0acf2e3b" +dependencies = [ + "arrow-schema", + "base64", + "chrono", + "parquet-variant", + "serde_json", + "uuid", +] + [[package]] name = "paste" version = "1.0.15" @@ -6804,6 +6848,9 @@ dependencies = [ "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", + "parquet-variant", + "parquet-variant-compute", + "parquet-variant-json", "rand 0.9.2", "regex", "scopeguard", diff --git a/Cargo.toml b/Cargo.toml index 9c64646..92d6e4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,10 @@ bincode = { version = "2.0", features = ["serde"] } walrus-rust = "0.2.0" thiserror = "2.0" strum = { version = "0.27", features = ["derive"] } +# Parquet Variant support for proper semi-structured data encoding +parquet-variant = "0.2.0" +parquet-variant-compute = "0.2.0" +parquet-variant-json = "0.2.0" [dev-dependencies] sqllogictest = { git = "https://github.com/risinglightdb/sqllogictest-rs.git" } diff --git a/schemas/otel_logs_and_spans.yaml b/schemas/otel_logs_and_spans.yaml index bcb0aa4..aa81cc8 100644 --- a/schemas/otel_logs_and_spans.yaml +++ b/schemas/otel_logs_and_spans.yaml @@ -49,7 +49,7 @@ fields: data_type: Int32 nullable: true - name: body - data_type: Utf8 + data_type: Variant nullable: true - name: duration data_type: Int64 @@ -61,7 +61,7 @@ fields: data_type: 'Timestamp(Microsecond, Some("UTC"))' nullable: true - name: context - data_type: Utf8 + data_type: Variant nullable: true - name: context___trace_id data_type: Utf8 @@ -79,13 +79,13 @@ fields: data_type: Utf8 nullable: true - name: events - data_type: Utf8 + data_type: Variant nullable: true - name: links - data_type: Utf8 + data_type: Variant nullable: true - name: attributes - data_type: Utf8 + data_type: Variant nullable: true - name: attributes___client___address data_type: Utf8 @@ -235,7 +235,7 @@ fields: data_type: Utf8 nullable: true - name: resource - data_type: Utf8 + data_type: Variant nullable: true - name: resource___service___name data_type: Utf8 @@ -268,7 +268,7 @@ fields: data_type: "List(Utf8)" nullable: false - name: errors - data_type: Utf8 + data_type: Variant nullable: true - name: log_pattern data_type: Utf8 diff --git a/src/database.rs b/src/database.rs index 39c644a..a18d8d9 100644 --- a/src/database.rs +++ b/src/database.rs @@ -31,6 +31,7 @@ use deltalake::PartitionFilter; use deltalake::datafusion::parquet::file::metadata::SortingColumn; use deltalake::datafusion::parquet::file::properties::WriterProperties; use deltalake::kernel::transaction::CommitProperties; +use deltalake::kernel::{Action, Protocol}; use deltalake::operations::create::CreateBuilder; use deltalake::{DeltaTable, DeltaTableBuilder}; use futures::StreamExt; @@ -66,6 +67,23 @@ pub fn extract_project_id(batch: &RecordBatch) -> Option { // Compression level for parquet files - kept for WriterProperties fallback const ZSTD_COMPRESSION_LEVEL: i32 = 3; +/// Create a Protocol with variantType feature enabled. +/// Required for tables with Variant columns per Delta Lake protocol spec. +/// Note: Currently unused because delta-rs ProtocolChecker doesn't support variantType yet. +/// When delta-rs adds support, this can be enabled in the CreateBuilder.with_actions() call. +#[allow(dead_code)] +fn create_variant_protocol() -> Protocol { + // Protocol::try_new is pub(crate) in delta-kernel, so we use serde_json + // to create it (same approach used internally by delta-rs) + serde_json::from_value(serde_json::json!({ + "minReaderVersion": 3, + "minWriterVersion": 7, + "readerFeatures": ["variantType"], + "writerFeatures": ["variantType"] + })) + .expect("Valid protocol JSON") +} + #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] struct StorageConfig { project_id: String, @@ -765,10 +783,13 @@ impl Database { ctx.register_udf(set_config_udf); } - /// Register JSON functions from datafusion-functions-json + /// Register JSON functions from datafusion-functions-json with Variant-aware wrappers pub fn register_json_functions(&self, ctx: &mut SessionContext) { datafusion_functions_json::register_all(ctx).expect("Failed to register JSON functions"); - info!("Registered JSON functions with SessionContext"); + // Register variant-aware wrappers that override the standard JSON functions + // These handle both Variant (Struct) and Utf8 inputs transparently + crate::variant_utils::register_variant_json_functions(ctx); + info!("Registered JSON functions with Variant support"); } #[instrument( @@ -993,7 +1014,15 @@ impl Database { let mut config = HashMap::new(); config.insert("delta.checkpointInterval".to_string(), Some(checkpoint_interval)); - config.insert("delta.checkpointPolicy".to_string(), Some("v2".to_string())); + // Note: v2 checkpoint policy requires v2Checkpoint feature which delta-rs doesn't support yet + // config.insert("delta.checkpointPolicy".to_string(), Some("v2".to_string())); + + // Note: delta-rs doesn't yet support variantType in its ProtocolChecker. + // Variant columns are stored as Struct + // which is the correct binary representation per Parquet Variant spec. + // When delta-rs adds variantType support, we can enable the Protocol action. + // See: https://github.com/delta-io/delta-rs/blob/main/crates/core/src/kernel/transaction/protocol.rs + let actions: Vec = vec![]; match CreateBuilder::new() .with_location(&storage_uri) @@ -1002,6 +1031,7 @@ impl Database { .with_storage_options(storage_options.clone()) .with_commit_properties(commit_properties) .with_configuration(config) + .with_actions(actions) .await { Ok(table) => break table, @@ -2004,6 +2034,7 @@ mod tests { let db_arc = Arc::new(db.clone()); let mut ctx = db_arc.create_session_context(); datafusion_functions_json::register_all(&mut ctx)?; + crate::variant_utils::register_variant_json_functions(&mut ctx); db.setup_session_context(&mut ctx)?; Ok((db, ctx, test_prefix)) } diff --git a/src/lib.rs b/src/lib.rs index 008cb8d..106fea8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,4 +14,5 @@ pub mod schema_loader; pub mod statistics; pub mod telemetry; pub mod test_utils; +pub mod variant_utils; pub mod wal; diff --git a/src/schema_loader.rs b/src/schema_loader.rs index cc360e2..d837cd9 100644 --- a/src/schema_loader.rs +++ b/src/schema_loader.rs @@ -90,6 +90,21 @@ impl TableSchema { }) .collect() } + + /// Check if this schema contains any Variant type columns + pub fn has_variant_columns(&self) -> bool { + self.fields.iter().any(|f| f.data_type == "Variant") + } +} + +/// Get the Arrow DataType for Variant (Struct with metadata and value BinaryView fields) +/// Uses BinaryView to match parquet-variant-compute output +pub fn variant_arrow_type() -> ArrowDataType { + use arrow::datatypes::Fields; + ArrowDataType::Struct(Fields::from(vec![ + Arc::new(Field::new("metadata", ArrowDataType::BinaryView, false)), + Arc::new(Field::new("value", ArrowDataType::BinaryView, false)), + ])) } fn parse_arrow_data_type(s: &str) -> anyhow::Result { @@ -103,10 +118,17 @@ fn parse_arrow_data_type(s: &str) -> anyhow::Result { "List(Utf8)" => ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Utf8, true))), "Timestamp(Microsecond, None)" => ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), "Timestamp(Microsecond, Some(\"UTC\"))" => ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, Some("UTC".into())), + "Variant" => variant_arrow_type(), _ => anyhow::bail!("Unknown type: {}", s), }) } +/// Create a proper Delta Variant type using delta-kernel's unshredded_variant() +/// This represents a Struct that delta-kernel recognizes as Variant +fn variant_delta_type() -> DeltaDataType { + DeltaDataType::unshredded_variant() +} + fn parse_delta_data_type(s: &str) -> anyhow::Result { use PrimitiveType::*; Ok(match s { @@ -115,6 +137,7 @@ fn parse_delta_data_type(s: &str) -> anyhow::Result { "Int32" | "UInt32" => DeltaDataType::Primitive(Integer), "Int64" | "UInt64" => DeltaDataType::Primitive(Long), "List(Utf8)" => DeltaDataType::Array(Box::new(ArrayType::new(DeltaDataType::Primitive(String), true))), + "Variant" => variant_delta_type(), _ if s.starts_with("Timestamp") => DeltaDataType::Primitive(Timestamp), _ => anyhow::bail!("Unknown type: {}", s), }) diff --git a/src/test_utils.rs b/src/test_utils.rs index fe7dbae..edc67c7 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,19 +1,34 @@ pub mod test_helpers { use crate::schema_loader::get_default_schema; + use crate::variant_utils::{VARIANT_COLUMNS, convert_batch_json_to_variant, variant_data_type}; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_json::ReaderBuilder; use datafusion::arrow::record_batch::RecordBatch; use serde_json::{Value, json}; use std::collections::HashMap; + use std::sync::Arc; pub fn json_to_batch(records: Vec) -> anyhow::Result { let schema = get_default_schema().schema_ref(); + // Create a parsing schema with Utf8 for Variant columns (arrow_json can't parse Variant) + let parse_schema = Arc::new(Schema::new( + schema.fields().iter().map(|f| { + if VARIANT_COLUMNS.contains(&f.name().as_str()) && *f.data_type() == variant_data_type() { + Arc::new(Field::new(f.name(), DataType::Utf8, f.is_nullable())) + } else { + f.clone() + } + }).collect::>() + )); let json_data = records.into_iter().map(|v| v.to_string()).collect::>().join("\n"); - ReaderBuilder::new(schema.clone()) + let batch = ReaderBuilder::new(parse_schema) .build(std::io::Cursor::new(json_data.as_bytes()))? .next() - .ok_or_else(|| anyhow::anyhow!("Failed to read batch"))? - .map_err(Into::into) + .ok_or_else(|| anyhow::anyhow!("Failed to read batch"))??; + + // Convert Utf8 JSON columns to Variant + convert_batch_json_to_variant(batch).map_err(|e| anyhow::anyhow!("{}", e)) } pub fn create_default_record() -> HashMap { diff --git a/src/variant_utils.rs b/src/variant_utils.rs new file mode 100644 index 0000000..7e8dc9a --- /dev/null +++ b/src/variant_utils.rs @@ -0,0 +1,405 @@ +//! Variant type utilities for converting between JSON and Variant, +//! and providing UDFs that work with both Variant and Utf8 inputs. +//! +//! This module uses the parquet-variant crate for proper Variant binary encoding +//! as specified in the Parquet Variant specification. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility}; +use arrow::array::{BinaryArray, BinaryViewArray, StringBuilder}; +use parquet_variant::Variant; +use parquet_variant_compute::json_to_variant; +use parquet_variant_json::VariantToJson; + +/// Columns that should be stored as Variant type +pub const VARIANT_COLUMNS: &[&str] = &["body", "context", "events", "links", "attributes", "resource", "errors"]; + +/// Get the Arrow DataType for Variant (Struct with metadata and value BinaryView fields) +/// This matches the parquet-variant-compute output +pub fn variant_data_type() -> DataType { + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("metadata", DataType::BinaryView, false)), + Arc::new(Field::new("value", DataType::BinaryView, false)), + ])) +} + +/// Check if an array is a Variant type (Struct with metadata and value fields) +pub fn is_variant_array(array: &ArrayRef) -> bool { + if let DataType::Struct(fields) = array.data_type() { + fields.len() == 2 + && fields.iter().any(|f| f.name() == "metadata" && matches!(f.data_type(), DataType::BinaryView | DataType::Binary)) + && fields.iter().any(|f| f.name() == "value" && matches!(f.data_type(), DataType::BinaryView | DataType::Binary)) + } else { + false + } +} + +/// Check if a ColumnarValue is a Variant type +pub fn is_variant_columnar(value: &ColumnarValue) -> bool { + match value { + ColumnarValue::Array(arr) => is_variant_array(arr), + ColumnarValue::Scalar(ScalarValue::Struct(arr)) => is_variant_array(&(arr.clone() as ArrayRef)), + _ => false, + } +} + +/// Convert a JSON string array to a Variant array using proper parquet-variant encoding +pub fn json_to_variant_array(json_array: &StringArray) -> Result { + let array_ref: ArrayRef = Arc::new(json_array.clone()); + let variant_array = json_to_variant(&array_ref) + .map_err(|e| DataFusionError::Execution(format!("Failed to convert JSON to Variant: {}", e)))?; + Ok(Arc::new(variant_array.into_inner())) +} + +/// Convert a Variant array to a JSON string array +/// Handles both Binary and BinaryView field types +pub fn variant_to_json_array(variant_array: &StructArray) -> Result { + let metadata_col = variant_array.column_by_name("metadata") + .ok_or_else(|| DataFusionError::Execution("Missing metadata field in Variant".to_string()))?; + let value_col = variant_array.column_by_name("value") + .ok_or_else(|| DataFusionError::Execution("Missing value field in Variant".to_string()))?; + + // Helper to get bytes from either Binary or BinaryView array + fn get_bytes<'a>(arr: &'a ArrayRef, idx: usize) -> Option<&'a [u8]> { + if let Some(binary) = arr.as_any().downcast_ref::() { + if binary.is_null(idx) { None } else { Some(binary.value(idx)) } + } else if let Some(binary_view) = arr.as_any().downcast_ref::() { + if binary_view.is_null(idx) { None } else { Some(binary_view.value(idx)) } + } else { + None + } + } + + let mut builder = StringBuilder::new(); + for i in 0..variant_array.len() { + if variant_array.is_null(i) { + builder.append_null(); + } else { + let metadata = get_bytes(metadata_col, i) + .ok_or_else(|| DataFusionError::Execution("Missing metadata bytes".to_string()))?; + let value = get_bytes(value_col, i) + .ok_or_else(|| DataFusionError::Execution("Missing value bytes".to_string()))?; + let variant = Variant::new(metadata, value); + let json_str = variant.to_json_string() + .map_err(|e| DataFusionError::Execution(format!("Failed to convert Variant to JSON: {}", e)))?; + builder.append_value(&json_str); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Convert a ColumnarValue to JSON string if it's a Variant, otherwise pass through +pub fn ensure_json_columnar(value: &ColumnarValue) -> Result { + if !is_variant_columnar(value) { + return Ok(value.clone()); + } + + match value { + ColumnarValue::Array(arr) => { + let struct_arr = arr.as_any().downcast_ref::().ok_or_else(|| DataFusionError::Execution("Expected StructArray".to_string()))?; + let json_arr = variant_to_json_array(struct_arr)?; + Ok(ColumnarValue::Array(json_arr)) + } + ColumnarValue::Scalar(scalar) => { + if let ScalarValue::Struct(arr) = scalar { + let json_arr = variant_to_json_array(arr)?; + let json_str = json_arr.as_any().downcast_ref::().ok_or_else(|| DataFusionError::Execution("Expected StringArray".to_string()))?; + if json_str.is_null(0) { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(json_str.value(0).to_string())))) + } + } else { + Ok(value.clone()) + } + } + } +} + +/// Convert a RecordBatch, transforming JSON string columns (Utf8) to Variant for specified columns +pub fn convert_batch_json_to_variant(batch: RecordBatch) -> Result { + let schema = batch.schema(); + let mut new_columns: Vec = Vec::with_capacity(batch.num_columns()); + let mut new_fields: Vec> = Vec::with_capacity(batch.num_columns()); + + for (i, field) in schema.fields().iter().enumerate() { + let column = batch.column(i); + + if VARIANT_COLUMNS.contains(&field.name().as_str()) && matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) { + // Convert UTF8 JSON string to Variant using proper parquet-variant encoding + let string_array = column.as_any().downcast_ref::().ok_or_else(|| DataFusionError::Execution("Expected StringArray".to_string()))?; + let variant_array = json_to_variant_array(string_array)?; + let variant_field = Arc::new(Field::new(field.name(), variant_data_type(), field.is_nullable())); + new_columns.push(variant_array); + new_fields.push(variant_field); + } else { + new_columns.push(column.clone()); + new_fields.push(field.clone()); + } + } + + let new_schema = Arc::new(arrow::datatypes::Schema::new(new_fields)); + RecordBatch::try_new(new_schema, new_columns).map_err(|e| DataFusionError::Execution(format!("Failed to create batch: {}", e))) +} + +// ============================================================================ +// Replacement UDFs that work with both Variant and Utf8 inputs +// ============================================================================ + +/// Create the json_get UDF (-> operator) that works with both Variant and Utf8 +pub fn create_json_get_udf() -> ScalarUDF { + ScalarUDF::from(VariantAwareJsonGetUDF::new()) +} + +/// Create the json_get_str UDF (->> operator) that works with both Variant and Utf8 +pub fn create_json_get_str_udf() -> ScalarUDF { + ScalarUDF::from(VariantAwareJsonGetStrUDF::new()) +} + +/// Create the json_length UDF that works with both Variant and Utf8 +pub fn create_json_length_udf() -> ScalarUDF { + ScalarUDF::from(VariantAwareJsonLengthUDF::new()) +} + +/// Create the json_contains UDF that works with both Variant and Utf8 +pub fn create_json_contains_udf() -> ScalarUDF { + ScalarUDF::from(VariantAwareJsonContainsUDF::new()) +} + +// ============================================================================ +// UDF Implementations +// ============================================================================ + +#[derive(Debug, Hash, Eq, PartialEq)] +struct VariantAwareJsonGetUDF { + signature: Signature, +} + +impl VariantAwareJsonGetUDF { + fn new() -> Self { + Self { signature: Signature::variadic_any(Volatility::Immutable) } + } +} + +impl ScalarUDFImpl for VariantAwareJsonGetUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_get" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result { + if args.args.len() < 2 { + return Err(DataFusionError::Execution("json_get requires at least 2 arguments".to_string())); + } + + // Convert first argument from Variant to JSON if needed + let json_arg = ensure_json_columnar(&args.args[0])?; + + // Call the underlying datafusion-functions-json implementation + let mut new_args = args.args.clone(); + new_args[0] = json_arg; + + // Use the json_get UDF from datafusion-functions-json + datafusion_functions_json::udfs::json_get_udf().invoke_with_args(ScalarFunctionArgs { args: new_args, ..args }) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +struct VariantAwareJsonGetStrUDF { + signature: Signature, +} + +impl VariantAwareJsonGetStrUDF { + fn new() -> Self { + Self { signature: Signature::variadic_any(Volatility::Immutable) } + } +} + +impl ScalarUDFImpl for VariantAwareJsonGetStrUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_get_str" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result { + if args.args.len() < 2 { + return Err(DataFusionError::Execution("json_get_str requires at least 2 arguments".to_string())); + } + + let json_arg = ensure_json_columnar(&args.args[0])?; + let mut new_args = args.args.clone(); + new_args[0] = json_arg; + + datafusion_functions_json::udfs::json_get_str_udf().invoke_with_args(ScalarFunctionArgs { args: new_args, ..args }) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +struct VariantAwareJsonLengthUDF { + signature: Signature, +} + +impl VariantAwareJsonLengthUDF { + fn new() -> Self { + Self { signature: Signature::variadic_any(Volatility::Immutable) } + } +} + +impl ScalarUDFImpl for VariantAwareJsonLengthUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(DataType::UInt64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result { + if args.args.is_empty() { + return Err(DataFusionError::Execution("json_length requires at least 1 argument".to_string())); + } + + let json_arg = ensure_json_columnar(&args.args[0])?; + let mut new_args = args.args.clone(); + new_args[0] = json_arg; + + datafusion_functions_json::udfs::json_length_udf().invoke_with_args(ScalarFunctionArgs { args: new_args, ..args }) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +struct VariantAwareJsonContainsUDF { + signature: Signature, +} + +impl VariantAwareJsonContainsUDF { + fn new() -> Self { + Self { signature: Signature::variadic_any(Volatility::Immutable) } + } +} + +impl ScalarUDFImpl for VariantAwareJsonContainsUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result { + if args.args.len() < 2 { + return Err(DataFusionError::Execution("json_contains requires at least 2 arguments".to_string())); + } + + let json_arg = ensure_json_columnar(&args.args[0])?; + let mut new_args = args.args.clone(); + new_args[0] = json_arg; + + datafusion_functions_json::udfs::json_contains_udf().invoke_with_args(ScalarFunctionArgs { args: new_args, ..args }) + } +} + +/// Register all variant-aware JSON functions in the session context +pub fn register_variant_json_functions(ctx: &mut datafusion::execution::context::SessionContext) { + // Register our variant-aware replacements that override the datafusion-functions-json ones + ctx.register_udf(create_json_get_udf()); + ctx.register_udf(create_json_get_str_udf()); + ctx.register_udf(create_json_length_udf()); + ctx.register_udf(create_json_contains_udf()); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_variant_data_type() { + let dt = variant_data_type(); + assert!(matches!(dt, DataType::Struct(_))); + } + + #[test] + fn test_json_to_variant_roundtrip() { + let json_data = vec![Some(r#"{"name": "test", "value": 123}"#), None, Some(r#"[1, 2, 3]"#)]; + let json_array = StringArray::from(json_data); + + // Convert to variant + let variant_arr = json_to_variant_array(&json_array).unwrap(); + + // Convert back to JSON + let struct_arr = variant_arr.as_any().downcast_ref::().unwrap(); + let json_back = variant_to_json_array(struct_arr).unwrap(); + let json_str = json_back.as_any().downcast_ref::().unwrap(); + + assert_eq!(json_str.len(), 3); + assert!(!json_str.is_null(0)); + assert!(json_str.is_null(1)); + assert!(!json_str.is_null(2)); + + // Verify content (JSON may be reordered but should parse to same value) + let parsed_original: serde_json::Value = serde_json::from_str(r#"{"name": "test", "value": 123}"#).unwrap(); + let parsed_roundtrip: serde_json::Value = serde_json::from_str(json_str.value(0)).unwrap(); + assert_eq!(parsed_original, parsed_roundtrip); + + let parsed_array: serde_json::Value = serde_json::from_str(json_str.value(2)).unwrap(); + assert_eq!(parsed_array, serde_json::json!([1, 2, 3])); + } + + #[test] + fn test_is_variant_array() { + let json_array = StringArray::from(vec![Some(r#"{"key": "value"}"#)]); + let variant_arr = json_to_variant_array(&json_array).unwrap(); + + assert!(is_variant_array(&variant_arr)); + + // Non-variant array should return false + let string_arr: ArrayRef = Arc::new(StringArray::from(vec!["test"])); + assert!(!is_variant_array(&string_arr)); + } +}