Skip to content

Commit

Permalink
Aggressive removal of code
Browse files Browse the repository at this point in the history
  • Loading branch information
Neeratyoy committed Aug 29, 2024
1 parent bb7abf2 commit eb02fa2
Show file tree
Hide file tree
Showing 14 changed files with 457 additions and 1,282 deletions.
10 changes: 5 additions & 5 deletions neps/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

from .base_optimizer import BaseOptimizer
from .bayesian_optimization.cost_cooling import CostCooling
from .bayesian_optimization.mf_tpe import MultiFidelityPriorWeightedTreeParzenEstimator
from .bayesian_optimization.optimizer import BayesianOptimization
from .grid_search.optimizer import GridSearch
from .multi_fidelity.dyhpo import MFEIBO
from .multi_fidelity.ifbo import IFBO
from .multi_fidelity.hyperband import (
MOBSTER,
AsynchronousHyperband,
Expand Down Expand Up @@ -41,10 +40,11 @@
"asha": AsynchronousSuccessiveHalving,
"hyperband": Hyperband,
"asha_prior": AsynchronousSuccessiveHalvingWithPriors,
"multifidelity_tpe": MultiFidelityPriorWeightedTreeParzenEstimator,
"hyperband_custom_default": HyperbandCustomDefault,
"priorband": PriorBand,
"priorband_bo": partial(PriorBand, model_based=True),
"priorband_asha": PriorBandAsha,
"priorband_asha_hyperband": PriorBandAshaHB,
"mobster": MOBSTER,
"ifbo_ei": MFEIBO,
"ifbo": partial(MFEIBO, acquisition="MFPI-random"),
"ifbo": IFBO,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from neps.optimizers.bayesian_optimization.acquisition_functions.ei import (
ComprehensiveExpectedImprovement,
)
from neps.optimizers.bayesian_optimization.acquisition_functions.mf_ei import MFEI
from neps.optimizers.bayesian_optimization.acquisition_functions.mf_pi import MFPI_Random
from neps.optimizers.bayesian_optimization.acquisition_functions.ucb import (
UpperConfidenceBound,
MF_UCB,
)
from neps.optimizers.bayesian_optimization.acquisition_functions.prior_weighted import (
DecayingPriorWeightedAcquisition,
Expand All @@ -36,11 +34,6 @@
in_fill="posterior",
augmented_ei=True,
),
"MFEI": partial(
MFEI,
in_fill="best",
augmented_ei=False,
),
"MFPI-random": partial(
MFPI_Random,
in_fill="best",
Expand All @@ -50,17 +43,13 @@
UpperConfidenceBound,
maximize=False,
),
"MF-UCB": partial(
MF_UCB,
maximize=False,
),
}

__all__ = [
"AcquisitionMapping",
"ComprehensiveExpectedImprovement",
"MFEI",
"UpperConfidenceBound",
"MF_UCB",
"DecayingPriorWeightedAcquisition",
"MFPI_Random",
"UCB",
]
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,6 @@ def preprocess_gp(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]:
x, inc_list = self.preprocess(x)
return x, inc_list

def preprocess_deep_gp(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]:
x, inc_list = self.preprocess(x)
x_lcs = []
for idx in x.index:
if idx in self.observations.df.index.levels[0]:
# TODO: Samir, check if `budget_id=None` is okay?
# budget_level = self.get_budget_level(x[idx])
# extracting the available/observed learning curve
lc = self.observations.extract_learning_curve(idx, budget_id=None)
else:
# initialize a learning curve with a placeholder
# This is later padded accordingly for the Conv1D layer
lc = []
x_lcs.append(lc)
self.surrogate_model.set_prediction_learning_curves(x_lcs)
return x, inc_list

