diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cee68f2e..495169a42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `stype_encoder_dict` to some models ([#319](https://github.com/pyg-team/pytorch-frame/pull/319)) +- Added GBDTs feature importance ([#292](https://github.com/pyg-team/pytorch-frame/pull/292)) + ### Changed - Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286)) diff --git a/examples/tuned_gbdt.py b/examples/tuned_gbdt.py index f414dc7c0..9f2b2d5f0 100644 --- a/examples/tuned_gbdt.py +++ b/examples/tuned_gbdt.py @@ -27,6 +27,7 @@ import random import numpy as np +import pandas as pd import torch from torch_frame.datasets import TabularBenchmark @@ -39,6 +40,7 @@ parser.add_argument('--dataset', type=str, default='eye_movements') parser.add_argument('--saved_model_path', type=str, default='storage/gbdts.txt') +parser.add_argument('--feature_importance', action='store_true') # Add this flag to match the reported number. parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() @@ -88,6 +90,12 @@ gbdt.tune(tf_train=train_dataset.tensor_frame, tf_val=val_dataset.tensor_frame, num_trials=20) gbdt.save(args.saved_model_path) + if args.feature_importance: + scores = pd.DataFrame({ + 'feature': dataset.feat_cols, + 'importance': gbdt.feature_importance() + }).sort_values(by='importance', ascending=False) + print(scores) pred = gbdt.predict(tf_test=test_dataset.tensor_frame) score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred) diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 42787976c..2da03ff31 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -21,7 +21,7 @@ [stype.numerical], [stype.categorical], [stype.text_embedded], - [stype.numerical, stype.numerical, stype.text_embedded], + [stype.numerical, stype.categorical, stype.text_embedded], ]) @pytest.mark.parametrize('task_type_and_metric', [ (TaskType.REGRESSION, Metric.RMSE), @@ -76,7 +76,20 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred) dataset.tensor_frame.y = None loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame) + # TODO: support more stypes + feat_dim = { + stype.numerical: 1, + stype.categorical: 1, + stype.embedding: 8, + } + num_features = sum([ + feat_dim[feat_stype] * len(feat_list) for feat_stype, feat_list in + dataset.tensor_frame.col_names_dict.items() + ]) + assert (gbdt_cls == XGBoost + and len(gbdt.feature_importance()) <= num_features) or (len( + gbdt.feature_importance()) == num_features) assert torch.allclose(pred, loaded_pred, atol=1e-5) assert gbdt.metric == metric assert score == loaded_score diff --git a/torch_frame/gbdt/gbdt.py b/torch_frame/gbdt/gbdt.py index b2aafc5c5..307665fd0 100644 --- a/torch_frame/gbdt/gbdt.py +++ b/torch_frame/gbdt/gbdt.py @@ -63,6 +63,10 @@ def _predict(self, tf_train: TensorFrame) -> Tensor: def _load(self, path: str) -> None: raise NotImplementedError + @abstractmethod + def _feature_importance(self, *args, **kwargs) -> list: + raise NotImplementedError + @property def is_fitted(self) -> bool: r"""Whether the GBDT is already fitted.""" @@ -135,6 +139,19 @@ def load(self, path: str) -> None: self._load(path) self._is_fitted = True + def feature_importance(self, *args, **kwargs) -> list: + r"""Get GBDT's feature importance. + + Returns: + scores (list): Feature importance. + """ + if not self.is_fitted: + raise RuntimeError( + f"{self.__class__.__name__} is not yet fitted. Please run " + f"`tune()` first before attempting to get feature importance.") + scores = self._feature_importance(*args, **kwargs) + return scores + @torch.no_grad() def compute_metric( self, diff --git a/torch_frame/gbdt/tuned_catboost.py b/torch_frame/gbdt/tuned_catboost.py index cc2659f35..4730fff9e 100644 --- a/torch_frame/gbdt/tuned_catboost.py +++ b/torch_frame/gbdt/tuned_catboost.py @@ -225,3 +225,7 @@ def _load(self, path: str) -> None: self.model = catboost.CatBoost() self.model.load_model(path) + + def _feature_importance(self) -> list: + scores = self.model.feature_importances_ + return scores diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index 732ad7418..79584497a 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional import numpy as np import pandas as pd @@ -226,3 +226,27 @@ def _load(self, path: str) -> None: import lightgbm self.model = lightgbm.Booster(model_file=path) + + def _feature_importance(self, importance_type: str = 'gain', + iteration: Optional[int] = None) -> list: + r"""Get feature importances. + + Args: + importance_type (str): How the importance is calculated. + If "split", result contains numbers of times the feature + is used in a model. If "gain", result contains total gains + of splits which use the feature. + iteration (int, optional): Limit number of `iterations` in the feature + importance calculation. If None, if the best `iteration` exists, + it is used; otherwise, all trees are used. If <= 0, all trees + are used (no limits). + + Returns: + list: Array with feature importances. + """ + assert importance_type in [ + 'split', 'gain' + ], f'Expect split or gain, got {importance_type}.' + scores = self.model.feature_importance(importance_type=importance_type, + iteration=iteration) + return scores.tolist() diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 9b939d324..19e28199e 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -232,3 +232,41 @@ def _load(self, path: str) -> None: import xgboost self.model = xgboost.Booster(model_file=path) + + def _feature_importance(self, importance_type: str = 'weight') -> list: + r"""Get feature importances. + + Args: + importance_type (str): How the importance is calculated. + For tree model Importance type can be defined as: + + * 'weight': the number of times a feature is used to split + the data across all trees. + * 'gain': the average gain across all splits the feature + is used in. + * 'cover': the average coverage across all splits the + feature is used in. + * 'total_gain': the total gain across all splits the + feature is used in. + * 'total_cover': the total coverage across all splits the + feature is used in. + + .. note:: + + For linear model, only "weight" is defined and it's the + normalized coefficients without bias. + + .. note:: Zero-importance features will not be included + + Keep in mind that this function does not include + zero-importance feature, i.e. those features that have not + been used in any split conditions. + + Returns: + list: Array with feature importances. + """ + assert importance_type in [ + 'weight', 'gain', 'cover', 'total_gain', 'total_cover' + ], f'{importance_type} is not supported.' + scores = self.model.get_score(importance_type=importance_type) + return list(scores.values())