Skip to content

Commit 8230d6e

Browse files
Merge pull request #220 from matthewwardrop/improve_contrast_prefixes
2 parents cb88cc2 + 6e562d3 commit 8230d6e

File tree

6 files changed

+140
-65
lines changed

6 files changed

+140
-65
lines changed

formulaic/materializers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ def wrapped(
758758
encoded = FactorValues(
759759
encoded.copy(),
760760
metadata=encoded.__formulaic_metadata__, # type: ignore
761+
reduced=True,
761762
)
762763
del encoded[encoded.__formulaic_metadata__.drop_field]
763764

@@ -782,7 +783,7 @@ def _flatten_encoded_evaled_factor(
782783
# Some nested dictionaries may not be a `FactorValues[dict]` instance,
783784
# in which case we impute the default formatter in `FactorValues.format`.
784785
if hasattr(values, "__formulaic_metadata__"):
785-
name_format = values.__formulaic_metadata__.format
786+
name_format = values.__formulaic_metadata__.get_format()
786787
else:
787788
name_format = FactorValuesMetadata.format
788789

formulaic/materializers/types/factor_values.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ class FactorValuesMetadata:
4343
4444
Attributes:
4545
kind: The kind of the evaluated values.
46-
spans_intercept: Whether the values span the intercept or not.
47-
drop_field: If the values do span the intercept, and we want to reduce
48-
the rank, which field should be dropped.
4946
format: The format to use when exploding factors into multiple columns
5047
(e.g. when encoding categories via dummy-encoding).
5148
encoded: Whether the values should be treated as pre-encoded.
@@ -55,18 +52,37 @@ class FactorValuesMetadata:
5552
materializer. Note that this should only be used in cases where
5653
direct evaluation would yield different results in reduced vs.
5754
non-reduced rank scenarios.
55+
56+
Rank-Reduction Attributes:
57+
spans_intercept: Whether the values span the intercept or not.
58+
drop_field: If the values do span the intercept, and we want to reduce
59+
the rank, which field should be dropped.
60+
reduced: Whether the rank has already been reduced by dropping the
61+
`drop_field` above.
62+
format_reduced: The format to use when exploding factors (as above), but
63+
in the case where the rank has been reduced by dropping a field.
64+
(This defaults to `format`.)
5865
"""
5966

6067
kind: Factor.Kind = Factor.Kind.UNKNOWN
6168
column_names: Optional[Tuple[str]] = None
62-
spans_intercept: bool = False
63-
drop_field: Optional[str] = None
6469
format: str = "{name}[{field}]"
6570
encoded: bool = False
6671
encoder: Optional[
6772
Callable[[Any, bool, List[int], Dict[str, Any], ModelSpec], Any]
6873
] = None
6974

75+
# Rank-Reduction Attributes
76+
spans_intercept: bool = False
77+
drop_field: Optional[str] = None
78+
reduced: bool = False
79+
format_reduced: Optional[str] = None
80+
81+
def get_format(self) -> str:
82+
return (
83+
self.format_reduced if self.reduced and self.format_reduced else self.format
84+
)
85+
7086
def replace(self, **kwargs: Any) -> FactorValuesMetadata:
7187
"""
7288
Return a copy of this `FactorValuesMetadata` instance with the nominated
@@ -91,25 +107,29 @@ def __init__(
91107
*,
92108
kind: Union[str, Factor.Kind, _MissingType] = MISSING,
93109
column_names: Union[Tuple[Hashable, ...], _MissingType] = MISSING,
94-
spans_intercept: Union[bool, _MissingType] = MISSING,
95-
drop_field: Union[None, Hashable, _MissingType] = MISSING,
96110
format: Union[str, _MissingType] = MISSING, # pylint: disable=redefined-builtin
97111
encoded: Union[bool, _MissingType] = MISSING,
98112
encoder: Union[
99113
None,
100114
Callable[[Any, bool, List[int], Dict[str, Any], ModelSpec], Any],
101115
_MissingType,
102116
] = MISSING,
117+
spans_intercept: Union[bool, _MissingType] = MISSING,
118+
drop_field: Union[None, Hashable, _MissingType] = MISSING,
119+
reduced: Union[bool, _MissingType] = MISSING,
120+
format_reduced: Union[str, _MissingType] = MISSING,
103121
):
104122
metadata_constructor: Callable = FactorValuesMetadata
105123
metadata_kwargs = dict(
106124
kind=Factor.Kind(kind) if kind is not MISSING else kind,
107125
column_names=column_names,
108-
spans_intercept=spans_intercept,
109-
drop_field=drop_field,
110126
format=format,
111127
encoded=encoded,
112128
encoder=encoder,
129+
spans_intercept=spans_intercept,
130+
drop_field=drop_field,
131+
reduced=reduced,
132+
format_reduced=format_reduced,
113133
)
114134
for key in set(metadata_kwargs):
115135
if metadata_kwargs[key] is MISSING:

formulaic/transforms/contrasts.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class Contrasts(metaclass=InterfaceMeta):
208208
INTERFACE_RAISE_ON_VIOLATION = True
209209

210210
FACTOR_FORMAT = "{name}[{field}]"
211+
FACTOR_FORMAT_REDUCED = "{name}[{field}]"
211212

212213
def apply(
213214
self,
@@ -251,9 +252,11 @@ def apply(
251252
if not levels or len(levels) == 1 and reduced_rank:
252253
if output == "pandas":
253254
encoded = pandas.DataFrame(
254-
index=dummies.index
255-
if isinstance(dummies, pandas.DataFrame)
256-
else range(dummies.shape[0])
255+
index=(
256+
dummies.index
257+
if isinstance(dummies, pandas.DataFrame)
258+
else range(dummies.shape[0])
259+
)
257260
)
258261
elif output == "numpy":
259262
encoded = numpy.ones((dummies.shape[0], 0))
@@ -269,6 +272,7 @@ def apply(
269272
column_names=cast(Tuple[Hashable], ()),
270273
spans_intercept=False,
271274
format=self.get_factor_format(levels, reduced_rank=reduced_rank),
275+
format_reduced=self.get_factor_format(levels, reduced_rank=True),
272276
encoded=True,
273277
)
274278

@@ -295,6 +299,7 @@ def apply(
295299
spans_intercept=self.get_spans_intercept(levels, reduced_rank=reduced_rank),
296300
drop_field=self.get_drop_field(levels, reduced_rank=reduced_rank),
297301
format=self.get_factor_format(levels, reduced_rank=reduced_rank),
302+
format_reduced=self.get_factor_format(levels, reduced_rank=True),
298303
encoded=True,
299304
)
300305

@@ -480,7 +485,7 @@ def get_factor_format(
480485
levels: The names of the levels/categories in the data.
481486
reduced_rank: Whether the contrast encoding used had reduced rank.
482487
"""
483-
return self.FACTOR_FORMAT
488+
return self.FACTOR_FORMAT_REDUCED if reduced_rank else self.FACTOR_FORMAT
484489

485490

486491
@dataclass
@@ -493,7 +498,7 @@ class TreatmentContrasts(Contrasts):
493498
is taken to be the first level.
494499
"""
495500

496-
FACTOR_FORMAT = "{name}[T.{field}]"
501+
FACTOR_FORMAT_REDUCED = "{name}[T.{field}]"
497502

498503
base: Hashable = UNSET
499504

@@ -609,7 +614,7 @@ class SumContrasts(Contrasts):
609614
(except the last, which is redundant) to the global average of all levels.
610615
"""
611616

612-
FACTOR_FORMAT = "{name}[S.{field}]"
617+
FACTOR_FORMAT_REDUCED = "{name}[S.{field}]"
613618

614619
@Contrasts.override
615620
def _get_coding_matrix(
@@ -659,7 +664,7 @@ class HelmertContrasts(Contrasts):
659664
integer one).
660665
"""
661666

662-
FACTOR_FORMAT = "{name}[H.{field}]"
667+
FACTOR_FORMAT_REDUCED = "{name}[H.{field}]"
663668

664669
reverse: bool = True
665670
scale: bool = False
@@ -729,7 +734,7 @@ class DiffContrasts(Contrasts):
729734
Level 1 cf. Level 1 - Level 2).
730735
"""
731736

732-
FACTOR_FORMAT = "{name}[D.{field}]"
737+
FACTOR_FORMAT_REDUCED = "{name}[D.{field}]"
733738

734739
backward: bool = True
735740

@@ -792,7 +797,6 @@ class PolyContrasts(Contrasts):
792797
have the same cardinality as the categories being coded.
793798
"""
794799

795-
FACTOR_FORMAT = "{name}{field}"
796800
NAME_ALIASES = {
797801
1: ".L",
798802
2: ".Q",

tests/materializers/test_arrow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def check_for_pyarrow():
1919
"a": (["Intercept", "a"], ["Intercept", "a"]),
2020
"A": (
2121
["Intercept", "A[T.b]", "A[T.c]"],
22-
["Intercept", "A[T.a]", "A[T.b]", "A[T.c]"],
22+
["Intercept", "A[a]", "A[b]", "A[c]"],
2323
),
2424
"C(A)": (
2525
["Intercept", "C(A)[T.b]", "C(A)[T.c]"],
26-
["Intercept", "C(A)[T.a]", "C(A)[T.b]", "C(A)[T.c]"],
26+
["Intercept", "C(A)[a]", "C(A)[b]", "C(A)[c]"],
2727
),
2828
"a:A": (
29-
["Intercept", "a:A[T.a]", "a:A[T.b]", "a:A[T.c]"],
30-
["Intercept", "a:A[T.a]", "a:A[T.b]", "a:A[T.c]"],
29+
["Intercept", "a:A[a]", "a:A[b]", "a:A[c]"],
30+
["Intercept", "a:A[a]", "a:A[b]", "a:A[c]"],
3131
),
3232
}
3333

tests/materializers/test_pandas.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,45 @@
2424
"a": (["Intercept", "a"], ["Intercept", "a"], ["Intercept", "a"], 2),
2525
"A": (
2626
["Intercept", "A[T.b]", "A[T.c]"],
27-
["Intercept", "A[T.a]", "A[T.b]", "A[T.c]"],
27+
["Intercept", "A[a]", "A[b]", "A[c]"],
2828
["Intercept", "A[T.c]"],
2929
2,
3030
),
3131
"C(A)": (
3232
["Intercept", "C(A)[T.b]", "C(A)[T.c]"],
33-
["Intercept", "C(A)[T.a]", "C(A)[T.b]", "C(A)[T.c]"],
33+
["Intercept", "C(A)[a]", "C(A)[b]", "C(A)[c]"],
3434
["Intercept", "C(A)[T.c]"],
3535
2,
3636
),
3737
"A:a": (
38-
["Intercept", "A[T.a]:a", "A[T.b]:a", "A[T.c]:a"],
39-
["Intercept", "A[T.a]:a", "A[T.b]:a", "A[T.c]:a"],
40-
["Intercept", "A[T.a]:a"],
38+
["Intercept", "A[a]:a", "A[b]:a", "A[c]:a"],
39+
["Intercept", "A[a]:a", "A[b]:a", "A[c]:a"],
40+
["Intercept", "A[a]:a"],
4141
1,
4242
),
4343
"A:B": (
4444
[
4545
"Intercept",
4646
"B[T.b]",
4747
"B[T.c]",
48-
"A[T.b]:B[T.a]",
49-
"A[T.c]:B[T.a]",
50-
"A[T.b]:B[T.b]",
51-
"A[T.c]:B[T.b]",
52-
"A[T.b]:B[T.c]",
53-
"A[T.c]:B[T.c]",
48+
"A[T.b]:B[a]",
49+
"A[T.c]:B[a]",
50+
"A[T.b]:B[b]",
51+
"A[T.c]:B[b]",
52+
"A[T.b]:B[c]",
53+
"A[T.c]:B[c]",
5454
],
5555
[
5656
"Intercept",
57-
"A[T.a]:B[T.a]",
58-
"A[T.b]:B[T.a]",
59-
"A[T.c]:B[T.a]",
60-
"A[T.a]:B[T.b]",
61-
"A[T.b]:B[T.b]",
62-
"A[T.c]:B[T.b]",
63-
"A[T.a]:B[T.c]",
64-
"A[T.b]:B[T.c]",
65-
"A[T.c]:B[T.c]",
57+
"A[a]:B[a]",
58+
"A[b]:B[a]",
59+
"A[c]:B[a]",
60+
"A[a]:B[b]",
61+
"A[b]:B[b]",
62+
"A[c]:B[b]",
63+
"A[a]:B[c]",
64+
"A[b]:B[c]",
65+
"A[c]:B[c]",
6666
],
6767
["Intercept"],
6868
1,
@@ -324,7 +324,7 @@ def test_encoding_edge_cases(self, materializer):
324324
spec=ModelSpec(formula=[]),
325325
drop_rows=[],
326326
)
327-
) == ["B[a][T.a]", "B[a][T.b]", "B[a][T.c]"]
327+
) == ["B[a][a]", "B[a][b]", "B[a][c]"]
328328

329329
def test_empty(self, materializer):
330330
mm = materializer.get_model_matrix("0", ensure_full_rank=True)
@@ -366,27 +366,27 @@ def test_category_reordering(self):
366366
)
367367

368368
m = PandasMaterializer(data).get_model_matrix("A + 0", ensure_full_rank=False)
369-
assert list(m.columns) == ["A[T.a]", "A[T.b]", "A[T.c]"]
369+
assert list(m.columns) == ["A[a]", "A[b]", "A[c]"]
370370
assert list(m.model_spec.get_model_matrix(data3).columns) == [
371-
"A[T.a]",
372-
"A[T.b]",
373-
"A[T.c]",
371+
"A[a]",
372+
"A[b]",
373+
"A[c]",
374374
]
375375

376376
m2 = PandasMaterializer(data2).get_model_matrix("A + 0", ensure_full_rank=False)
377-
assert list(m2.columns) == ["A[T.a]", "A[T.b]", "A[T.c]"]
377+
assert list(m2.columns) == ["A[a]", "A[b]", "A[c]"]
378378
assert list(m2.model_spec.get_model_matrix(data3).columns) == [
379-
"A[T.a]",
380-
"A[T.b]",
381-
"A[T.c]",
379+
"A[a]",
380+
"A[b]",
381+
"A[c]",
382382
]
383383

384384
m3 = PandasMaterializer(data3).get_model_matrix("A + 0", ensure_full_rank=False)
385-
assert list(m3.columns) == ["A[T.c]", "A[T.b]", "A[T.a]"]
385+
assert list(m3.columns) == ["A[c]", "A[b]", "A[a]"]
386386
assert list(m3.model_spec.get_model_matrix(data).columns) == [
387-
"A[T.c]",
388-
"A[T.b]",
389-
"A[T.a]",
387+
"A[c]",
388+
"A[b]",
389+
"A[a]",
390390
]
391391

392392
def test_term_clustering(self, materializer):

0 commit comments

Comments
 (0)