Skip to content

Commit d850bb8

Browse files
committed
[feat] BaseModel.input_from_feature_columns() 메서드 구현
1 parent 9ea8a30 commit d850bb8

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

CATS/models/basemodel.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Literal, Union
1+
from typing import Callable, List, Literal, Tuple, Union
22

33
import numpy as np
44
import torch
@@ -8,7 +8,8 @@
88

99
from ..callbacks import History
1010
from ..inputs import (DenseFeat, SparseFeat, VarLenSparseFeat,
11-
build_input_features, create_embedding_matrix)
11+
build_input_features, create_embedding_matrix,
12+
embedding_lookup, get_dense_inputs)
1213
from ..layers import PredictionLayer
1314

1415

@@ -243,3 +244,34 @@ def _get_metrics(
243244
raise NotImplementedError(f"{metric} is not implemented")
244245
self.metrics_names.append(metric)
245246
return metrics_dict
247+
248+
def input_from_feature_columns(
249+
self, inputs: torch.Tensor, feature_columns: List[Union[SparseFeat, DenseFeat]]
250+
) -> Tuple[List, List]:
251+
"""
252+
Get input data from feature columns.
253+
:param inputs: input tensor
254+
:param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat)
255+
:return: sparse embedding value list and dense input value list
256+
"""
257+
258+
sparse_feature_columns = (
259+
list(filter(lambda x: isinstance(x, SparseFeat), feature_columns))
260+
if len(feature_columns)
261+
else []
262+
)
263+
dense_feature_columns = (
264+
list(filter(lambda x: isinstance(x, DenseFeat), feature_columns))
265+
if len(feature_columns)
266+
else []
267+
)
268+
269+
sparse_embedding_list = embedding_lookup(
270+
inputs, self.embedding_dict, self.feature_index, sparse_feature_columns
271+
)
272+
273+
dense_value_list = get_dense_inputs(
274+
inputs, self.feature_index, dense_feature_columns
275+
)
276+
277+
return sparse_embedding_list, dense_value_list

0 commit comments

Comments
 (0)