Skip to content

Commit ca4a0ae

Browse files
authored
[Variant] Enforce shredded-type validation in shred_variant (#8796)
# Which issue does this PR close? - Closes #8795. # Rationale for this change Mentioned in the issue # What changes are included in this PR? Add validation in `shred_variant` to allow spec-approved types only. # Are these changes tested? Yes # Are there any user-facing changes?
1 parent 3d5428d commit ca4a0ae

File tree

3 files changed

+310
-128
lines changed

3 files changed

+310
-128
lines changed

parquet-variant-compute/src/shred_variant.rs

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{VariantArray, VariantValueArrayBuilder};
2525
use arrow::array::{ArrayRef, BinaryViewArray, NullBufferBuilder};
2626
use arrow::buffer::NullBuffer;
2727
use arrow::compute::CastOptions;
28-
use arrow::datatypes::{DataType, Fields};
28+
use arrow::datatypes::{DataType, Fields, TimeUnit};
2929
use arrow::error::{ArrowError, Result};
3030
use parquet_variant::{Variant, VariantBuilderExt};
3131

@@ -123,13 +123,39 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
123123
"Shredding variant array values as arrow lists".to_string(),
124124
));
125125
}
126-
_ => {
126+
// Supported shredded primitive types, see Variant shredding spec:
127+
// https://github.com/apache/parquet-format/blob/master/VariantShredding.md#shredded-value-types
128+
DataType::Boolean
129+
| DataType::Int8
130+
| DataType::Int16
131+
| DataType::Int32
132+
| DataType::Int64
133+
| DataType::Float32
134+
| DataType::Float64
135+
| DataType::Decimal32(..)
136+
| DataType::Decimal64(..)
137+
| DataType::Decimal128(..)
138+
| DataType::Date32
139+
| DataType::Time64(TimeUnit::Microsecond)
140+
| DataType::Timestamp(TimeUnit::Microsecond | TimeUnit::Nanosecond, _)
141+
| DataType::Binary
142+
| DataType::BinaryView
143+
| DataType::Utf8
144+
| DataType::Utf8View
145+
| DataType::FixedSizeBinary(16) // UUID
146+
=> {
127147
let builder =
128148
make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?;
129149
let typed_value_builder =
130150
VariantToShreddedPrimitiveVariantRowBuilder::new(builder, capacity, top_level);
131151
VariantToShreddedVariantRowBuilder::Primitive(typed_value_builder)
132152
}
153+
DataType::FixedSizeBinary(_) => {
154+
return Err(ArrowError::InvalidArgumentError(format!("{data_type} is not a valid variant shredding type. Only FixedSizeBinary(16) for UUID is supported.")))
155+
}
156+
_ => {
157+
return Err(ArrowError::InvalidArgumentError(format!("{data_type} is not a valid variant shredding type")))
158+
}
133159
};
134160
Ok(builder)
135161
}
@@ -327,7 +353,7 @@ mod tests {
327353
use super::*;
328354
use crate::VariantArrayBuilder;
329355
use arrow::array::{Array, FixedSizeBinaryArray, Float64Array, Int64Array};
330-
use arrow::datatypes::{DataType, Field, Fields};
356+
use arrow::datatypes::{DataType, Field, Fields, TimeUnit, UnionFields, UnionMode};
331357
use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder};
332358
use std::sync::Arc;
333359
use uuid::Uuid;
@@ -536,6 +562,60 @@ mod tests {
536562
assert!(typed_value_float64.is_null(2)); // string doesn't convert
537563
}
538564

565+
#[test]
566+
fn test_invalid_shredded_types_rejected() {
567+
let input = VariantArray::from_iter([Variant::from(42)]);
568+
569+
let invalid_types = vec![
570+
DataType::UInt8,
571+
DataType::Float16,
572+
DataType::Decimal256(38, 10),
573+
DataType::Date64,
574+
DataType::Time32(TimeUnit::Second),
575+
DataType::Time64(TimeUnit::Nanosecond),
576+
DataType::Timestamp(TimeUnit::Millisecond, None),
577+
DataType::LargeBinary,
578+
DataType::LargeUtf8,
579+
DataType::FixedSizeBinary(17),
580+
DataType::Union(
581+
UnionFields::new(
582+
vec![0_i8, 1_i8],
583+
vec![
584+
Field::new("int_field", DataType::Int32, false),
585+
Field::new("str_field", DataType::Utf8, true),
586+
],
587+
),
588+
UnionMode::Dense,
589+
),
590+
DataType::Map(
591+
Arc::new(Field::new(
592+
"entries",
593+
DataType::Struct(Fields::from(vec![
594+
Field::new("key", DataType::Utf8, false),
595+
Field::new("value", DataType::Int32, true),
596+
])),
597+
false,
598+
)),
599+
false,
600+
),
601+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
602+
DataType::RunEndEncoded(
603+
Arc::new(Field::new("run_ends", DataType::Int32, false)),
604+
Arc::new(Field::new("values", DataType::Utf8, true)),
605+
),
606+
];
607+
608+
for data_type in invalid_types {
609+
let err = shred_variant(&input, &data_type).unwrap_err();
610+
assert!(
611+
matches!(err, ArrowError::InvalidArgumentError(_)),
612+
"expected InvalidArgumentError for {:?}, got {:?}",
613+
data_type,
614+
err
615+
);
616+
}
617+
}
618+
539619
#[test]
540620
fn test_object_shredding_comprehensive() {
541621
let mut builder = VariantArrayBuilder::new(7);

parquet-variant-compute/src/variant_get.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ mod test {
320320
use arrow::datatypes::DataType::{Int16, Int32, Int64};
321321
use arrow::datatypes::i256;
322322
use arrow_schema::DataType::{Boolean, Float32, Float64, Int8};
323-
use arrow_schema::{DataType, Field, FieldRef, Fields, TimeUnit};
323+
use arrow_schema::{DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit};
324324
use chrono::DateTime;
325325
use parquet_variant::{
326326
EMPTY_VARIANT_METADATA_BYTES, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16,
@@ -3685,6 +3685,34 @@ mod test {
36853685
));
36863686
}
36873687

3688+
#[test]
3689+
fn get_non_supported_temporal_types_error() {
3690+
let values = vec![None, Some(Variant::Null), Some(Variant::BooleanFalse)];
3691+
let variant_array: ArrayRef = ArrayRef::from(VariantArray::from_iter(values));
3692+
3693+
let test_cases = vec![
3694+
FieldRef::from(Field::new(
3695+
"result",
3696+
DataType::Duration(TimeUnit::Microsecond),
3697+
true,
3698+
)),
3699+
FieldRef::from(Field::new(
3700+
"result",
3701+
DataType::Interval(IntervalUnit::YearMonth),
3702+
true,
3703+
)),
3704+
];
3705+
3706+
for field in test_cases {
3707+
let options = GetOptions::new().with_as_type(Some(field));
3708+
let err = variant_get(&variant_array, options).unwrap_err();
3709+
assert!(
3710+
err.to_string()
3711+
.contains("Casting Variant to duration/interval types is not supported")
3712+
);
3713+
}
3714+
}
3715+
36883716
perfectly_shredded_variant_array_fn!(perfectly_shredded_invalid_time_variant_array, || {
36893717
// 86401000000 is invalid for Time64Microsecond (max is 86400000000)
36903718
Time64MicrosecondArray::from(vec![

0 commit comments

Comments
 (0)