Skip to content

Commit

Permalink
adapt embedding for forecasting tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
dengdifan committed Aug 22, 2022
1 parent 9d62c2b commit 5f9713a
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
self.add_fit_requirements([
FitRequirement('numerical_features', (List,), user_defined=True, dataset_property=True),
FitRequirement('categorical_features', (List,), user_defined=True, dataset_property=True)])
self.output_feature_order = None

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
"""
Expand Down Expand Up @@ -74,6 +75,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
X_train = X['backend'].load_datamanager().train_tensors[0]

self.preprocessor.fit(X_train)
self.output_feature_order = self.get_output_column_orders(len(X['dataset_properties']['feature_names']))
return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -86,7 +88,8 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
X (Dict[str, Any]): updated fit dictionary
"""
X.update({'time_series_feature_transformer': self})
X.update({'time_series_feature_transformer': self,
'feature_order_after_preprocessing': self.output_feature_order})
return X

def __call__(self, X: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -108,6 +111,33 @@ def get_column_transformer(self) -> ColumnTransformer:
.format(self.__class__.__name__))
return self.preprocessor

def get_output_column_orders(self, n_input_columns: int) -> List[int]:
"""
get the order of the output features transformed by self.preprocessor
TODO: replace this function with self.preprocessor.get_feature_names_out() when switch to sklearn 1.0 !
Args:
n_input_columns (int): number of input columns that will be transformed
Returns:
np.ndarray: a list of index indicating the order of each columns after transformation. Its length should
equal to n_input_columns
"""
if self.preprocessor is None:
raise ValueError("cant call {} without fitting the column transformer first."
.format(self.__class__.__name__))
transformers = self.preprocessor.transformers

n_reordered_input = np.arange(n_input_columns)
processed_columns = np.asarray([], dtype=np.int)

for tran in transformers:
trans_columns = np.array(tran[-1], dtype=np.int)
unprocessed_columns = np.setdiff1d(processed_columns, trans_columns)
processed_columns = np.hstack([unprocessed_columns, trans_columns])
unprocessed_columns = np.setdiff1d(n_reordered_input, processed_columns)
return np.hstack([processed_columns, unprocessed_columns]).tolist() # type: ignore[return-value]


class TimeSeriesTargetTransformer(autoPyTorchTimeSeriesTargetPreprocessingComponent):
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def __init__(

def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'TimeSeriesColumnSplitter':
super(TimeSeriesColumnSplitter, self).fit(X, y)

self.num_categories_per_col_encoded = X['dataset_properties']['num_categories_per_col']
for i in range(len(self.num_categories_per_col_encoded)):
if i in self.special_feature_types['embed_columns']:
self.num_categories_per_col_encoded[i] = 1
return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,12 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None
FitRequirement('X_train', (pd.DataFrame, ), user_defined=True,
dataset_property=False),
FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True),
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('feature_order_after_preprocessing', (List,), user_defined=False, dataset_property=False)
])

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
if dataset is small process, we transform the entire dataset here.
Before transformation, the order of the dataset is:
[(unknown_columns), categorical_columns, numerical_columns]
While after transformation, the order of the dataset is:
[numerical_columns, categorical_columns, unknown_columns]
we need to change feature_names and feature_shapes accordingly
Args:
X(Dict): fit dictionary
Expand All @@ -52,20 +46,9 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
X['X_train'] = time_series_preprocess(dataset=X_train, transforms=transforms)

feature_names = X['dataset_properties']['feature_names']
numerical_columns = X['dataset_properties']['numerical_columns']
categorical_columns = X['dataset_properties']['categorical_columns']
# encoding_columns = X['dataset_properties']['encoding_columns']
encode_columns = X['encode_columns']
import pdb
pdb.set_trace()

# resort feature_names
# Previously, the categorical features are sorted before numerical features. However,
# After the preprocessing. The numerical features are sorted at the first place.
new_feature_names = [feature_names[num_col] for num_col in numerical_columns]
new_feature_names += [feature_names[cat_col] for cat_col in categorical_columns]
if set(feature_names) != set(new_feature_names):
new_feature_names += list(set(feature_names) - set(new_feature_names))

feature_order_after_preprocessing = X['feature_order_after_preprocessing']
new_feature_names = (feature_names[i] for i in feature_order_after_preprocessing)
X['dataset_properties']['feature_names'] = tuple(new_feature_names)

