Skip to content

Commit f5077ec

Browse files
tmivtuma
andauthored
Support for PdReader kwargs (AmpX-AI#8)
Co-authored-by: vojta tuma <vtuma@amp.energy>
1 parent 497673d commit f5077ec

File tree

2 files changed

+68
-7
lines changed

2 files changed

+68
-7
lines changed

fsql/deser.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
opt for the lazy approach (such as in Dask), and don't materialize inside neither the `read_single` nor
1818
`concat` methods.
1919
20+
Existing readers such as PandasReader allow customisation via passing through any kwargs to the underlying
21+
pandas read method.
22+
2023
The user should *not* bake in any specific business logic in here -- a more prefered approach is to
21-
return an object such as data frame as early as possible, and apply any transformations later on.
24+
return an object such as (lazy) data frame as early as possible, and apply any transformations later on.
2225
"""
2326
from __future__ import annotations
2427

2528
import json
2629
import logging
2730
from abc import ABC, abstractmethod
31+
from collections import defaultdict
2832
from collections.abc import Iterable
2933
from concurrent.futures import ThreadPoolExecutor
3034
from enum import Enum, auto, unique
@@ -87,31 +91,51 @@ def read_and_concat(
8791

8892

8993
class PandasReader(DataReader):
94+
"""Wraps various pandas read methods (parquet, json, csv, excel) into a single interface.
95+
Behaviour can be customised via passing any kwargs to the constructor.
96+
"""
97+
98+
def __init__(self, input_format=InputFormat.AUTO, **pdread_kwargs):
99+
super().__init__(input_format=input_format)
100+
self.pdread_user_kwargs = pdread_kwargs
101+
self.pdread_default_kwargs = defaultdict(dict)
102+
self.pdread_default_kwargs[InputFormat.PARQUET] = {
103+
"engine": "fastparquet",
104+
}
105+
self.pdread_default_kwargs[InputFormat.JSON] = {
106+
"lines": "true",
107+
}
108+
self.pdread_default_kwargs[InputFormat.XLSX] = {
109+
"engine": "openpyxl",
110+
}
111+
90112
def read_single(self, partition: Partition, fs: AbstractFileSystem) -> pd.DataFrame:
91113
logger.debug(f"read dataframe for partition {partition}")
92114
input_format = self.detect_format(partition.url)
93-
# TODO allow for user spec of engine and other params, essentially any quark
115+
logger.debug(f"format detected for partition {input_format} <- {partition}")
94116
if input_format is InputFormat.PARQUET:
95-
reader = lambda fd: pd.read_parquet(fd, engine="fastparquet") # noqa: E731
117+
reader = pd.read_parquet
96118
elif input_format is InputFormat.JSON:
97-
reader = lambda fd: pd.read_json(fd, lines=True) # noqa: E731
119+
reader = pd.read_json
98120
elif input_format is InputFormat.CSV:
99121
reader = pd.read_csv
100122
elif input_format is InputFormat.XLSX:
101-
reader = lambda fd: pd.read_excel(fd, engine="openpyxl") # noqa: E731
123+
reader = pd.read_excel
102124
elif input_format is InputFormat.AUTO:
103125
raise ValueError(f"partition had format detected as auto -> invalid state. Partition: {partition}")
104126
else:
105127
assert_exhaustive_enum(input_format)
106128

129+
pdread_kwargs = {**self.pdread_default_kwargs[input_format], **self.pdread_user_kwargs}
130+
logger.debug(f"reader kwargs {pdread_kwargs} for partition {partition}")
107131
try:
108132
with fs.open(partition.url, "rb") as fd:
109-
df = reader(fd)
133+
df = reader(fd, **pdread_kwargs)
110134
except FileNotFoundError as e:
111135
logger.warning(f"file {partition} reading exception {type(e)}, attempting cache invalidation and reread")
112136
fs.invalidate_cache()
113137
with fs.open(partition.url, "rb") as fd:
114-
df = reader(fd)
138+
df = reader(fd, **pdread_kwargs)
115139

116140
for key, value in partition.columns.items():
117141
df[key] = value

tests/test_pandasreader.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pandas as pd
2+
import pytest
3+
from pandas.testing import assert_frame_equal
4+
5+
from fsql.api import read_partitioned_table
6+
from fsql.deser import InputFormat, PandasReader
7+
from fsql.query import Q_TRUE
8+
9+
df1 = pd.DataFrame(data={"c1": [0, 1], "c2": ["hello", "world"]})
10+
11+
12+
def test_input_format_override(tmp_path):
13+
"""Test that explicitly setting format overrides suffix."""
14+
15+
case1_path = tmp_path / "table1"
16+
case1_path.mkdir(parents=True)
17+
df1.to_csv(case1_path / "f1.json", index=False) # confuse the default by bad suffix
18+
19+
with pytest.raises(ValueError, match="Expected object or value"):
20+
# this test condition is quite brittle! A better match would be desired
21+
failure_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE)
22+
23+
reader = PandasReader(input_format=InputFormat.CSV)
24+
succ_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader)
25+
assert_frame_equal(df1, succ_result)
26+
27+
28+
def test_parquet_kwargs(tmp_path):
29+
"""Test that a kwarg (`columns`) gets passed through and obeyed."""
30+
31+
case1_path = tmp_path / "table1"
32+
case1_path.mkdir(parents=True)
33+
df1.to_parquet(case1_path / "f1.parquet", index=False)
34+
35+
reader = PandasReader(columns=["c2"])
36+
result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader)
37+
assert_frame_equal(df1[["c2"]], result)

0 commit comments

Comments
 (0)