Skip to content

Commit 2c217dd

Browse files
Merge pull request #225 from matthewwardrop/pandas_dict_recarray
2 parents 05bfa25 + c92cc83 commit 2c217dd

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

formulaic/materializers/pandas.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools
44
import itertools
5+
from collections.abc import Mapping
56
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Set, Tuple, cast
67

78
import numpy
@@ -22,9 +23,24 @@
2223

2324
class PandasMaterializer(FormulaMaterializer):
2425
REGISTER_NAME = "pandas"
25-
REGISTER_INPUTS: Sequence[str] = ("pandas.core.frame.DataFrame", "pandas.DataFrame")
26+
REGISTER_INPUTS: Sequence[str] = (
27+
"pandas.core.frame.DataFrame",
28+
"pandas.DataFrame",
29+
"dict",
30+
"numpy.rec.recarray",
31+
)
2632
REGISTER_OUTPUTS: Sequence[str] = ("pandas", "numpy", "sparse")
2733

34+
@override
35+
def _init(self) -> None:
36+
if isinstance(self.data, (dict, Mapping)):
37+
if all(numpy.isscalar(v) for v in self.data.values()):
38+
self.data = pandas.DataFrame(self.data, index=[0])
39+
else:
40+
self.data = pandas.DataFrame(self.data)
41+
elif isinstance(self.data, numpy.rec.recarray):
42+
self.data = pandas.DataFrame.from_records(self.data)
43+
2844
@override
2945
def _is_categorical(self, values: Any) -> bool:
3046
if isinstance(values, (pandas.Series, pandas.Categorical)):

tests/materializers/test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class TestFormulaMaterializer:
1818
def test_registrations(self):
1919
assert sorted(FormulaMaterializer.REGISTERED_NAMES) == ["arrow", "pandas"]
2020
assert sorted(FormulaMaterializer.REGISTERED_INPUTS) == [
21+
"dict",
22+
"numpy.rec.recarray",
2123
"pandas.DataFrame",
2224
"pandas.core.frame.DataFrame",
2325
"pyarrow.lib.Table",

tests/materializers/test_pandas.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,23 @@ def data_with_nulls(self):
117117
def materializer(self, data):
118118
return PandasMaterializer(data)
119119

120+
def test_data_conversion(self):
121+
df = PandasMaterializer({"a": [1, 2, 3]}).data
122+
assert isinstance(df, pandas.DataFrame)
123+
assert df.columns == ["a"]
124+
125+
df2 = PandasMaterializer({"a": 1}).data
126+
assert isinstance(df2, pandas.DataFrame)
127+
assert df2.columns == ["a"]
128+
assert list(df2["a"]) == [1]
129+
130+
df3 = PandasMaterializer(
131+
numpy.recarray((2,), dtype=[("x", int), ("y", float), ("z", int)])
132+
).data
133+
assert isinstance(df3, pandas.DataFrame)
134+
assert list(df3.columns) == ["x", "y", "z"]
135+
assert len(df3["x"]) == 2
136+
120137
@pytest.mark.parametrize("formula,tests", PANDAS_TESTS.items())
121138
def test_get_model_matrix(self, materializer, formula, tests):
122139
mm = materializer.get_model_matrix(formula, ensure_full_rank=True)

0 commit comments

Comments
 (0)