|
1 |
| -from typing import Callable, List, Literal, Union |
| 1 | +from typing import Callable, List, Literal, Tuple, Union |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | import torch
|
|
8 | 8 |
|
9 | 9 | from ..callbacks import History
|
10 | 10 | from ..inputs import (DenseFeat, SparseFeat, VarLenSparseFeat,
|
11 |
| - build_input_features, create_embedding_matrix) |
| 11 | + build_input_features, create_embedding_matrix, |
| 12 | + embedding_lookup, get_dense_inputs) |
12 | 13 | from ..layers import PredictionLayer
|
13 | 14 |
|
14 | 15 |
|
@@ -243,3 +244,34 @@ def _get_metrics(
|
243 | 244 | raise NotImplementedError(f"{metric} is not implemented")
|
244 | 245 | self.metrics_names.append(metric)
|
245 | 246 | return metrics_dict
|
| 247 | + |
| 248 | + def input_from_feature_columns( |
| 249 | + self, inputs: torch.Tensor, feature_columns: List[Union[SparseFeat, DenseFeat]] |
| 250 | + ) -> Tuple[List, List]: |
| 251 | + """ |
| 252 | + Get input data from feature columns. |
| 253 | + :param inputs: input tensor |
| 254 | + :param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat) |
| 255 | + :return: sparse embedding value list and dense input value list |
| 256 | + """ |
| 257 | + |
| 258 | + sparse_feature_columns = ( |
| 259 | + list(filter(lambda x: isinstance(x, SparseFeat), feature_columns)) |
| 260 | + if len(feature_columns) |
| 261 | + else [] |
| 262 | + ) |
| 263 | + dense_feature_columns = ( |
| 264 | + list(filter(lambda x: isinstance(x, DenseFeat), feature_columns)) |
| 265 | + if len(feature_columns) |
| 266 | + else [] |
| 267 | + ) |
| 268 | + |
| 269 | + sparse_embedding_list = embedding_lookup( |
| 270 | + inputs, self.embedding_dict, self.feature_index, sparse_feature_columns |
| 271 | + ) |
| 272 | + |
| 273 | + dense_value_list = get_dense_inputs( |
| 274 | + inputs, self.feature_index, dense_feature_columns |
| 275 | + ) |
| 276 | + |
| 277 | + return sparse_embedding_list, dense_value_list |
0 commit comments