Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Embedding to avoid OneHotEncoding all categorical columns #425

9 changes: 5 additions & 4 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def send_warnings_to_log(
return prediction


def get_search_updates(categorical_indicator: List[bool]):
def get_search_updates(categorical_indicator: List[bool]) -> HyperparameterSearchSpaceUpdates:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method argument is not used, I believe it could be removed.

"""
These updates mimic the autopytorch tabular paper.
Returns:
Expand All @@ -120,8 +120,8 @@ def get_search_updates(categorical_indicator: List[bool]):
The search space updates like setting different hps to different values or ranges.
"""

has_cat_features = any(categorical_indicator)
has_numerical_features = not all(categorical_indicator)
# has_cat_features = any(categorical_indicator)
# has_numerical_features = not all(categorical_indicator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be removed.


search_space_updates = HyperparameterSearchSpaceUpdates()

Expand Down Expand Up @@ -267,7 +267,8 @@ def __init__(

self.input_validator: Optional[BaseInputValidator] = None

self.search_space_updates = search_space_updates if search_space_updates is not None else get_search_updates(categorical_indicator)
# if search_space_updates is not None else get_search_updates(categorical_indicator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could also be removed.

ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
self.search_space_updates = search_space_updates
if search_space_updates is not None:
if not isinstance(self.search_space_updates,
HyperparameterSearchSpaceUpdates):
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(

# Required for dataset properties
self.num_features: Optional[int] = None
self.categories: List[List[int]] = []
self.categorical_columns: List[int] = []
self.numerical_columns: List[int] = []

self.num_categories_per_col: Optional[List[int]] = []
self.all_nan_columns: Optional[Set[Union[int, str]]] = None

self._is_fitted = False
Expand Down
8 changes: 2 additions & 6 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ def _fit(
encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
list(range(len(cat)))
for cat in encoded_categories
]

self.num_categories_per_col = [len(cat) for cat in encoded_categories]

# differently to categorical_columns and numerical_columns,
# this saves the index of the column.
Expand Down Expand Up @@ -274,8 +272,6 @@ def transform(
X = self.numpy_to_pandas(X)

if ispandas(X) and not issparse(X):
X = cast(pd.DataFrame, X)

if self.all_nan_columns is None:
raise ValueError('_fit must be called before calling transform')

Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self,
self.categorical_columns = validator.feature_validator.categorical_columns
self.numerical_columns = validator.feature_validator.numerical_columns
self.num_features = validator.feature_validator.num_features
self.categories = validator.feature_validator.categories
self.num_categories_per_col = validator.feature_validator.num_categories_per_col

super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle,
resampling_strategy=resampling_strategy,
Expand Down
43 changes: 24 additions & 19 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from multiprocessing.queues import Queue
import os
from multiprocessing.queues import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

from ConfigSpace.configuration_space import Configuration
Expand All @@ -22,6 +22,7 @@
fit_and_suppress_warnings
)
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.utils.common import dict_repr, subsampler
Expand Down Expand Up @@ -195,24 +196,7 @@ def fit_predict_and_loss(self) -> None:
additional_run_info = pipeline.get_additional_run_info() if hasattr(
pipeline, 'get_additional_run_info') else {}

# add learning curve of configurations to additional_run_info
if isinstance(pipeline, TabularClassificationPipeline):
if hasattr(pipeline.named_steps['trainer'], 'run_summary'):
run_summary = pipeline.named_steps['trainer'].run_summary
split_types = ['train', 'val', 'test']
run_summary_dict = dict(
run_summary={},
budget=self.budget,
seed=self.seed,
config_id=self.configuration.config_id,
num_run=self.num_run
)
for split_type in split_types:
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(f'{split_type}_loss', None)
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(f'{split_type}_metrics', None)
self.logger.debug(f"run_summary_dict {json.dumps(run_summary_dict)}")
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file:
file.write(f"{json.dumps(run_summary_dict)}\n")
# self._write_run_summary(pipeline)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the functionality that was encapsulated in the function, I think this should be called here, right?


status = StatusType.SUCCESS

Expand Down Expand Up @@ -370,6 +354,27 @@ def fit_predict_and_loss(self) -> None:
status=status,
)

def _write_run_summary(self, pipeline: BasePipeline) -> None:
# add learning curve of configurations to additional_run_info
if isinstance(pipeline, TabularClassificationPipeline):
assert isinstance(self.configuration, Configuration)
if hasattr(pipeline.named_steps['trainer'], 'run_summary'):
run_summary = pipeline.named_steps['trainer'].run_summary
split_types = ['train', 'val', 'test']
run_summary_dict = dict(
run_summary={},
budget=self.budget,
seed=self.seed,
config_id=self.configuration.config_id,
num_run=self.num_run)
for split_type in split_types:
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(
f'{split_type}_loss', None)
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(
f'{split_type}_metrics', None)
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file:
file.write(f"{json.dumps(run_summary_dict)}\n")

def _fit_and_predict(self, pipeline: BaseEstimator, fold: int, train_indices: Union[np.ndarray, List],
test_indices: Union[np.ndarray, List],
add_pipeline_to_self: bool
Expand Down
34 changes: 3 additions & 31 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import copy
import warnings
from abc import ABCMeta
from collections import Counter
Expand Down Expand Up @@ -297,7 +296,7 @@ def _get_hyperparameter_search_space(self,
"""
raise NotImplementedError()

def _add_forbidden_conditions(self, cs):
def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpace:
"""
Add forbidden conditions to ensure valid configurations.
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the chances introduced in the PR I think the first condition mentioned in the docstring regarding Learned Entity Embedding should be removed.

Expand All @@ -310,33 +309,6 @@ def _add_forbidden_conditions(self, cs):

"""

# Learned Entity Embedding is only valid when encoder is one hot encoder
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
if 'LearnedEntityEmbedding' in embeddings:
encoders = cs.get_hyperparameter('encoder:__choice__').choices
possible_default_embeddings = copy(list(embeddings))
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]

for encoder in encoders:
if encoder == 'OneHotEncoder':
continue
while True:
try:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
))
break
except ValueError:
# change the default and try again
try:
default = possible_default_embeddings.pop()
except IndexError:
raise ValueError("Cannot find a legal default configuration")
cs.get_hyperparameter('network_embedding:__choice__').default_value = default

# Disable CyclicLR until todo is completed.
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
trainers = cs.get_hyperparameter('trainer:__choice__').choices
Expand All @@ -347,7 +319,8 @@ def _add_forbidden_conditions(self, cs):
if cyclic_lr_name in available_schedulers:
# disable snapshot ensembles and stochastic weight averaging
snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble')
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and True in snapshot_ensemble_hyperparameter.choices:
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and \
True in snapshot_ensemble_hyperparameter.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
Expand Down Expand Up @@ -549,7 +522,6 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
node_hyperparameters,
update.hyperparameter))


def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]]
) -> List[Tuple[str, PipelineStepType]]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __init__(self) -> None:
self._processing = True
self.add_fit_requirements([
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categories', (List,), user_defined=True, dataset_property=True)
])

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Any, Dict, List, Optional, Union

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import (
UniformIntegerHyperparameter,
)

import numpy as np


from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \
autoPyTorchTabularPreprocessingComponent
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


class ColumnSplitter(autoPyTorchTabularPreprocessingComponent):
"""
Removes features that have the same value in the training data.
"""
def __init__(
self,
min_categories_for_embedding: float = 5,
random_state: Optional[np.random.RandomState] = None
):
self.min_categories_for_embedding = min_categories_for_embedding
self.random_state = random_state

self.special_feature_types: Dict[str, List] = dict(encode_columns=[], embed_columns=[])
self.num_categories_per_col: Optional[List] = None
super().__init__()

def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'ColumnSplitter':

self.check_requirements(X, y)

if len(X['dataset_properties']['categorical_columns']) > 0:
self.num_categories_per_col = []
for categories_per_column, column in zip(X['dataset_properties']['num_categories_per_col'],
X['dataset_properties']['categorical_columns']):
if (
categories_per_column >= self.min_categories_for_embedding
):
self.special_feature_types['embed_columns'].append(column)
# we only care about the categories for columns to be embedded
self.num_categories_per_col.append(categories_per_column)
else:
self.special_feature_types['encode_columns'].append(column)

return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
if self.num_categories_per_col is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_categories_per_col is initialized as an empty list, which means that it will not be None also for the encoded columns. Maybe this conditions should be changed to:

if self.num_categories_per_col:
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be none when there were no categorical column, see line 38

Copy link
Collaborator

@theodorju theodorju Jul 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, but line 38 initializes self.num_categories_per_col to an empty list if there are categorical columns, and [] is not None returns True.

I'm mentioning this because I thought in line 53 we check if there are columns to be embedded, currently the if conditions evaluates to true both for embedded and encoded columns.

# update such that only n categories for embedding columns is passed
X['dataset_properties']['num_categories_per_col'] = self.num_categories_per_col
X.update(self.special_feature_types)
return X

@staticmethod
def get_properties(
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None
) -> Dict[str, Union[str, bool]]:

return {
'shortname': 'ColumnSplitter',
'name': 'Column Splitter',
'handles_sparse': False,
}

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None,
min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="min_categories_for_embedding",
value_range=(3, 7),
default_value=3,
log=True),
) -> ConfigurationSpace:
cs = ConfigurationSpace()

add_hyperparameter(cs, min_categories_for_embedding, UniformIntegerHyperparameter)

return cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder:

self.preprocessor['categorical'] = OHE(
# It is safer to have the OHE produce a 0 array than to crash a good configuration
categories=X['dataset_properties']['categories']
if len(X['dataset_properties']['categories']) > 0 else 'auto',
categories='auto',
sparse=False,
handle_unknown='ignore')
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class BaseEncoder(autoPyTorchTabularPreprocessingComponent):
def __init__(self) -> None:
super().__init__()
self.add_fit_requirements([
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categories', (List,), user_defined=True, dataset_property=True)])
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), ])

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from math import ceil, floor
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Tuple

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.utils.common import HyperparameterSearchSpace, HyperparameterValueType
Expand Down Expand Up @@ -81,11 +81,17 @@ def percentage_value_range_to_integer_range(
log = False
else:
log = hyperparameter_search_space.log

value_range: Tuple
if len(hyperparameter_search_space.value_range) == 2:
value_range = (floor(float(hyperparameter_search_space.value_range[0]) * n_features),
floor(float(hyperparameter_search_space.value_range[-1]) * n_features))
else:
value_range = (floor(float(hyperparameter_search_space.value_range[0]) * n_features),)

hyperparameter_search_space = HyperparameterSearchSpace(
hyperparameter=hyperparameter_name,
value_range=(
floor(float(hyperparameter_search_space.value_range[0]) * n_features),
floor(float(hyperparameter_search_space.value_range[1]) * n_features)),
value_range=value_range,
default_value=ceil(float(hyperparameter_search_space.default_value) * n_features),
log=log)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
X['X_train'] = preprocess(dataset=X_train, transforms=transforms)

# We need to also save the preprocess transforms for inference
X.update({'preprocess_transforms': transforms})
X.update({
'preprocess_transforms': transforms,
'shape_after_preprocessing': X['X_train'].shape[1:]
})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def __init__(self,
self.add_fit_requirements([
FitRequirement('X_train', (np.ndarray, pd.DataFrame, spmatrix), user_defined=True,
dataset_property=False),
FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True),
FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False),
FitRequirement('shape_after_preprocessing', (Iterable,), user_defined=False, dataset_property=False),
FitRequirement('network_embedding', (nn.Module,), user_defined=False, dataset_property=False)
])
self.backbone: nn.Module = None
Expand All @@ -49,9 +48,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
Self
"""
self.check_requirements(X, y)
X_train = X['X_train']

input_shape = X_train.shape[1:]
input_shape = X['shape_after_preprocessing']

input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape)
self.input_shape = input_shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def get_output_shape(network: torch.nn.Module, input_shape: Tuple[int, ...]
:param input_shape: shape of the input
:return: output_shape
"""
placeholder = torch.randn((2, *input_shape), dtype=torch.float)
# as we are using nn embedding, 2 is a safe upper limit as 3
# is the lowest `min_values_for_embedding` can be
placeholder = torch.randint(high=2, size=(2, *input_shape), dtype=torch.float)
with torch.no_grad():
output = network(placeholder)

Expand Down
Loading