Skip to content

Commit

Permalink
update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Oct 4, 2024
1 parent e335e1d commit 653b51b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 17 deletions.
6 changes: 4 additions & 2 deletions imlightgbm/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@
Parameters
----------
objective : str
Specify the learning objective
'binary_focal', 'binary_weighted'.
Specify the learning objective.
Options are 'binary_focal', 'binary_weighted', 'multiclass focal', and 'multiclas_weighted'.
alpha: float
Parameter used with 'binary_weighted' and 'multiclass_weighted' objective.
gamma: float
Parameter used with 'binary_focal' and 'multiclass_focal' objective.
boosting_type : str, optional (default='gbdt')
'gbdt', traditional Gradient Boosting Decision Tree.
'dart', Dropouts meet Multiple Additive Regression Trees.
Expand Down
2 changes: 2 additions & 0 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def predict(
else:
return softmax(_predict, axis=1)

predict.__doc__ = lgb.Booster.predict.__doc__


@add_docstring("train")
def train(
Expand Down
16 changes: 1 addition & 15 deletions imlightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ImbalancedLGBMClassifier(LGBMClassifier):
"""Inbalanced LightGBM classifier."""
"""Imbalanced LightGBM classifier."""

@add_docstring("classifier")
def __init__(
Expand Down Expand Up @@ -48,19 +48,6 @@ def __init__(
importance_type: str = "split",
num_class: int | None = None,
) -> None:
"""Construct a gradient boosting model.
Parameters
----------
objective : str
Specify the learning objective. Options are 'binary_focal' and 'binary_weighted'.
alpha: float
For 'binary_weighted' objective
gamma: float
For 'binary_focal' objective
other parameters:
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more details.
"""
self.num_class = num_class
_objective_enum: Objective = Objective.get(objective)
self.__alpha_select(objective=_objective_enum, alpha=alpha)
Expand Down Expand Up @@ -100,7 +87,6 @@ def predict(
validate_features: bool = False,
**kwargs: Any,
) -> np.ndarray | spmatrix | list[spmatrix]:
"""Docstring is inherited from the LGBMClassifier."""
_predict = super().predict(
X=X,
raw_score=raw_score,
Expand Down

0 comments on commit 653b51b

Please sign in to comment.