Skip to content

Commit

Permalink
Updated the utils
Browse files Browse the repository at this point in the history
  • Loading branch information
prithagupta committed Aug 23, 2024
1 parent 0afec08 commit eabd75e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 84 deletions.
85 changes: 2 additions & 83 deletions autoqild/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def check_and_delete_corrupt_h5_file(file_path, logger):
logger.info(f"File does not exist '{basename}'")


def standardize_features(x_train, x_test):
def standardize_features(x_train, x_test, scaler=RobustScaler):
"""Standardize the features in the training and test sets using
RobustScaler as a default.
Expand All @@ -246,89 +246,8 @@ def standardize_features(x_train, x_test):
x_test : array-like
Standardized test set features.
"""
standardize = Standardize()
standardize = scaler()
x_train = standardize.fit_transform(x_train)
x_test = standardize.transform(x_test)
return x_train, x_test


class Standardize:
"""A class for standardizing features using a specified scaler.
Parameters
----------
scalar : object, optional
The scaling class to use (default is "RobustScaler").
Attributes
----------
n_features : list or None
The list of feature names if `X` is a dictionary.
scalars : dict
A dictionary of scalers for each feature if `X` is a dictionary.
"""

def __init__(self, scalar=RobustScaler):
self.scalar = scalar
self.n_features = None
self.scalars = dict()

def fit(self, X):
"""Fit the scaler to the data.
Parameters
----------
X : array-like or dict
The data to fit the scaler on.
Returns
-------
self : object
Fitted scaler.
"""
if isinstance(X, dict):
self.n_features = list(X.keys())
for k, x in X.items():
scalar = self.scalar()
self.scalars[k] = scalar.fit(x)
if isinstance(X, (np.ndarray, np.generic)):
self.scalar = self.scalar()
self.scalar.fit(X)
self.n_features = X.shape[-1]

def transform(self, X):
"""Apply the scaling transformation to the data.
Parameters
----------
X : array-like or dict
The data to transform.
Returns
-------
X : array-like or dict
The transformed data.
"""
if isinstance(X, dict):
for n in self.n_features:
X[n] = self.scalars[n].transform(X[n])
if isinstance(X, (np.ndarray, np.generic)):
X = self.scalar.transform(X)
return X

def fit_transform(self, X):
"""Fit the scaler and transform the data.
Parameters
----------
X : array-like or dict
The data to fit and transform.
Returns
-------
X : array-like or dict
The transformed data.
"""
self.fit(X)
X = self.transform(X)
return X
2 changes: 1 addition & 1 deletion docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ List of references for the implemented learning algorithms, AutoML tools and bas
-------------------------
🚀 Baseline MI Estimators
-------------------------
- `Gaussain Mixture Models <https://ieeexplore.ieee.org/document/6889561>`_: Polo et al. (2022)
- `Gaussain Mixture Model (GMM) <https://ieeexplore.ieee.org/document/6889561>`_: Polo et al. (2022)
- `Mutual Information Neural Estimation (MINE) <https://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf>`_: Belghazi et al. (2018)
- `PC-softmax <https://arxiv.org/abs/1911.10688>`_: Qin et al. (2020)

0 comments on commit eabd75e

Please sign in to comment.