Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support reading arrow Map type from Delta #21330

Merged
merged 3 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,16 @@ impl DataType {
ArrowDataType::Float32 => DataType::Float32,
ArrowDataType::Float64 => DataType::Float64,
#[cfg(feature = "dtype-array")]
ArrowDataType::FixedSizeList(f, size) => DataType::Array(DataType::from_arrow_field(f).boxed(), *size),
ArrowDataType::LargeList(f) | ArrowDataType::List(f) => DataType::List(DataType::from_arrow_field(f).boxed()),
ArrowDataType::FixedSizeList(f, size) => {
DataType::Array(DataType::from_arrow_field(f).boxed(), *size)
},
ArrowDataType::LargeList(f) | ArrowDataType::List(f) => {
DataType::List(DataType::from_arrow_field(f).boxed())
},
ArrowDataType::Date32 => DataType::Date,
ArrowDataType::Timestamp(tu, tz) => DataType::Datetime(tu.into(), DataType::canonical_timezone(tz)),
ArrowDataType::Timestamp(tu, tz) => {
DataType::Datetime(tu.into(), DataType::canonical_timezone(tz))
},
ArrowDataType::Duration(tu) => DataType::Duration(tu.into()),
ArrowDataType::Date64 => DataType::Datetime(TimeUnit::Milliseconds, None),
ArrowDataType::Time64(_) | ArrowDataType::Time32(_) => DataType::Time,
Expand All @@ -183,19 +189,25 @@ impl DataType {
// We know thus that len is only [0-9] and the first ';' doesn't belong to the
// payload.
while let Some(pos) = encoded.find(';') {
let (len, remainder) = encoded.split_at(pos);
// Split off ';'
encoded = &remainder[1..];
let len = len.parse::<usize>().unwrap();

let (value, remainder) = encoded.split_at(len);
cats.push_value(value);
encoded = remainder;
let (len, remainder) = encoded.split_at(pos);
// Split off ';'
encoded = &remainder[1..];
let len = len.parse::<usize>().unwrap();

let (value, remainder) = encoded.split_at(len);
cats.push_value(value);
encoded = remainder;
}
DataType::Enum(Some(Arc::new(RevMapping::build_local(cats.into()))), Default::default())
DataType::Enum(
Some(Arc::new(RevMapping::build_local(cats.into()))),
Default::default(),
)
} else if let Some(ordering) = md.and_then(|md| md.categorical()) {
DataType::Categorical(None, ordering)
} else if matches!(value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View) {
} else if matches!(
value_type.as_ref(),
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View
) {
DataType::Categorical(None, Default::default())
} else {
Self::from_arrow(value_type, bin_to_view, None)
Expand All @@ -204,11 +216,11 @@ impl DataType {
#[cfg(feature = "dtype-struct")]
ArrowDataType::Struct(fields) => {
DataType::Struct(fields.iter().map(|fld| fld.into()).collect())
}
},
#[cfg(not(feature = "dtype-struct"))]
ArrowDataType::Struct(_) => {
panic!("activate the 'dtype-struct' feature to handle struct data types")
}
},
ArrowDataType::Extension(ext) if ext.name.as_str() == EXTENSION_NAME => {
#[cfg(feature = "object")]
{
Expand All @@ -218,21 +230,30 @@ impl DataType {
{
panic!("activate the 'object' feature to be able to load POLARS_EXTENSION_TYPE")
}
}
},
#[cfg(feature = "dtype-decimal")]
ArrowDataType::Decimal(precision, scale) => DataType::Decimal(Some(*precision), Some(*scale)),
ArrowDataType::Utf8View |ArrowDataType::LargeUtf8 | ArrowDataType::Utf8 => DataType::String,
ArrowDataType::Decimal(precision, scale) => {
DataType::Decimal(Some(*precision), Some(*scale))
},
ArrowDataType::Utf8View | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8 => {
DataType::String
},
ArrowDataType::BinaryView => DataType::Binary,
ArrowDataType::LargeBinary | ArrowDataType::Binary => {
if bin_to_view {
DataType::Binary
} else {

DataType::BinaryOffset
}
},
ArrowDataType::FixedSizeBinary(_) => DataType::Binary,
dt => panic!("Arrow datatype {dt:?} not supported by Polars. You probably need to activate that data-type feature."),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rustfmt broke because this string line was too long

ArrowDataType::Map(inner, _is_sorted) => {
DataType::List(Self::from_arrow_field(inner).boxed())
},
dt => panic!(
"Arrow datatype {dt:?} not supported by Polars. \
You probably need to activate that data-type feature."
),
}
}
}
Expand Down
30 changes: 22 additions & 8 deletions crates/polars-io/src/parquet/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,23 @@ impl<R: MmapBytesReader> ParquetReader<R> {
projected_arrow_schema: Option<&ArrowSchema>,
allow_missing_columns: bool,
) -> PolarsResult<Self> {
// `self.schema` gets overwritten if allow_missing_columns
let this_schema_width = self.schema()?.len();
let slf_schema = self.schema()?;
let slf_schema_width = slf_schema.len();

if allow_missing_columns {
// Must check the dtypes
ensure_matching_dtypes_if_found(
projected_arrow_schema.unwrap_or(first_schema.as_ref()),
self.schema()?.as_ref(),
)?;
self.schema.replace(first_schema.clone());
self.schema = Some(Arc::new(
first_schema
.iter()
.map(|(name, field)| {
(name.clone(), slf_schema.get(name).unwrap_or(field).clone())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensures we use the Field from the metadata of this file

})
.collect(),
));
}

let schema = self.schema()?;
Expand All @@ -110,7 +117,7 @@ impl<R: MmapBytesReader> ParquetReader<R> {
projected_arrow_schema,
)?;
} else {
if this_schema_width > first_schema.len() {
if slf_schema_width > first_schema.len() {
polars_bail!(
SchemaMismatch:
"parquet file contained extra columns and no selection was given"
Expand Down Expand Up @@ -334,16 +341,23 @@ impl ParquetAsyncReader {
projected_arrow_schema: Option<&ArrowSchema>,
allow_missing_columns: bool,
) -> PolarsResult<Self> {
// `self.schema` gets overwritten if allow_missing_columns
let this_schema_width = self.schema().await?.len();
let slf_schema = self.schema().await?;
let slf_schema_width = slf_schema.len();

if allow_missing_columns {
// Must check the dtypes
ensure_matching_dtypes_if_found(
projected_arrow_schema.unwrap_or(first_schema.as_ref()),
self.schema().await?.as_ref(),
)?;
self.schema.replace(first_schema.clone());
self.schema = Some(Arc::new(
first_schema
.iter()
.map(|(name, field)| {
(name.clone(), slf_schema.get(name).unwrap_or(field).clone())
})
.collect(),
));
}

let schema = self.schema().await?;
Expand All @@ -355,7 +369,7 @@ impl ParquetAsyncReader {
projected_arrow_schema,
)?;
} else {
if this_schema_width > first_schema.len() {
if slf_schema_width > first_schema.len() {
polars_bail!(
SchemaMismatch:
"parquet file contained extra columns and no selection was given"
Expand Down
6 changes: 5 additions & 1 deletion py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ def test_from_pyarrow_map() -> None:
),
)

result = cast(pl.DataFrame, pl.from_arrow(pa_table))
# Convert from an empty table to trigger an ArrowSchema -> native schema
# conversion (checks that ArrowDataType::Map is handled in Rust).
pl.DataFrame(pa_table.slice(0, 0))

result = pl.DataFrame(pa_table)
assert result.to_dict(as_series=False) == {
"idx": [1, 2],
"mapping": [
Expand Down
31 changes: 30 additions & 1 deletion py-polars/tests/unit/io/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
import pyarrow.fs
import pytest
from deltalake import DeltaTable
from deltalake import DeltaTable, write_deltalake
from deltalake.exceptions import DeltaError, TableNotFoundError
from deltalake.table import TableMerger

Expand Down Expand Up @@ -497,3 +497,32 @@ def test_read_delta_empty(tmp_path: Path) -> None:

DeltaTable.create(path, pl.DataFrame(schema={"x": pl.Int64}).to_arrow().schema)
assert_frame_equal(pl.read_delta(path), pl.DataFrame(schema={"x": pl.Int64}))


@pytest.mark.write_disk
def test_read_delta_arrow_map_type(tmp_path: Path) -> None:
payload = [
{"id": 1, "account_id": {17: "100.01.001 Cash"}},
{"id": 2, "account_id": {18: "180.01.001 Cash", 19: "foo"}},
]

schema = pa.schema(
[
pa.field("id", pa.int32()),
pa.field("account_id", pa.map_(pa.int32(), pa.string())),
]
)
table = pa.Table.from_pylist(payload, schema)

expect = pl.DataFrame(table)

table_path = str(tmp_path)
write_deltalake(
table_path,
table,
mode="overwrite",
engine="rust",
)

assert_frame_equal(pl.scan_delta(table_path).collect(), expect)
assert_frame_equal(pl.read_delta(table_path), expect)
15 changes: 15 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,9 @@ def test_parquet_nested_struct_17933() -> None:
test_round_trip(df)


# This is fixed with POLARS_FORCE_MULTISCAN=1. Without it we have
# first_metadata.unwrap() on None.
@pytest.mark.may_fail_auto_streaming
def test_parquet_pyarrow_map() -> None:
xs = [
[
Expand Down Expand Up @@ -1341,6 +1344,18 @@ def test_parquet_pyarrow_map() -> None:
f.seek(0)
assert_frame_equal(pl.read_parquet(f).explode(["x"]), expected)

# Test for https://github.com/pola-rs/polars/issues/21317
# Specifying schema/allow_missing_columns
for allow_missing_columns in [True, False]:
assert_frame_equal(
pl.read_parquet(
f,
schema={"x": pl.List(pl.Struct({"key": pl.Int32, "value": pl.Int32}))},
allow_missing_columns=allow_missing_columns,
).explode(["x"]),
expected,
)


@pytest.mark.parametrize(
("s", "elem"),
Expand Down
Loading