|
1 |
| -from collections import namedtuple |
2 |
| -from typing import Literal |
| 1 | +from collections import OrderedDict, namedtuple |
| 2 | +from typing import List, Literal, Union |
3 | 3 |
|
4 | 4 | DEFAULT_GROUP_NAME = "default_group"
|
5 | 5 |
|
@@ -138,3 +138,53 @@ def __hash__(self):
|
138 | 138 | :return: self.name's hash
|
139 | 139 | """
|
140 | 140 | return self.name.__hash__()
|
| 141 | + |
| 142 | + |
| 143 | +def get_feature_names(feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]]) -> list: |
| 144 | + """ |
| 145 | + Get list of feature names |
| 146 | + :param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat) |
| 147 | + :return: list about features dictionary's keys |
| 148 | + """ |
| 149 | + if feature_columns is None: |
| 150 | + raise ValueError("feature_columns is None. feature_columns must be list") |
| 151 | + if not isinstance(feature_columns, list): |
| 152 | + raise ValueError(f"feature_columns is {type(feature_columns)}, feature_columns must be list.") |
| 153 | + if not all(isinstance(feature, (SparseFeat, DenseFeat, VarLenSparseFeat)) for feature in feature_columns): |
| 154 | + raise TypeError( |
| 155 | + "All elements in feature_columns must be instances of SparseFeat, DenseFeat or VarLenSparseFeat.") |
| 156 | + features = build_input_features(feature_columns) |
| 157 | + return list(features.keys()) |
| 158 | + |
| 159 | + |
| 160 | +def build_input_features(feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]]) -> dict: |
| 161 | + """ |
| 162 | + Return an input feature dictionary based on various types of features (SparseFeat, DenseFeat, VarLenSparseFeat). |
| 163 | + input feature dictionary stores the start and end inices of each feature, helping the model identify the location of |
| 164 | + each feature in the input data. |
| 165 | + :param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat) |
| 166 | + :return: dictionary about features |
| 167 | + """ |
| 168 | + features = OrderedDict() |
| 169 | + |
| 170 | + curr_features_idx = 0 |
| 171 | + for feat in feature_columns: |
| 172 | + feat_name = feat.name |
| 173 | + if feat_name in features: |
| 174 | + continue |
| 175 | + if isinstance(feat, SparseFeat): |
| 176 | + features[feat_name] = (curr_features_idx, curr_features_idx + 1) |
| 177 | + curr_features_idx += 1 |
| 178 | + elif isinstance(feat, DenseFeat): |
| 179 | + features[feat_name] = (curr_features_idx, curr_features_idx + feat.dimension) |
| 180 | + curr_features_idx += feat.dimension |
| 181 | + elif isinstance(feat, VarLenSparseFeat): |
| 182 | + features[feat_name] = (curr_features_idx, curr_features_idx + feat.maxlen) |
| 183 | + curr_features_idx += feat.maxlen |
| 184 | + if feat.length_name is not None and feat.length_name not in features: |
| 185 | + features[feat.length_name] = (curr_features_idx, curr_features_idx+1) |
| 186 | + curr_features_idx += 1 |
| 187 | + else: |
| 188 | + raise TypeError(f"Invalid feature column type, got {type(feat)}") |
| 189 | + return features |
| 190 | + |
0 commit comments