Skip to content

Commit 512740d

Browse files
Remove code that is not needed in tabmat v4 / glum v3 (#741)
* Remove check_array from predict() We don't need it here as predict calls linear_redictor, and the latter does this check. We can avoid doing it twice. * Remove _name_categorical_variable parts There is no need for those as Tabmat v4 handles variable names internally. --------- Co-authored-by: Martin Stancsics <martin.stancsics@gmail.com>
1 parent 248c1dc commit 512740d

File tree

1 file changed

+0
-36
lines changed

1 file changed

+0
-36
lines changed

src/glum/_glm.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,6 @@ def _check_offset(
228228
return offset
229229

230230

231-
def _name_categorical_variables(
232-
categories: tuple[str], column_name: str, drop_first: bool
233-
):
234-
new_names = [
235-
f"{column_name}__{category}" for category in categories[int(drop_first) :]
236-
]
237-
if len(new_names) == 0:
238-
raise ValueError(
239-
f"Categorical column: {column_name}, contains only one category. "
240-
+ "This should be dropped from the feature matrix."
241-
)
242-
return new_names
243-
244-
245231
def _parse_formula(
246232
formula: FormulaSpec, include_intercept: bool = True
247233
) -> tuple[Optional[Formula], Formula]:
@@ -1424,16 +1410,6 @@ def predict(
14241410
)
14251411
X = self._convert_from_pandas(X, context=captured_context)
14261412

1427-
X = check_array_tabmat_compliant(
1428-
X,
1429-
accept_sparse=["csr", "csc", "coo"],
1430-
dtype="numeric",
1431-
copy=self._should_copy_X(),
1432-
ensure_2d=True,
1433-
allow_nd=False,
1434-
drop_first=getattr(self, "drop_first", False),
1435-
)
1436-
14371413
eta = self.linear_predictor(
14381414
X, offset=offset, alpha_index=alpha_index, alpha=alpha
14391415
)
@@ -2718,18 +2694,6 @@ def _set_up_and_check_fit_args(
27182694
self.feature_dtypes_ = X.dtypes.to_dict()
27192695

27202696
if any(X.dtypes == "category"):
2721-
self.feature_names_ = list(
2722-
chain.from_iterable(
2723-
_name_categorical_variables(
2724-
dtype.categories,
2725-
column,
2726-
getattr(self, "drop_first", False),
2727-
)
2728-
if isinstance(dtype, pd.CategoricalDtype)
2729-
else [column]
2730-
for column, dtype in zip(X.columns, X.dtypes)
2731-
)
2732-
)
27332697

27342698
def _expand_categorical_penalties(penalty, X, drop_first):
27352699
"""

0 commit comments

Comments
 (0)