Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 71 additions & 40 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2724,6 +2724,58 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
)
self._auto_augment = auto_augment

if "auto" == estimator_list:
if self._state.task == "rank":
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth"]
elif _is_nlp_task(self._state.task):
estimator_list = ["transformer"]
elif self._state.task == TS_FORECASTPANEL:
estimator_list = ["tft"]
else:
try:
import catboost

estimator_list = [
"lgbm",
"rf",
"catboost",
"xgboost",
"extra_tree",
"xgb_limitdepth",
]
except ImportError:
estimator_list = [
"lgbm",
"rf",
"xgboost",
"extra_tree",
"xgb_limitdepth",
]
if self._state.task in TS_FORECAST:
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
if "catboost" in estimator_list:
estimator_list.remove("catboost")
if self._state.task in TS_FORECASTREGRESSION:
try:
import prophet

estimator_list += ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list += ["arima", "sarimax"]
elif "regression" != self._state.task:
estimator_list += ["lrl1"]

if isinstance(starting_points, dict):
_estimators_from_starting_points = starting_points.keys()
if not any(i in estimator_list for i in _estimators_from_starting_points):
logger.warning(
"The proivded starting_points {} is removed as it does not contain relevant estimators as keys"
" and is thus NOT used. Please check the required format of starting_points.".format(
starting_points
)
)
starting_points = {}

_sample_size_from_starting_points = {}
if isinstance(starting_points, dict):
for _estimator, _point_per_estimator in starting_points.items():
Expand Down Expand Up @@ -2843,46 +2895,6 @@ def is_to_reverse_metric(metric, task):
error_metric = "customized metric"
logger.info(f"Minimizing error metric: {error_metric}")

if "auto" == estimator_list:
if self._state.task == "rank":
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth"]
elif _is_nlp_task(self._state.task):
estimator_list = ["transformer"]
elif self._state.task == TS_FORECASTPANEL:
estimator_list = ["tft"]
else:
try:
import catboost

estimator_list = [
"lgbm",
"rf",
"catboost",
"xgboost",
"extra_tree",
"xgb_limitdepth",
]
except ImportError:
estimator_list = [
"lgbm",
"rf",
"xgboost",
"extra_tree",
"xgb_limitdepth",
]
if self._state.task in TS_FORECAST:
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
if "catboost" in estimator_list:
estimator_list.remove("catboost")
if self._state.task in TS_FORECASTREGRESSION:
try:
import prophet

estimator_list += ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list += ["arima", "sarimax"]
elif "regression" != self._state.task:
estimator_list += ["lrl1"]
# When no search budget is specified
if no_budget:
max_iter = len(estimator_list)
Expand Down Expand Up @@ -3557,6 +3569,25 @@ def _search_sequential(self):
state.best_config,
self.data_size_full,
)
if getattr(self._trained_estimator, "params", {}) and getattr(
self._trained_estimator, "ITER_HP", None
):
_hp_trained_iter = self._trained_estimator.params.get(
self._trained_estimator.ITER_HP
)
_best_config_iter = self.best_config.get(
self._trained_estimator.ITER_HP
)
if _hp_trained_iter != _best_config_iter:
logger.warning(
"Early stopping happened when retraining a model with the best configuration."
f" The best config's {self._trained_estimator.ITER_HP} is {_best_config_iter}"
f" and the actual {self._trained_estimator.ITER_HP} used for retraining the model is {_hp_trained_iter}."
" This early stopping happens because flaml needs to do its best effort to"
" retrain without violating the time budget when retrain_full is set to 'budget'. "
" If this mismatch is not desired, please set retrain_full to True."
)

logger.info(
"retrain {} for {:.1f}s".format(self._best_estimator, retrain_time)
)
Expand Down
11 changes: 10 additions & 1 deletion test/automl/test_warmstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ def test_FLAML_sample_size_in_starting_points(self):
except AssertionError:
pass


# In the following test case, the starting_points is not provided in the
# right format and thus we expect a warning for removing the provided
# starting_points when the fit function is called
automl5 = AutoML()
automl_settings["starting_points"] = automl3.best_config
automl5.fit(
X_train,
y_train,
**automl_settings,
)
if __name__ == "__main__":
unittest.main()