From b2aeb178da163706aa54a1f4dd03426b8b5e7177 Mon Sep 17 00:00:00 2001 From: OlivierBinette-Upstart Date: Mon, 28 Jul 2025 14:46:11 +0000 Subject: [PATCH] Initial commit --- include/xgboost/predictor.h | 2 +- src/predictor/cpu_predictor.cc | 10 +++++----- src/predictor/cpu_treeshap.cc | 23 +++++++++++++---------- src/predictor/cpu_treeshap.h | 3 ++- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 020e0a59d1e8..d624ee6c294b 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -158,7 +158,7 @@ class Predictor { gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, std::vector const* tree_weights = nullptr, bool approximate = false, int condition = 0, - unsigned condition_feature = 0) const = 0; + unsigned condition_feature = 0, HostDeviceVector* feature_reprs = nullptr) const = 0; virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 18759edd8e95..febc8979f1af 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -747,7 +747,7 @@ class CPUPredictor : public Predictor { DataView batch, const MetaInfo &info, const gbm::GBTreeModel &model, const std::vector *tree_weights, std::vector> *mean_values, std::vector *feat_vecs, std::vector *contribs, - bst_tree_t ntree_limit, bool approximate, int condition, unsigned condition_feature) const { + bst_tree_t ntree_limit, bool approximate, int condition, unsigned condition_feature, std::vector *feature_reprs = nullptr) const { const int num_feature = model.learner_model_param->num_feature; const int ngroup = model.learner_model_param->num_output_group; CHECK_NE(ngroup, 0); @@ -778,7 +778,7 @@ class CPUPredictor : public Predictor { } if (!approximate) { CalculateContributions(*model.trees[j], feats, tree_mean_values, - &this_tree_contribs[0], condition, condition_feature); + &this_tree_contribs[0], condition, condition_feature, feature_reprs == nullptr ? nullptr : feature_reprs->data()); } else { model.trees[j]->CalculateContributionsApprox( feats, tree_mean_values, &this_tree_contribs[0]); @@ -950,7 +950,7 @@ class CPUPredictor : public Predictor { void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate, - int condition, unsigned condition_feature) const override { + int condition, unsigned condition_feature, HostDeviceVector *feature_reprs = nullptr) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) @@ -982,13 +982,13 @@ class CPUPredictor : public Predictor { for (const auto &batch : p_fmat->GetBatches(ctx_, {})) { PredictContributionKernel(GHistIndexMatrixView{batch, std::forward(acc), ft}, info, model, tree_weights, &mean_values, &feat_vecs, &contribs, - ntree_limit, approximate, condition, condition_feature); + ntree_limit, approximate, condition, condition_feature, feature_reprs == nullptr ? nullptr : &feature_reprs->HostVector()); } } else { for (const auto &batch : p_fmat->GetBatches()) { PredictContributionKernel(SparsePageView{&batch, std::forward(acc)}, info, model, tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit, - approximate, condition, condition_feature); + approximate, condition, condition_feature, feature_reprs == nullptr ? nullptr : &feature_reprs->HostVector()); } } }; diff --git a/src/predictor/cpu_treeshap.cc b/src/predictor/cpu_treeshap.cc index 64b195d78221..fec09cb9cc89 100644 --- a/src/predictor/cpu_treeshap.cc +++ b/src/predictor/cpu_treeshap.cc @@ -106,11 +106,13 @@ float UnwoundPathSum(const PathElement* unique_path, std::uint32_t unique_depth, * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param condition_feature the index of the feature to fix * \param condition_fraction what fraction of the current weight matches our conditioning feature + * \param feature_reprs mapping of features to group representatives, for groupSHAP calculation */ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_node_t node_index, std::uint32_t unique_depth, PathElement* parent_unique_path, float parent_zero_fraction, float parent_one_fraction, int parent_feature_index, - int condition, std::uint32_t condition_feature, float condition_fraction) { + int condition, std::uint32_t condition_feature, float condition_fraction, + std::uint32_t* feature_reprs) { const auto node = tree[node_index]; // stop if we have no weight coming down to us @@ -125,6 +127,7 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no parent_feature_index); } const std::uint32_t split_index = node.SplitIndex(); + const std::uint32_t split_index_repr = feature_reprs == nullptr ? split_index : feature_reprs[split_index]; // leaf node if (node.IsLeaf()) { @@ -153,7 +156,7 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no // if so we undo that split so we can redo it for this node std::uint32_t path_index = 0; for (; path_index <= unique_depth; ++path_index) { - if (static_cast(unique_path[path_index].feature_index) == split_index) break; + if (static_cast(unique_path[path_index].feature_index) == split_index_repr) break; } if (path_index != unique_depth + 1) { incoming_zero_fraction = unique_path[path_index].zero_fraction; @@ -165,28 +168,28 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no // divide up the condition_fraction among the recursive calls float hot_condition_fraction = condition_fraction; float cold_condition_fraction = condition_fraction; - if (condition > 0 && split_index == condition_feature) { + if (condition > 0 && split_index_repr == condition_feature) { cold_condition_fraction = 0; unique_depth -= 1; - } else if (condition < 0 && split_index == condition_feature) { + } else if (condition < 0 && split_index_repr == condition_feature) { hot_condition_fraction *= hot_zero_fraction; cold_condition_fraction *= cold_zero_fraction; unique_depth -= 1; } TreeShap(tree, feat, phi, hot_index, unique_depth + 1, unique_path, - hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index, - condition, condition_feature, hot_condition_fraction); + hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index_repr, + condition, condition_feature, hot_condition_fraction, feature_reprs); TreeShap(tree, feat, phi, cold_index, unique_depth + 1, unique_path, - cold_zero_fraction * incoming_zero_fraction, 0, split_index, condition, - condition_feature, cold_condition_fraction); + cold_zero_fraction * incoming_zero_fraction, 0, split_index_repr, condition, + condition_feature, cold_condition_fraction, feature_reprs); } } void CalculateContributions(RegTree const& tree, const RegTree::FVec& feat, std::vector* mean_values, float* out_contribs, int condition, - std::uint32_t condition_feature) { + std::uint32_t condition_feature, std::uint32_t* feature_reprs) { // find the expected value of the tree's predictions if (condition == 0) { float node_value = (*mean_values)[0]; @@ -198,6 +201,6 @@ void CalculateContributions(RegTree const& tree, const RegTree::FVec& feat, std::vector unique_path_data((maxd * (maxd + 1)) / 2); TreeShap(tree, feat, out_contribs, 0, 0, unique_path_data.data(), 1, 1, -1, condition, - condition_feature, 1); + condition_feature, 1, feature_reprs); } } // namespace xgboost diff --git a/src/predictor/cpu_treeshap.h b/src/predictor/cpu_treeshap.h index 3cdbcc4a998e..c8f60a56a19b 100644 --- a/src/predictor/cpu_treeshap.h +++ b/src/predictor/cpu_treeshap.h @@ -14,9 +14,10 @@ namespace xgboost { * \param out_contribs output vector to hold the contributions * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param condition_feature the index of the feature to fix + * \param feature_reprs mapping of features to group representatives, for groupSHAP calculation */ void CalculateContributions(RegTree const &tree, const RegTree::FVec &feat, std::vector *mean_values, bst_float *out_contribs, int condition, - unsigned condition_feature); + unsigned condition_feature, std::uint32_t* feature_reprs); } // namespace xgboost #endif // XGBOOST_PREDICTOR_CPU_TREESHAP_H_