def preprocess_pfn(self, x: pd.Series) -> Tuple[torch.Tensor, pd.Series, torch.Tensor]:
"""Prepares the configurations for appropriate EI calculation.
Expand All @@ -78,7 +61,7 @@ def preprocess_pfn(self, x: pd.Series) -> Tuple[torch.Tensor, pd.Series, torch.T
return _x, _x_tok, inc_list


# NOTE: the order of inheritance is important
# NOTE: the order of inheritance is important by MRO
class MFEI(MFStepBase, ComprehensiveExpectedImprovement):
def __init__(
self,
Expand Down Expand Up @@ -152,17 +135,12 @@ def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]:
def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Series]:
# deepcopy
_x = pd.Series([x.loc[idx].copy() for idx in x.index.values], index=x.index)
if self.surrogate_model_name == "pfn":
if self.surrogate_model_name == "ftpfn":
_x, _x_tok, inc_list = self.preprocess_pfn(
x.copy()
) # IMPORTANT change from vanilla-EI
ei = self.eval_pfn_ei(_x_tok, inc_list)
elif self.surrogate_model_name in ["deep_gp", "dpl"]:
_x, inc_list = self.preprocess_deep_gp(
_x
) # IMPORTANT change from vanilla-EI
ei = self.eval_gp_ei(_x.values.tolist(), inc_list)
elif self.surrogate_model_name == "gp":
elif self.surrogate_model_name in ["gp", "gp_hierarchy"]:
_x, inc_list = self.preprocess_gp(
_x
) # IMPORTANT change from vanilla-EI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,10 @@ def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Ser
_x = pd.Series([deepcopy(x.loc[idx]) for idx in x.index.values], index=x.index)
if self.surrogate_model_name == "ftpfn":
_x, _x_tok, inc_list = self.preprocess_pfn(
deepcopy(x.copy())
) # IMPORTANT change from vanilla-EI
pi = self.eval_pfn_pi(_x_tok, inc_list)
elif self.surrogate_model_name in ["deep_gp", "dpl"]:
_x, inc_list = self.preprocess_deep_gp(
_x
) # IMPORTANT change from vanilla-EI
pi = self.eval_gp_pi(_x.values.tolist(), inc_list)
elif self.surrogate_model_name == "gp":
pi = self.eval_pfn_pi(_x_tok, inc_list)
elif self.surrogate_model_name in ["gp", "gp_hierarchy"]:
_x, inc_list = self.preprocess_gp(
_x
) # IMPORTANT change from vanilla-EI
Expand Down
13 changes: 0 additions & 13 deletions neps/optimizers/bayesian_optimization/acquisition_functions/ucb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,3 @@ def eval(
ucb_scores = ucb_scores.detach().numpy() * sign

return ucb_scores


class MF_UCB(UpperConfidenceBound):

def preprocess(self, x: Iterable) -> Iterable:
performances = self.observations.get_best_performance_for_each_budget()
pass

def eval(
self, x: Iterable, asscalar: bool = False
) -> Union[np.ndarray, torch.Tensor, float]:
x = self.preprocess(x)
return self.eval(x, asscalar=asscalar)
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
from neps.optimizers.bayesian_optimization.acquisition_samplers.base_acq_sampler import AcquisitionSampler


class FreezeThawSampler(AcquisitionSampler):
SAMPLES_TO_DRAW = 100 # number of random samples to draw for optimizing acquisition function


SAMPLES_TO_DRAW = 100 # number of random samples to draw at lowest fidelity
class FreezeThawSampler(AcquisitionSampler):

def __init__(self, **kwargs):
def __init__(self, samples_to_draw: int = None, **kwargs):
super().__init__(**kwargs)
self.observations = None
self.b_step = None
self.n = None
self.pipeline_space = None
# args to manage tabular spaces/grid
self.is_tabular = False
self.is_tabular = False # flag is set by `set_state()`
self.sample_full_table = None
self.samples_to_draw = samples_to_draw if samples_to_draw is not None else SAMPLES_TO_DRAW
self.set_sample_full_tabular(True) # sets flag that samples full table

def set_sample_full_tabular(self, flag: bool=False):
Expand All @@ -34,7 +36,7 @@ def set_sample_full_tabular(self, flag: bool=False):
def _sample_new(
self, index_from: int, n: int = None, ignore_fidelity: bool = False
) -> pd.Series:
n = n if n is not None else self.SAMPLES_TO_DRAW
n = n if n is not None else self.samples_to_draw
new_configs = [
self.pipeline_space.sample(
patience=self.patience, user_priors=False, ignore_fidelity=ignore_fidelity
Expand All @@ -53,10 +55,10 @@ def _sample_new_unique(
patience: int = 10,
ignore_fidelity: bool = False,
) -> pd.Series:
n = n if n is not None else self.SAMPLES_TO_DRAW
n = n if n is not None else self.samples_to_draw
assert (
patience > 0 and n > 0
), "Patience and SAMPLES_TO_DRAW must be larger than 0"
), "Patience and `samples_to_draw` must be larger than 0"

existing_configs = self.observations.all_configs_list()
new_configs = []
Expand Down Expand Up @@ -108,7 +110,7 @@ def sample(
"""Samples a new set and returns the total set of observed + new configs."""
partial_configs = self.observations.get_partial_configs_at_max_seen()

_n = n if n is not None else self.SAMPLES_TO_DRAW
_n = n if n is not None else self.samples_to_draw
if self.is_tabular:
# handles tabular data such that the entire unseen set of configs from the
# table is considered to be the new set of candidates
Expand Down Expand Up @@ -168,7 +170,7 @@ def set_state(
self.pipeline_space = pipeline_space
self.observations = observations
self.b_step = b_step
self.n = n if n is not None else self.SAMPLES_TO_DRAW
self.n = n if n is not None else self.samples_to_draw
if (
hasattr(self.pipeline_space, "custom_grid_table")
and self.pipeline_space.custom_grid_table is not None
Expand Down
Loading

0 comments on commit eb02fa2

Please sign in to comment.