Skip to content

Commit

Permalink
[Feature] add optuna based optimize (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 18, 2024
1 parent 4e19e88 commit 53334a2
Showing 5 changed files with 481 additions and 3 deletions.
2 changes: 1 addition & 1 deletion imlightgbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# ruff: noqa
from imlightgbm.engine import cv, train
from imlightgbm.engine import cv, optimize, train
67 changes: 66 additions & 1 deletion imlightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,11 @@

import lightgbm as lgb
import numpy as np
import optuna
from sklearn.model_selection import BaseCrossValidator

from imlightgbm.objective import set_params
from imlightgbm.utils import docstring
from imlightgbm.utils import docstring, optimize_doc


@docstring(lgb.train.__doc__)
@@ -79,3 +80,67 @@ def cv(
eval_train_metric=eval_train_metric,
return_cvbooster=return_cvbooster,
)


def get_params(trial: optuna.Trial):
return {
"alpha": trial.suggest_float("alpha", 0.25, 0.75),
"gamma": trial.suggest_float("gamma", 0.0, 3.0),
"num_leaves": trial.suggest_int("num_leaves", 20, 150),
"learning_rate": trial.suggest_float("learning_rate", 0.005, 0.1),
"feature_fraction": trial.suggest_float("feature_fraction", 0.5, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.5, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
}


def optimize(
train_set: lgb.Dataset,
num_trials: int = 10,
num_boost_round: int = 100,
folds: Iterable[tuple[np.ndarray, np.ndarray]] | BaseCrossValidator | None = None,
nfold: int = 5,
stratified: bool = True,
shuffle: bool = True,
get_params: Callable[[optuna.Trial], dict[str, Any]] = get_params,
init_model: str | lgb.Path | lgb.Booster | None = None,
feature_name: list[str] | Literal["auto"] = "auto",
categorical_feature: list[str] | list[int] | Literal["auto"] = "auto",
fpreproc: Callable[
[lgb.Dataset, lgb.Dataset, dict[str, Any]],
tuple[lgb.Dataset, lgb.Dataset, dict[str, Any]],
]
| None = None,
seed: int = 0,
callbacks: list[Callable] | None = None,
) -> optuna.Study:
def _objective(trial: optuna.Trial):
"""Optuna objective function."""
params = get_params(trial)
cv_results = cv(
params=params,
train_set=train_set,
num_boost_round=num_boost_round,
folds=folds,
nfold=nfold,
stratified=stratified,
shuffle=shuffle,
init_model=init_model,
feature_name=feature_name,
categorical_feature=categorical_feature,
fpreproc=fpreproc,
seed=seed,
callbacks=callbacks,
eval_train_metric=False,
return_cvbooster=False,
)
_keys = [_ for _ in cv_results.keys() if _.endswith("mean")]
assert len(_keys) == 1
return min(cv_results[_keys[0]])

study = optuna.create_study(direction="minimize")
study.optimize(_objective, n_trials=num_trials)
return study


optimize.__doc__ = optimize_doc
64 changes: 64 additions & 0 deletions imlightgbm/utils.py
Original file line number Diff line number Diff line change
@@ -30,3 +30,67 @@ def init_logger() -> logging.Logger:


logger = init_logger()


optimize_doc = """Perform the cross-validation with given parameters.
Parameters
----------
train_set : Dataset
Data to be trained on.
num_trials : int, optional (default=10)
Number of hyperparameter search trials.
num_boost_round : int, optional (default=100)
Number of boosting iterations.
folds : generator or iterator of (train_idx, test_idx) tuples, scikit-learn splitter object or None, optional (default=None)
If generator or iterator, it should yield the train and test indices for each fold.
If object, it should be one of the scikit-learn splitter classes
(https://scikit-learn.org/stable/modules/classes.html#splitter-classes)
and have ``split`` method.
This argument has highest priority over other data split arguments.
nfold : int, optional (default=5)
Number of folds in CV.
stratified : bool, optional (default=True)
Whether to perform stratified sampling.
shuffle : bool, optional (default=True)
Whether to shuffle before splitting data.
get_params : callable, optional (default=get_params)
def get_params(trial: optuna.Trial):
return {
'alpha': trial.suggest_float('alpha', .25, .75),
'gamma': trial.suggest_float('gamma', .0, 3.),
'num_leaves': trial.suggest_int('num_leaves', 20, 150),
'learning_rate': trial.suggest_float('learning_rate', 0.005, 0.1),
'feature_fraction': trial.suggest_float('feature_fraction', 0.5, 1.0),
'bagging_fraction': trial.suggest_float('bagging_fraction', 0.5, 1.0),
'bagging_freq': trial.suggest_int('bagging_freq', 1, 7),
}
init_model : str, pathlib.Path, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of str, or 'auto', optional (default="auto")
**Deprecated.** Set ``feature_name`` on ``train_set`` instead.
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of str or int, or 'auto', optional (default="auto")
**Deprecated.** Set ``categorical_feature`` on ``train_set`` instead.
Categorical features.
If list of int, interpreted as indices.
If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
Floating point numbers in categorical features will be rounded towards 0.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
seed : int, optional (default=0)
Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callable, or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
Returns
-------
study: optuna.Study
"""
Loading

0 comments on commit 53334a2

Please sign in to comment.