Skip to content

Commit

Permalink
add test for xgb
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-rinchin committed Jul 31, 2024
1 parent 2ad9c35 commit 5b5e243
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ license = "Apache-2.0"
homepage = "https://lightautoml.readthedocs.io/en/latest/"
repository = "https://github.com/AILab-MLTools/LightAutoML"
classifiers = [
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Development Status :: 3 - Alpha",
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Natural Language :: English",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_automl/test_presets/test_tabularautoml_xgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from sklearn.metrics import roc_auc_score

from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.tasks import Task
from tests.unit.test_automl.test_presets.presets_utils import check_pickling
from tests.unit.test_automl.test_presets.presets_utils import get_target_name


class TestTabularAutoMLXGB:
def test_fit_predict(self, sampled_app_train_test, sampled_app_roles, binary_task):
# load and prepare data
train, test = sampled_app_train_test

# run automl
automl = TabularAutoML(task=binary_task, general_params={"use_algos": [["xgb"]]})
oof_predictions = automl.fit_predict(train, roles=sampled_app_roles, verbose=10)
ho_predictions = automl.predict(test)

# calculate scores
target_name = get_target_name(sampled_app_roles)
oof_score = roc_auc_score(train[target_name].values, oof_predictions.data[:, 0])
ho_score = roc_auc_score(test[target_name].values, ho_predictions.data[:, 0])

# checks
assert oof_score > 0.65
assert ho_score > 0.65

check_pickling(automl, ho_score, binary_task, test, target_name)

0 comments on commit 5b5e243

Please sign in to comment.