-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Embedding to avoid OneHotEncoding all categorical columns #425
Changes from 8 commits
05d187c
769b51e
0d9beae
a2d84e5
6b188a4
539fdba
adc26d5
9573358
95a5969
b2c0ecc
6830116
3761b53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,7 +111,7 @@ def send_warnings_to_log( | |
return prediction | ||
|
||
|
||
def get_search_updates(categorical_indicator: List[bool]): | ||
def get_search_updates(categorical_indicator: List[bool]) -> HyperparameterSearchSpaceUpdates: | ||
""" | ||
These updates mimic the autopytorch tabular paper. | ||
Returns: | ||
|
@@ -120,8 +120,8 @@ def get_search_updates(categorical_indicator: List[bool]): | |
The search space updates like setting different hps to different values or ranges. | ||
""" | ||
|
||
has_cat_features = any(categorical_indicator) | ||
has_numerical_features = not all(categorical_indicator) | ||
# has_cat_features = any(categorical_indicator) | ||
# has_numerical_features = not all(categorical_indicator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be removed. |
||
|
||
search_space_updates = HyperparameterSearchSpaceUpdates() | ||
|
||
|
@@ -267,7 +267,8 @@ def __init__( | |
|
||
self.input_validator: Optional[BaseInputValidator] = None | ||
|
||
self.search_space_updates = search_space_updates if search_space_updates is not None else get_search_updates(categorical_indicator) | ||
# if search_space_updates is not None else get_search_updates(categorical_indicator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this could also be removed.
ravinkohli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.search_space_updates = search_space_updates | ||
if search_space_updates is not None: | ||
if not isinstance(self.search_space_updates, | ||
HyperparameterSearchSpaceUpdates): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import json | ||
from multiprocessing.queues import Queue | ||
import os | ||
from multiprocessing.queues import Queue | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
from ConfigSpace.configuration_space import Configuration | ||
|
@@ -22,6 +22,7 @@ | |
fit_and_suppress_warnings | ||
) | ||
from autoPyTorch.evaluation.utils import DisableFileOutputParameters | ||
from autoPyTorch.pipeline.base_pipeline import BasePipeline | ||
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric | ||
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline | ||
from autoPyTorch.utils.common import dict_repr, subsampler | ||
|
@@ -195,24 +196,7 @@ def fit_predict_and_loss(self) -> None: | |
additional_run_info = pipeline.get_additional_run_info() if hasattr( | ||
pipeline, 'get_additional_run_info') else {} | ||
|
||
# add learning curve of configurations to additional_run_info | ||
if isinstance(pipeline, TabularClassificationPipeline): | ||
if hasattr(pipeline.named_steps['trainer'], 'run_summary'): | ||
run_summary = pipeline.named_steps['trainer'].run_summary | ||
split_types = ['train', 'val', 'test'] | ||
run_summary_dict = dict( | ||
run_summary={}, | ||
budget=self.budget, | ||
seed=self.seed, | ||
config_id=self.configuration.config_id, | ||
num_run=self.num_run | ||
) | ||
for split_type in split_types: | ||
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(f'{split_type}_loss', None) | ||
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(f'{split_type}_metrics', None) | ||
self.logger.debug(f"run_summary_dict {json.dumps(run_summary_dict)}") | ||
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file: | ||
file.write(f"{json.dumps(run_summary_dict)}\n") | ||
# self._write_run_summary(pipeline) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on the functionality that was encapsulated in the function, I think this should be called here, right? |
||
|
||
status = StatusType.SUCCESS | ||
|
||
|
@@ -370,6 +354,27 @@ def fit_predict_and_loss(self) -> None: | |
status=status, | ||
) | ||
|
||
def _write_run_summary(self, pipeline: BasePipeline) -> None: | ||
# add learning curve of configurations to additional_run_info | ||
if isinstance(pipeline, TabularClassificationPipeline): | ||
assert isinstance(self.configuration, Configuration) | ||
if hasattr(pipeline.named_steps['trainer'], 'run_summary'): | ||
run_summary = pipeline.named_steps['trainer'].run_summary | ||
split_types = ['train', 'val', 'test'] | ||
run_summary_dict = dict( | ||
run_summary={}, | ||
budget=self.budget, | ||
seed=self.seed, | ||
config_id=self.configuration.config_id, | ||
num_run=self.num_run) | ||
for split_type in split_types: | ||
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get( | ||
f'{split_type}_loss', None) | ||
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get( | ||
f'{split_type}_metrics', None) | ||
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file: | ||
file.write(f"{json.dumps(run_summary_dict)}\n") | ||
|
||
def _fit_and_predict(self, pipeline: BaseEstimator, fold: int, train_indices: Union[np.ndarray, List], | ||
test_indices: Union[np.ndarray, List], | ||
add_pipeline_to_self: bool | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from copy import copy | ||
import warnings | ||
from abc import ABCMeta | ||
from collections import Counter | ||
|
@@ -297,7 +296,7 @@ def _get_hyperparameter_search_space(self, | |
""" | ||
raise NotImplementedError() | ||
|
||
def _add_forbidden_conditions(self, cs): | ||
def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpace: | ||
""" | ||
Add forbidden conditions to ensure valid configurations. | ||
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on the chances introduced in the PR I think the first condition mentioned in the docstring regarding Learned Entity Embedding should be removed. |
||
|
@@ -310,33 +309,6 @@ def _add_forbidden_conditions(self, cs): | |
|
||
""" | ||
|
||
# Learned Entity Embedding is only valid when encoder is one hot encoder | ||
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): | ||
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices | ||
if 'LearnedEntityEmbedding' in embeddings: | ||
encoders = cs.get_hyperparameter('encoder:__choice__').choices | ||
possible_default_embeddings = copy(list(embeddings)) | ||
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')] | ||
|
||
for encoder in encoders: | ||
if encoder == 'OneHotEncoder': | ||
continue | ||
while True: | ||
try: | ||
cs.add_forbidden_clause(ForbiddenAndConjunction( | ||
ForbiddenEqualsClause(cs.get_hyperparameter( | ||
'network_embedding:__choice__'), 'LearnedEntityEmbedding'), | ||
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) | ||
)) | ||
break | ||
except ValueError: | ||
# change the default and try again | ||
try: | ||
default = possible_default_embeddings.pop() | ||
except IndexError: | ||
raise ValueError("Cannot find a legal default configuration") | ||
cs.get_hyperparameter('network_embedding:__choice__').default_value = default | ||
|
||
# Disable CyclicLR until todo is completed. | ||
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): | ||
trainers = cs.get_hyperparameter('trainer:__choice__').choices | ||
|
@@ -347,7 +319,8 @@ def _add_forbidden_conditions(self, cs): | |
if cyclic_lr_name in available_schedulers: | ||
# disable snapshot ensembles and stochastic weight averaging | ||
snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble') | ||
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and True in snapshot_ensemble_hyperparameter.choices: | ||
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and \ | ||
True in snapshot_ensemble_hyperparameter.choices: | ||
cs.add_forbidden_clause(ForbiddenAndConjunction( | ||
ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True), | ||
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) | ||
|
@@ -549,7 +522,6 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], | |
node_hyperparameters, | ||
update.hyperparameter)) | ||
|
||
|
||
def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] | ||
) -> List[Tuple[str, PipelineStepType]]: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from ConfigSpace.configuration_space import ConfigurationSpace | ||
from ConfigSpace.hyperparameters import ( | ||
UniformIntegerHyperparameter, | ||
) | ||
|
||
import numpy as np | ||
|
||
|
||
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType | ||
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \ | ||
autoPyTorchTabularPreprocessingComponent | ||
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter | ||
|
||
|
||
class ColumnSplitter(autoPyTorchTabularPreprocessingComponent): | ||
""" | ||
Removes features that have the same value in the training data. | ||
""" | ||
def __init__( | ||
self, | ||
min_categories_for_embedding: float = 5, | ||
random_state: Optional[np.random.RandomState] = None | ||
): | ||
self.min_categories_for_embedding = min_categories_for_embedding | ||
self.random_state = random_state | ||
|
||
self.special_feature_types: Dict[str, List] = dict(encode_columns=[], embed_columns=[]) | ||
self.num_categories_per_col: Optional[List] = None | ||
super().__init__() | ||
|
||
def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'ColumnSplitter': | ||
|
||
self.check_requirements(X, y) | ||
|
||
if len(X['dataset_properties']['categorical_columns']) > 0: | ||
self.num_categories_per_col = [] | ||
for categories_per_column, column in zip(X['dataset_properties']['num_categories_per_col'], | ||
X['dataset_properties']['categorical_columns']): | ||
if ( | ||
categories_per_column >= self.min_categories_for_embedding | ||
): | ||
self.special_feature_types['embed_columns'].append(column) | ||
# we only care about the categories for columns to be embedded | ||
self.num_categories_per_col.append(categories_per_column) | ||
else: | ||
self.special_feature_types['encode_columns'].append(column) | ||
|
||
return self | ||
|
||
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: | ||
if self.num_categories_per_col is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it will be none when there were no categorical column, see line 38 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, but line 38 initializes I'm mentioning this because I thought in line 53 we check if there are columns to be embedded, currently the if conditions evaluates to true both for embedded and encoded columns. |
||
# update such that only n categories for embedding columns is passed | ||
X['dataset_properties']['num_categories_per_col'] = self.num_categories_per_col | ||
X.update(self.special_feature_types) | ||
return X | ||
|
||
@staticmethod | ||
def get_properties( | ||
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None | ||
) -> Dict[str, Union[str, bool]]: | ||
|
||
return { | ||
'shortname': 'ColumnSplitter', | ||
'name': 'Column Splitter', | ||
'handles_sparse': False, | ||
} | ||
|
||
@staticmethod | ||
def get_hyperparameter_search_space( | ||
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, | ||
min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace( | ||
hyperparameter="min_categories_for_embedding", | ||
value_range=(3, 7), | ||
default_value=3, | ||
log=True), | ||
) -> ConfigurationSpace: | ||
cs = ConfigurationSpace() | ||
|
||
add_hyperparameter(cs, min_categories_for_embedding, UniformIntegerHyperparameter) | ||
|
||
return cs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method argument is not used, I believe it could be removed.