Skip to content

Commit

Permalink
refactor objective select
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 28, 2024
1 parent d0a3987 commit ba87250
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions imlightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def __init__(
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more details.
"""
self.num_class = num_class
_objective_enum, _objective = self.__objective_select(objective=objective)
_objective_enum: Objective = Objective.get(objective)
self.__alpha_select(objective=_objective_enum, alpha=alpha)
self.__gamma_select(objective=_objective_enum, gamma=gamma)

_objective = self.__objective_select(objective_enum=_objective_enum)
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
Expand Down Expand Up @@ -127,10 +127,9 @@ def predict(

predict.__doc__ = LGBMClassifier.predict.__doc__

def __objective_select(self, objective: str) -> tuple[Objective, _SklearnObjLike]:
def __objective_select(self, objective_enum: Objective) -> _SklearnObjLike:
"""Select objective function."""
_objective: Objective = Objective.get(objective)
if _objective in {
if objective_enum in {
Objective.multiclass_focal,
Objective.multiclass_weighted,
} and not isinstance(self.num_class, int):
Expand All @@ -154,7 +153,7 @@ def __objective_select(self, objective: str) -> tuple[Objective, _SklearnObjLike
y_true=y_true, y_pred=y_pred, alpha=self.alpha, num_class=self.num_class
),
}
return _objective, _objective_mapper[_objective]
return _objective_mapper[objective_enum]

def __param_select(
self,
Expand Down

0 comments on commit ba87250

Please sign in to comment.