Skip to content

Commit

Permalink
feat: 3154 from_parquet should be able to read partial columns (#3156)
Browse files Browse the repository at this point in the history
* feat: 3154 - Adding unit test and fix

* Taking this opportunity to fix an issue with 2772 test

This had been bugging me, and with a better understanding now I can improve
the unit tests. Mostly for future readability and correctness.

* Removing commented-out test return value

---------

Co-authored-by: Ianna Osborne <ianna.osborne@cern.ch>
  • Loading branch information
tcawlfield and ianna authored Jun 24, 2024
1 parent 6f688e9 commit 03f6169
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/awkward/_connect/pyarrow/table_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import json
Expand Down Expand Up @@ -128,11 +130,34 @@ def native_arrow_field_to_akarraytype(
fields = _fields_of_strg_type(storage_type)
if len(fields) > 0:
# We need to replace storage_type with one that contains AwkwardArrowTypes.
awkwardized_fields = [
native_arrow_field_to_akarraytype(field, meta) # Recurse
for field, meta in zip(fields, metadata["subfield_metadata"])
]
sub_meta = metadata["subfield_metadata"]
awkwardized_fields = None # Temporary
if len(sub_meta) == len(fields):
awkwardized_fields = [
native_arrow_field_to_akarraytype(field, meta) # Recurse
for field, meta in zip(fields, metadata["subfield_metadata"])
]
elif len(fields) < len(sub_meta):
# If a user has read a partial column, we can have fewer Arrow fields than the original.
sub_meta_dict = {sm["field_name"]: sm for sm in sub_meta}
awkwardized_fields = []
for field in fields:
if field.name in sub_meta_dict:
awkwardized_fields.append(
native_arrow_field_to_akarraytype(
field, sub_meta_dict[field.name]
)
)
else:
raise ValueError(
f"Cannot find Awkward metadata for sub-field {field.name}"
)
else:
raise ValueError(
f"Not enough fields in Awkward metadata. Have {len(sub_meta)} need at least {len(fields)}."
)
storage_type = _make_pyarrow_type_like(storage_type, awkwardized_fields)

ak_type = AwkwardArrowType._from_metadata_object(storage_type, metadata)
return pyarrow.field(ntv_field.name, type=ak_type, nullable=ntv_field.nullable)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_2772_parquet_extn_array_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import io
import os

import numpy as np
Expand Down Expand Up @@ -137,6 +138,22 @@ def test_array_conversions(akarray, as_dict):
rt_array = ak.from_arrow(as_extn, highlevel=True)
assert to_list(rt_array) == to_list(akarray)

# Deeper test of types
akarray_high = ak.Array(akarray)
if akarray_high.type.content.parameters.get("__categorical__", False) == as_dict:
# as_dict is supposed to go hand-in-hand with __categorical__: True, and if it
# does not, we do not round-trip perfectly. So only test when this is set correctly.
assert rt_array.type == akarray_high.type

ak_type_str_orig = io.StringIO()
ak_type_str_rtrp = io.StringIO()
akarray_high.type.show(stream=ak_type_str_orig)
rt_array.type.show(stream=ak_type_str_rtrp)
if ak_type_str_orig.getvalue() != ak_type_str_rtrp.getvalue():
print(" Original type:", ak_type_str_orig.getvalue())
print(" Rnd-trip type:", ak_type_str_rtrp.getvalue())
assert ak_type_str_orig.getvalue() == ak_type_str_rtrp.getvalue()


def test_table_conversion():
ak_tbl_like = ak.Array(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_3154_parquet_subcolumn_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
# ruff: noqa: E402

from __future__ import annotations

import os

import pytest

import awkward as ak

pa = pytest.importorskip("pyarrow")
pq = pytest.importorskip("pyarrow.parquet")


def test_parquet_subcolumn_select(tmp_path):
ak_tbl = ak.Array(
{
"a": [
{"lbl": "item 1", "idx": 11, "ids": [1, 2, 3]},
{"lbl": "item 2", "idx": 12, "ids": [51, 52]},
{"lbl": "item 3", "idx": 13, "ids": [61, 62, 63, 64]},
],
"b": [
[[111, 112], [121, 122]],
[[211, 212], [221, 222]],
[[311, 312], [321, 322]],
],
}
)
parquet_file = os.path.join(tmp_path, "test_3514.parquet")
ak.to_parquet(ak_tbl, parquet_file)

selection = ak.from_parquet(parquet_file, columns=["a.ids", "b"])
assert selection["a"].to_list() == [
{"ids": [1, 2, 3]},
{"ids": [51, 52]},
{"ids": [61, 62, 63, 64]},
]
assert selection["b"].to_list() == ak_tbl["b"].to_list()

0 comments on commit 03f6169

Please sign in to comment.