Skip to content

Commit a930c6b

Browse files
authored
[#3] feature들의 name 및 크기 정보를 반환하는 함수 구현 (#9)
* [feat] build_input_feature()함수 구현 * [feat] get_feature_names() 함수 구현 * [fix] import 되는 모듈 및 패키지 정렬 * [fix] get_feature_names() 함수 내 feature_columns 매개변수 type hint 수정 * [fix] get_feature_names() 함수 내 feature_columns 변수에 대한 예외처리 구현 * [fix] build_input_features() 함수 내 feature_columns 변수의 type hint 추가 * [fix] build_input_features() 함수 내 start 변수명 변경 (start->curr_features_idx) * [fix] build_input_features() 함수 내 예외처리 구문 f-string으로 변경
1 parent 04e7df0 commit a930c6b

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

CATS/inputs.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections import namedtuple
2-
from typing import Literal
1+
from collections import OrderedDict, namedtuple
2+
from typing import List, Literal, Union
33

44
DEFAULT_GROUP_NAME = "default_group"
55

@@ -138,3 +138,53 @@ def __hash__(self):
138138
:return: self.name's hash
139139
"""
140140
return self.name.__hash__()
141+
142+
143+
def get_feature_names(feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]]) -> list:
144+
"""
145+
Get list of feature names
146+
:param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat)
147+
:return: list about features dictionary's keys
148+
"""
149+
if feature_columns is None:
150+
raise ValueError("feature_columns is None. feature_columns must be list")
151+
if not isinstance(feature_columns, list):
152+
raise ValueError(f"feature_columns is {type(feature_columns)}, feature_columns must be list.")
153+
if not all(isinstance(feature, (SparseFeat, DenseFeat, VarLenSparseFeat)) for feature in feature_columns):
154+
raise TypeError(
155+
"All elements in feature_columns must be instances of SparseFeat, DenseFeat or VarLenSparseFeat.")
156+
features = build_input_features(feature_columns)
157+
return list(features.keys())
158+
159+
160+
def build_input_features(feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]]) -> dict:
161+
"""
162+
Return an input feature dictionary based on various types of features (SparseFeat, DenseFeat, VarLenSparseFeat).
163+
input feature dictionary stores the start and end inices of each feature, helping the model identify the location of
164+
each feature in the input data.
165+
:param feature_columns: list about feature instances (SparseFeat, DenseFeat, VarLenSparseFeat)
166+
:return: dictionary about features
167+
"""
168+
features = OrderedDict()
169+
170+
curr_features_idx = 0
171+
for feat in feature_columns:
172+
feat_name = feat.name
173+
if feat_name in features:
174+
continue
175+
if isinstance(feat, SparseFeat):
176+
features[feat_name] = (curr_features_idx, curr_features_idx + 1)
177+
curr_features_idx += 1
178+
elif isinstance(feat, DenseFeat):
179+
features[feat_name] = (curr_features_idx, curr_features_idx + feat.dimension)
180+
curr_features_idx += feat.dimension
181+
elif isinstance(feat, VarLenSparseFeat):
182+
features[feat_name] = (curr_features_idx, curr_features_idx + feat.maxlen)
183+
curr_features_idx += feat.maxlen
184+
if feat.length_name is not None and feat.length_name not in features:
185+
features[feat.length_name] = (curr_features_idx, curr_features_idx+1)
186+
curr_features_idx += 1
187+
else:
188+
raise TypeError(f"Invalid feature column type, got {type(feat)}")
189+
return features
190+

0 commit comments

Comments
 (0)