Skip to content

Commit

Permalink
[FIX] SWA and SE with non cyclic schedulers (#395)
Browse files Browse the repository at this point in the history
* Enable learned embeddings, fix bug with non cyclic schedulers

* add forbidden condition cyclic lr

* refactor base_pipeline forbidden conditions

* Apply suggestions from code review

Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>

Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
  • Loading branch information
ravinkohli and nabenabe0928 committed Mar 9, 2022
1 parent 45a7043 commit 9d622db
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 142 deletions.
63 changes: 63 additions & 0 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from copy import copy
import warnings
from abc import ABCMeta
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple, Union

from ConfigSpace import Configuration
from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause

import numpy as np

Expand Down Expand Up @@ -295,6 +297,67 @@ def _get_hyperparameter_search_space(self,
"""
raise NotImplementedError()

def _add_forbidden_conditions(self, cs):
"""
Add forbidden conditions to ensure valid configurations.
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder
and CyclicLR is disabled when using stochastic weight averaging and snapshot
ensembling.
Args:
cs (ConfigurationSpace):
Configuration space to which forbidden conditions are added.
"""

# Learned Entity Embedding is only valid when encoder is one hot encoder
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
if 'LearnedEntityEmbedding' in embeddings:
encoders = cs.get_hyperparameter('encoder:__choice__').choices
possible_default_embeddings = copy(list(embeddings))
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]

for encoder in encoders:
if encoder == 'OneHotEncoder':
continue
while True:
try:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
))
break
except ValueError:
# change the default and try again
try:
default = possible_default_embeddings.pop()
except IndexError:
raise ValueError("Cannot find a legal default configuration")
cs.get_hyperparameter('network_embedding:__choice__').default_value = default

# Disable CyclicLR until todo is completed.
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
trainers = cs.get_hyperparameter('trainer:__choice__').choices
for trainer in trainers:
available_schedulers = cs.get_hyperparameter('lr_scheduler:__choice__').choices
# TODO: update cyclic lr to use n_restarts and adjust according to batch size
cyclic_lr_name = 'CyclicLR'
if cyclic_lr_name in available_schedulers:
# disable snapshot ensembles and stochastic weight averaging
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
f'trainer:{trainer}:use_snapshot_ensemble'), True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
return cs

def __repr__(self) -> str:
"""Retrieves a str representation of the current pipeline
Expand Down
103 changes: 48 additions & 55 deletions autoPyTorch/pipeline/components/setup/network_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,71 +148,64 @@ def get_hyperparameter_search_space(
if default is None:
defaults = [
'NoEmbedding',
# 'LearnedEntityEmbedding',
'LearnedEntityEmbedding',
]
for default_ in defaults:
if default_ in available_embedding:
default = default_
break

# Restrict embedding to NoEmbedding until preprocessing is fixed
embedding = CSH.CategoricalHyperparameter('__choice__',
['NoEmbedding'],
default_value=default)
if isinstance(dataset_properties['categorical_columns'], list):
categorical_columns = dataset_properties['categorical_columns']
else:
categorical_columns = []

updates = self._get_search_space_updates()
if '__choice__' in updates.keys():
choice_hyperparameter = updates['__choice__']
if not set(choice_hyperparameter.value_range).issubset(available_embedding):
raise ValueError("Expected given update for {} to have "
"choices in {} got {}".format(self.__class__.__name__,
available_embedding,
choice_hyperparameter.value_range))
if len(categorical_columns) == 0:
assert len(choice_hyperparameter.value_range) == 1
if 'NoEmbedding' not in choice_hyperparameter.value_range:
raise ValueError("Provided {} in choices, however, the dataset "
"is incompatible with it".format(choice_hyperparameter.value_range))
embedding = CSH.CategoricalHyperparameter('__choice__',
choice_hyperparameter.value_range,
default_value=choice_hyperparameter.default_value)
else:

if len(categorical_columns) == 0:
default = 'NoEmbedding'
if include is not None and default not in include:
raise ValueError("Provided {} in include, however, the dataset "
"is incompatible with it".format(include))
embedding = CSH.CategoricalHyperparameter('__choice__',
['NoEmbedding'],
default_value=default)
else:
embedding = CSH.CategoricalHyperparameter('__choice__',
list(available_embedding.keys()),
default_value=default)

cs.add_hyperparameter(embedding)
for name in embedding.choices:
updates = self._get_search_space_updates(prefix=name)
config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore
**updates)
parent_hyperparameter = {'parent': embedding, 'value': name}
cs.add_configuration_space(
name,
config_space,
parent_hyperparameter=parent_hyperparameter
)

self.configuration_space_ = cs
self.dataset_properties_ = dataset_properties
return cs
# categorical_columns = dataset_properties['categorical_columns'] \
# if isinstance(dataset_properties['categorical_columns'], List) else []

# updates = self._get_search_space_updates()
# if '__choice__' in updates.keys():
# choice_hyperparameter = updates['__choice__']
# if not set(choice_hyperparameter.value_range).issubset(available_embedding):
# raise ValueError("Expected given update for {} to have "
# "choices in {} got {}".format(self.__class__.__name__,
# available_embedding,
# choice_hyperparameter.value_range))
# if len(categorical_columns) == 0:
# assert len(choice_hyperparameter.value_range) == 1
# if 'NoEmbedding' not in choice_hyperparameter.value_range:
# raise ValueError("Provided {} in choices, however, the dataset "
# "is incompatible with it".format(choice_hyperparameter.value_range))
# embedding = CSH.CategoricalHyperparameter('__choice__',
# choice_hyperparameter.value_range,
# default_value=choice_hyperparameter.default_value)
# else:

# if len(categorical_columns) == 0:
# default = 'NoEmbedding'
# if include is not None and default not in include:
# raise ValueError("Provided {} in include, however, the dataset "
# "is incompatible with it".format(include))
# embedding = CSH.CategoricalHyperparameter('__choice__',
# ['NoEmbedding'],
# default_value=default)
# else:
# embedding = CSH.CategoricalHyperparameter('__choice__',
# list(available_embedding.keys()),
# default_value=default)

# cs.add_hyperparameter(embedding)
# for name in embedding.choices:
# updates = self._get_search_space_updates(prefix=name)
# config_space = available_embedding[name].get_hyperparameter_search_space(
# dataset_properties, # type: ignore
# **updates)
# parent_hyperparameter = {'parent': embedding, 'value': name}
# cs.add_configuration_space(
# name,
# config_space,
# parent_hyperparameter=parent_hyperparameter
# )

# self.configuration_space_ = cs
# self.dataset_properties_ = dataset_properties
# return cs

def transform(self, X: np.ndarray) -> np.ndarray:
assert self.choice is not None, "Cannot call transform before the object is initialized"
Expand Down
71 changes: 40 additions & 31 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
"""
pass

def _swa_update(self) -> None:
"""
perform swa model update
"""
if self.swa_model is None:
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
self.swa_model.update_parameters(self.model)
self.swa_updated = True

def _se_update(self, epoch: int) -> None:
"""
Add latest model or swa_model to model snapshot ensemble
Args:
epoch (int):
current epoch
"""
if self.model_snapshots is None:
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled")
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
if is_last_epoch and self.use_stochastic_weight_averaging:
model_copy = deepcopy(self.swa_model)
else:
model_copy = deepcopy(self.model)

assert model_copy is not None
model_copy.cpu()
self.model_snapshots.append(model_copy)
self.model_snapshots = self.model_snapshots[-self.se_lastk:]

def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
"""
Optional place holder for AutoPytorch Extensions.
Expand All @@ -344,39 +373,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
if X['is_cyclic_scheduler']:
if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1:
if self.use_stochastic_weight_averaging:
assert self.swa_model is not None, "SWA model can't be none when" \
" stochastic weight averaging is enabled"
self.swa_model.update_parameters(self.model)
self.swa_updated = True
self._swa_update()
if self.use_snapshot_ensemble:
assert self.model_snapshots is not None, "model snapshots container can't be " \
"none when snapshot ensembling is enabled"
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
if is_last_epoch and self.use_stochastic_weight_averaging:
model_copy = deepcopy(self.swa_model)
else:
model_copy = deepcopy(self.model)

assert model_copy is not None
model_copy.cpu()
self.model_snapshots.append(model_copy)
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
self._se_update(epoch=epoch)
else:
if epoch > self._budget_threshold:
if self.use_stochastic_weight_averaging:
assert self.swa_model is not None, "SWA model can't be none when" \
" stochastic weight averaging is enabled"
self.swa_model.update_parameters(self.model)
self.swa_updated = True
if self.use_snapshot_ensemble:
assert self.model_snapshots is not None, "model snapshots container can't be " \
"none when snapshot ensembling is enabled"
model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \
else deepcopy(self.model)
assert model_copy is not None
model_copy.cpu()
self.model_snapshots.append(model_copy)
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
if epoch > self._budget_threshold and self.use_stochastic_weight_averaging:
self._swa_update()

if (
self.use_snapshot_ensemble
and self.budget_tracker.max_epochs is not None
and epoch > (self.budget_tracker.max_epochs - self.se_lastk)
):
self._se_update(epoch=epoch)
return False

def _scheduler_step(
Expand Down
1 change: 1 addition & 0 deletions autoPyTorch/pipeline/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _get_hyperparameter_search_space(self,

# Here we add custom code, like this with this
# is not a valid configuration
cs = self._add_forbidden_conditions(cs)

self.configuration_space = cs
self.dataset_properties = dataset_properties
Expand Down
31 changes: 3 additions & 28 deletions autoPyTorch/pipeline/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause

import numpy as np

Expand Down Expand Up @@ -261,33 +260,9 @@ def _get_hyperparameter_search_space(self,
cs=cs, dataset_properties=dataset_properties,
exclude=exclude, include=include, pipeline=self.steps)

# Here we add custom code, that is used to ensure valid configurations, For example
# Learned Entity Embedding is only valid when encoder is one hot encoder
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
if 'LearnedEntityEmbedding' in embeddings:
encoders = cs.get_hyperparameter('encoder:__choice__').choices
possible_default_embeddings = copy.copy(list(embeddings))
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]

for encoder in encoders:
if encoder == 'OneHotEncoder':
continue
while True:
try:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
))
break
except ValueError:
# change the default and try again
try:
default = possible_default_embeddings.pop()
except IndexError:
raise ValueError("Cannot find a legal default configuration")
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
# Here we add custom code, like this with this
# is not a valid configuration
cs = self._add_forbidden_conditions(cs)

self.configuration_space = cs
self.dataset_properties = dataset_properties
Expand Down
Loading

0 comments on commit 9d622db

Please sign in to comment.