Skip to content

Commit

Permalink
test fix in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Aug 16, 2022
1 parent d6bb8c8 commit d4717fb
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
7 changes: 3 additions & 4 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ class TabularFeatureValidator(BaseFeatureValidator):
transformer.
Attributes:
categories (List[List[str]]):
List for which an element at each index is a
list containing the categories for the respective
categorical column.
num_categories_per_col (List[int]):
List for which an element at each index is the number
of categories for the respective categorical column.
transformed_columns (List[str])
List of columns that were transformed.
column_transformer (Optional[BaseEstimator])
Expand Down
4 changes: 2 additions & 2 deletions autoPyTorch/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(self,
self.num_features: int = self.validator.feature_validator.num_features # type: ignore[assignment]
self.num_targets: int = self.validator.target_validator.out_dimensionality # type: ignore[assignment]

self.categories = self.validator.feature_validator.categories
self.num_categories_per_col = self.validator.feature_validator.num_categories_per_col

self.feature_shapes = self.validator.feature_shapes
self.feature_names = tuple(self.validator.feature_names)
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def get_required_dataset_info(self) -> Dict[str, Any]:
'categorical_features': self.categorical_features,
'numerical_columns': self.numerical_columns,
'categorical_columns': self.categorical_columns,
'categories': self.categories,
'num_categories_per_col': self.num_categories_per_col,
})
return info

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class ColumnSplitter(autoPyTorchTabularPreprocessingComponent):
"""
Removes features that have the same value in the training data.
Splits categorical columns into embed or encode columns based on a hyperparameter.
"""
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def __init__(self,
def fit(self, X: Dict[str, Any], y: Any = None) -> TimeSeriesBaseEncoder:
OneHotEncoder.fit(self, X, y)
categorical_columns = X['dataset_properties']['categorical_columns']
n_features_cat = X['dataset_properties']['categories']
num_categories_per_col = X['dataset_properties']['num_categories_per_col']
feature_names = X['dataset_properties']['feature_names']
feature_shapes = X['dataset_properties']['feature_shapes']

if len(n_features_cat) == 0:
n_features_cat = self.preprocessor['categorical'].categories # type: ignore
if len(num_categories_per_col) == 0:
num_categories_per_col = [len(cat) for cat in self.preprocessor['categorical'].categories] # type: ignore
for i, cat_column in enumerate(categorical_columns):
feature_shapes[feature_names[cat_column]] = len(n_features_cat[i])
feature_shapes[feature_names[cat_column]] = num_categories_per_col[i]
self.feature_shapes = feature_shapes
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def __init__(self) -> None:
super(TimeSeriesBaseEncoder, self).__init__()
self.add_fit_requirements([
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categories', (List,), user_defined=True, dataset_property=True),
FitRequirement('num_categories_per_col', (List,), user_defined=True, dataset_property=True),
FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True),
FitRequirement('feature_shapes', (Dict, ), user_defined=True, dataset_property=True),
])
self.feature_shapes: Union[Dict[str, int]] = {}
self.feature_shapes: Dict[str, int] = {}

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down
9 changes: 7 additions & 2 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
raise RuntimeError("Budget exhausted without finishing an epoch.")

if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
use_double = 'float64' in X['preprocessed_dtype']
# By default, we assume the data is double. Only if the data was preprocessed,
# we check the dtype and use it accordingly
preprocessed_dtype = X.get('preprocessed_dtype', None)
if preprocessed_dtype is None:
use_double = True
else:
use_double = 'float64' in preprocessed_dtype

# update batch norm statistics
swa_model = self.choice.swa_model.double() if use_double else self.choice.swa_model
Expand All @@ -458,7 +464,6 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
# we update only the last network which pertains to the stochastic weight averaging model
snapshot_model = self.choice.model_snapshots[-1].double() if use_double else self.choice.model_snapshots[-1]
swa_utils.update_bn(X['train_data_loader'], snapshot_model)
update_model_state_dict_from_swa(X['network_snapshots'][-1], self.choice.swa_model.state_dict())

# wrap up -- add score if not evaluating every epoch
if not self.eval_valid_each_epoch(X):
Expand Down

0 comments on commit d4717fb

Please sign in to comment.