preprocessed_dtype = get_preprocessed_dtype(X['X_train'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.distributions import AffineTransform, TransformedDistribution
Expand Down Expand Up @@ -205,6 +206,7 @@ def __init__(self,
auto_regressive: bool,
feature_names: Union[Tuple[str], Tuple[()]] = (),
known_future_features: Union[Tuple[str], Tuple[()]] = (),
embed_features_idx: Tuple[int] = (),
feature_shapes: Dict[str, int] = {},
static_features: Union[Tuple[str], Tuple[()]] = (),
time_feature_names: Union[Tuple[str], Tuple[()]] = (),
Expand All @@ -218,7 +220,16 @@ def __init__(self,
self.embedding = network_embedding
if len(known_future_features) > 0:
known_future_features_idx = [feature_names.index(kff) for kff in known_future_features]
self.decoder_embedding = self.embedding.get_partial_models(known_future_features_idx)
known_future_embed_features = np.where(
np.in1d(embed_features_idx, known_future_features_idx, assume_unique=True)
)[0]
idx_excl_embed_future_features = np.setdiff1d(known_future_features_idx, embed_features_idx)
n_excl_embed_features = sum(feature_shapes[feature_names[i]] for i in idx_excl_embed_future_features)

self.decoder_embedding = self.embedding.get_partial_models(
n_excl_embed_features=n_excl_embed_features,
idx_embed_feat_partial=known_future_embed_features
)
else:
self.decoder_embedding = _NoEmbedding()
# modules that generate tensors while doing forward pass
Expand Down Expand Up @@ -558,7 +569,7 @@ def pre_processing(self,
return x_past, x_future, x_static, loc, scale, static_context_initial_hidden, past_targets
else:
if past_features is not None:
x_past = torch.cat([truncated_past_targets, past_features], dim=-1).to(device=self.device)
x_past = torch.cat([past_features, truncated_past_targets], dim=-1).to(device=self.device)
x_past = self.embedding(x_past.to(device=self.device))
else:
x_past = self.embedding(truncated_past_targets.to(device=self.device))
Expand Down Expand Up @@ -615,8 +626,8 @@ def forward(self,
return self.rescale_output(output, loc, scale, self.device)

def _unwrap_past_targets(
self,
past_targets: dict
self,
past_targets: dict
) -> Tuple[torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
FitRequirement("auto_regressive", (bool,), user_defined=False, dataset_property=False),
FitRequirement("target_scaler", (BaseTargetScaler,), user_defined=False, dataset_property=False),
FitRequirement("net_output_type", (str,), user_defined=False, dataset_property=False),
FitRequirement('embed_features_idx', (tuple,), user_defined=False, dataset_property=False),
FitRequirement("feature_names", (Iterable,), user_defined=False, dataset_property=True),
FitRequirement("feature_shapes", (Iterable,), user_defined=False, dataset_property=True),
FitRequirement('transform_time_features', (bool,), user_defined=False, dataset_property=False),
Expand Down Expand Up @@ -85,6 +86,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
feature_names=feature_names,
feature_shapes=feature_shapes,
known_future_features=known_future_features,
embed_features_idx=X['embed_features_idx'],
time_feature_names=time_feature_names,
static_features=X['dataset_properties']['static_features']
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,37 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n

self.ee_layers = self._create_ee_layers()

def insert_new_input_features(self, n_new_features: int):
"""
Time series tasks need to add targets to the embeddings. However, the target information is not recorded
by autoPyTorch's embeddings. Therefore, we need to add the targets to the input features manually, which is
located in front of the features
Args:
n_new_features (int):
number of new features that is inserted in front of the input features
"""
self.num_categories_per_col = np.hstack([np.zeros(n_new_features, dtype=np.int16), self.num_categories_per_col])
self.embed_features = np.hstack([np.zeros(n_new_features, dtype=np.bool), self.num_categories_per_col])

self.num_features_excl_embed += n_new_features
self.num_output_dimensions = [1] * n_new_features + self.num_output_dimensions
self.num_out_feats += n_new_features

def get_partial_models(self,
n_excl_embed_features: int,
idx_embed_feat_partial: List[int]) -> "_LearnedEntityEmbedding":
"""
extract a partial models that only works on a subset of the data that ought to be passed to the embedding
network, this function is implemented for time series forecasting tasks where the known future features is only
a subset of the past features
Args:
n_excl_embed_features (int):
number of unembedded features
idx_embed_feat_partial (List[int]):
a set of index identifying the which embedding features will be inherited by the partial model
a set of index identifying the which embedding features will be inherited by the partial model. This
index is used to extract self.ee_layers
Returns:
partial_model (_LearnedEntityEmbedding)
Expand Down Expand Up @@ -119,11 +138,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
concat_seq = []

layer_pointer = 0
# Time series tasks need to add targets to the embeddings. However, the target information is not recorded
# by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
# the last few dimensions of the input x
# we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
x_pointer = 0
# For forcasting architectures,besides the input features, we might also need to feed targets and time features
# to the embedding layers, which are not counted by self.embed_features.
for x_pointer, embed in enumerate(self.embed_features):
if not embed:
current_feature_slice = x[..., [x_pointer]]
Expand All @@ -134,6 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
concat_seq.append(self.ee_layers[layer_pointer](current_feature_slice))

layer_pointer += 1
concat_seq.append(x[..., x_pointer + 1:])

return torch.cat(concat_seq, dim=-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@


class _NoEmbedding(nn.Module):
def get_partial_models(self, **kwargs: Any) -> "_NoEmbedding":
def get_partial_models(self, *args, **kwargs) -> "_NoEmbedding":
return self

def insert_new_input_features(self, *args, **kwargs) -> "_NoEmbedding":
return self

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
# forecasting tasks
feature_names = X['dataset_properties']['feature_names']
n_features_all = len(feature_names)
# embedded feature index
embed_features_idx = tuple(range(n_features_all - n_features_embedded, n_features_all))
for idx, n_output_embedded in zip(embed_features_idx, num_output_features[-n_features_embedded:]):
feat_name = feature_names[idx]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdate


class ReducedEmbedding(torch.nn.Module):
class IncrementalEmbedding(torch.nn.Module):
# a dummy reduced embedding, it simply cut row for each categorical features
def __init__(self, num_input_features, num_numerical_features: int):
super(ReducedEmbedding, self).__init__()
self.num_input_features = num_input_features
self.num_numerical_features = num_numerical_features
self.n_cat_features = len(num_input_features) - num_numerical_features
def __init__(self, n_excl_embed_features, embed_feat_idx):
super(IncrementalEmbedding, self).__init__()
self.n_excl_embed_features = n_excl_embed_features
self.embed_feat_idx = embed_feat_idx

def forward(self, x):
x = x[..., :-self.n_cat_features]
if len(self.embed_feat_idx) > 0:
x = torch.cat([x, x[..., -len(self.embed_feat_idx):]], dim=-1)
return x

def get_partial_models(self, subset_features):
num_numerical_features = sum([sf < self.num_numerical_features for sf in subset_features])
num_input_features = [self.num_input_features[sf] for sf in subset_features]
return ReducedEmbedding(num_input_features, num_numerical_features)
def get_partial_models(self, n_excl_embed_features, idx_embed_feat_partial):
n_excl_embed_features = n_excl_embed_features
embed_feat_idx = [self.embed_feat_idx[idx] for idx in idx_embed_feat_partial]
return IncrementalEmbedding(n_excl_embed_features, embed_feat_idx)


@pytest.fixture(params=['ForecastingNet', 'ForecastingSeq2SeqNet', 'ForecastingDeepARNet', 'NBEATSNet'])
Expand All @@ -52,7 +52,7 @@ def network_encoder(request):
return request.param


@pytest.fixture(params=['ReducedEmbedding', 'NoEmbedding'])
@pytest.fixture(params=['IncrementalEmbedding', 'NoEmbedding'])
def embedding(request):
return request.param

Expand Down Expand Up @@ -110,7 +110,7 @@ def test_network_forward(self,
dataset_properties['known_future_features'] = ('f1', 'f3', 'f5')

if with_static_features:
dataset_properties['static_features'] = (0, 4)
dataset_properties['static_features'] = (0, 3)
else:
dataset_properties['static_features'] = tuple()

Expand All @@ -130,10 +130,14 @@ def test_network_forward(self,
fit_dictionary['net_output_type'] = net_output_type

if embedding == 'NoEmbedding':
embed_features_idx = ()
fit_dictionary['network_embedding'] = _NoEmbedding()
fit_dictionary['embed_features_idx'] = embed_features_idx
else:
fit_dictionary['network_embedding'] = ReducedEmbedding([10] * 5, 2)
dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 9, 'f4': 9, 'f5': 9}
embed_features_idx = (3, 4)
fit_dictionary['network_embedding'] = IncrementalEmbedding(50, embed_features_idx)
fit_dictionary['embed_features_idx'] = embed_features_idx
dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 10, 'f4': 11, 'f5': 11}

if uni_variant_data:
fit_dictionary['X_train'] = None
Expand Down

0 comments on commit 5f9713a

Please sign in to comment.