Skip to content

Commit

Permalink
add _get_default_search_spaces for xgb
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-rinchin committed Jul 31, 2024
1 parent 5b5e243 commit 5deac11
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
32 changes: 8 additions & 24 deletions lightautoml/ml_algo/boost_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..validation.base import TrainValidIterator
from .base import TabularDataset
from .base import TabularMLAlgo
from .tuning.base import Uniform
from .tuning.base import Uniform, Choice


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,37 +188,21 @@ def _get_default_search_spaces(self, suggested_params: Dict, estimated_n_trials:
"""
optimization_search_space = {}

optimization_search_space["colsample_bytree"] = Uniform(
low=0.5,
high=1.0,
)

optimization_search_space["max_leaves"] = Uniform(
low=16,
high=255,
q=1,
)
optimization_search_space["colsample_bytree"] = Choice(options=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
optimization_search_space["subsample"] = Choice(options=[0.4, 0.5, 0.6, 0.7, 0.8, 1.0])
optimization_search_space["max_depth"] = Choice(options=[5, 7, 9, 11, 13, 15, 17])
optimization_search_space["learning_rate"] = Choice(options=[0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02])

if estimated_n_trials > 30:
optimization_search_space["subsample"] = Uniform(
low=0.5,
high=1.0,
)
optimization_search_space["min_child_weight"] = Uniform(low=1, high=300, q=1, log=False)

optimization_search_space["min_child_weight"] = Uniform(
low=1e-3,
high=10.0,
log=True,
)

if estimated_n_trials > 100:
optimization_search_space["reg_alpha"] = Uniform(
low=1e-8,
low=1e-3,
high=10.0,
log=True,
)
optimization_search_space["reg_lambda"] = Uniform(
low=1e-8,
low=1e-3,
high=10.0,
log=True,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from sklearn.metrics import roc_auc_score

from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.tasks import Task
from tests.unit.test_automl.test_presets.presets_utils import check_pickling
from tests.unit.test_automl.test_presets.presets_utils import get_target_name

Expand Down

0 comments on commit 5deac11

Please sign in to comment.