Skip to content

Commit

Permalink
fix xgb losses and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-rinchin committed Jul 31, 2024
1 parent 0d0cf0e commit 5775cbf
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
12 changes: 6 additions & 6 deletions lightautoml/ml_algo/boost_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def _infer_params(
fobj = loss.fobj

# # get metric params
# params["metric"] = loss.metric_name
params["eval_metric"] = loss.metric_name
feval = loss.feval

# params["num_class"] = self.n_classes
params["num_class"] = self.n_classes
# add loss and tasks params if defined
params = {**params, **loss.fobj_params, **loss.metric_params}

Expand Down Expand Up @@ -188,24 +188,24 @@ def _get_default_search_spaces(self, suggested_params: Dict, estimated_n_trials:
"""
optimization_search_space = {}

optimization_search_space["feature_fraction"] = Uniform(
optimization_search_space["colsample_bytree"] = Uniform(
low=0.5,
high=1.0,
)

optimization_search_space["num_leaves"] = Uniform(
optimization_search_space["max_leaves"] = Uniform(
low=16,
high=255,
q=1,
)

if estimated_n_trials > 30:
optimization_search_space["bagging_fraction"] = Uniform(
optimization_search_space["subsample"] = Uniform(
low=0.5,
high=1.0,
)

optimization_search_space["min_sum_hessian_in_leaf"] = Uniform(
optimization_search_space["min_child_weight"] = Uniform(
low=1e-3,
high=10.0,
log=True,
Expand Down
10 changes: 5 additions & 5 deletions lightautoml/tasks/losses/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
}

_xgb_reg_metrics_dict = {
"mse": "mse", # TODO
"mse": "rmse",
"mae": "mae",
# "r2": "mse", # TODO
"r2": "rmse",
"rmsle": "rmsle",
"mape": "mape",
"quantile": "reg:quantileerror",
"huber": "reg:pseudohubererror",
# "fair": "fair",# TODO
# "quantile": "", # TODO
"huber": "mphe",
# "fair": "",# TODO
}

_xgb_multiclass_metrics_dict = {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ scikit-learn = [
]
lightgbm = ">=2.3, <=3.2.1"
catboost = ">=0.26.1"
xgboost = "^2.0.0"
optuna = "*"
torch = [
{platform = "win32", python = "3.6.1", version = "1.7.0"},
Expand Down

0 comments on commit 5775cbf

Please sign in to comment.