diff --git a/docs/analyse.md b/docs/analyse.md index cd78a14d..80020cdd 100644 --- a/docs/analyse.md +++ b/docs/analyse.md @@ -166,6 +166,165 @@ The scatter plot matrix view provides an in-depth analysis of pairwise relations ![hparam_loggings3](doc_images/tensorboard/tblogger_hparam3.jpg) +## Summary CSV + +The argument `post_run_summary` in `neps.run` allows for the automatic generation of CSV files after a run is complete. The new root directory after utilizing this argument will look like the following: + +``` +ROOT_DIRECTORY +├── results +│ └── config_1 +│ ├── config.yaml +│ ├── metadata.yaml +│ └── result.yaml +├── summary_csv +│ ├── config_data.csv +│ └── run_status.csv +├── all_losses_and_configs.txt +├── best_loss_trajectory.txt +└── best_loss_with_config_trajectory.txt +``` + +- *`config_data.csv`*: Contains all configuration details in CSV format, ordered by ascending `loss`. Details include configuration hyperparameters, any returned result from the `run_pipeline` function, and metadata information. + +- *`run_status.csv`*: Provides general run details, such as the number of sampled configs, best configs, number of failed configs, best loss, etc. + +## TensorBoard Integration + +### Introduction + +[TensorBoard](https://www.tensorflow.org/tensorboard) serves as a valuable tool for visualizing machine learning experiments, offering the ability to observe losses and metrics throughout the model training process. In NePS, we use this powerful tool to show metrics of configurations during training in addition to comparisons to different hyperparameters used in the search for better diagnosis of the model. + +### The Logging Function + +The `tblogger.log` function is invoked within the model's training loop to facilitate logging of key metrics. + +!!! tip + +``` +The logger function is primarily designed for implementation within the `run_pipeline` function during the training of the neural network. +``` + +- **Signature:** + +```python +tblogger.log( + loss: float, + current_epoch: int, + write_config_scalar: bool = False, + write_config_hparam: bool = True, + write_summary_incumbent: bool = False, + extra_data: dict | None = None +) +``` + +- **Parameters:** + - `loss` (float): The loss value to be logged. + - `current_epoch` (int): The current epoch or iteration number. + - `write_config_scalar` (bool, optional): Set to `True` for a live loss trajectory for each configuration. + - `write_config_hparam` (bool, optional): Set to `True` for live parallel coordinate, scatter plot matrix, and table view. + - `write_summary_incumbent` (bool, optional): Set to `True` for a live incumbent trajectory. + - `extra_data` (dict, optional): Additional data to be logged, provided as a dictionary. + +### Extra Custom Logging + +NePS provides dedicated functions for customized logging using the `extra_data` argument. + +!!! note "Custom Logging Instructions" + +``` +Name the dictionary keys as the names of the values you want to log and pass one of the following functions as the values for a successful logging process. +``` + +#### 1- Extra Scalar Logging + +Logs new scalar data during training. Uses `current_epoch` from the log function as its `global_step`. + +- **Signature:** + +```python +tblogger.scalar_logging(value: float) +``` + +- **Parameters:** + - `value` (float): Any scalar value to be logged at the current epoch of `tblogger.log` function. + +#### 2- Extra Image Logging + +Logs images during training. Images can be resized, randomly selected, and a specified number can be logged at specified intervals. Uses `current_epoch` from the log function as its `global_step`. + +- **Signature:** + +```python +tblogger.image_logging( + image: torch.Tensor, + counter: int = 1, + resize_images: list[None | int] | None = None, + random_images: bool = True, + num_images: int = 20, + seed: int | np.random.RandomState | None = None, +) +``` + +- **Parameters:** + - `image` (torch.Tensor): Image tensor to be logged. + - `counter` (int): Log images every counter epochs (i.e., when current_epoch % counter equals 0). + - `resize_images` (list of int, optional): List of integers for image sizes after resizing (default: \[32, 32\]). + - `random_images` (bool, optional): Images are randomly selected if True (default: True). + - `num_images` (int, optional): Number of images to log (default: 20). + - `seed` (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control randomness and reproducibility (default: None). + +### Logging Example + +For illustration purposes, we have employed a straightforward example involving the tuning of hyperparameters for a model utilized in the classification of the MNIST dataset provided by [torchvision](https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html). + +You can find this example [here](https://github.com/automl/neps/blob/master/neps_examples/convenience/neps_tblogger_tutorial.py) + +!!! info "Important" +We have optimized the example for computational efficiency. If you wish to replicate the exact results showcased in the following section, we recommend the following modifications: + +``` +1- Increase maximum epochs from 2 to 10 + +2- Set the `write_summary_incumbent` argument to `True` + +3- Change the searcher from `random_search` to `bayesian_optimization` + +4- Increase the maximum evaluations before disabling `tblogger` from 2 to 14 + +5- Increase the maximum evaluations after disabling `tblogger` from 3 to 15 +``` + +### Visualization Results + +The following command will open a local host for TensorBoard visualizations, allowing you to view them either in real-time or after the run is complete. + +```bash +tensorboard --logdir path/to/root_directory +``` + +This image shows visualizations related to scalar values logged during training. Scalars typically include metrics such as loss, incumbent trajectory, a summary of losses for all configurations, and any additional data provided via the `extra_data` argument in the `tblogger.log` function. + +![scalar_loggings](doc_images/tensorboard/tblogger_scalar.jpg) + +This image represents visualizations related to logged images during training. It could include snapshots of input data, model predictions, or any other image-related information. In our case, we use images to depict instances of incorrect predictions made by the model. + +![image_loggings](doc_images/tensorboard/tblogger_image.jpg) + +The following images showcase visualizations related to hyperparameter logging in TensorBoard. These plots include three different views, providing insights into the relationship between different hyperparameters and their impact on the model. + +In the table view, you can explore hyperparameter configurations across five different trials. The table displays various hyperparameter values alongside corresponding evaluation metrics. + +![hparam_loggings1](doc_images/tensorboard/tblogger_hparam1.jpg) + +The parallel coordinate plot offers a holistic perspective on hyperparameter configurations. By presenting multiple hyperparameters simultaneously, this view allows you to observe the interactions between variables, providing insights into their combined influence on the model. + +![hparam_loggings2](doc_images/tensorboard/tblogger_hparam2.jpg) + +The scatter plot matrix view provides an in-depth analysis of pairwise relationships between different hyperparameters. By visualizing correlations and patterns, this view aids in identifying key interactions that may influence the model's performance. + +![hparam_loggings3](doc_images/tensorboard/tblogger_hparam3.jpg) + ## Status To show status information about a neural pipeline search run, use diff --git a/neps/metahyper/__init__.py b/neps/metahyper/__init__.py index 6e2aa0f7..6a76b9d1 100644 --- a/neps/metahyper/__init__.py +++ b/neps/metahyper/__init__.py @@ -1,2 +1,2 @@ -from .api import ConfigResult, Sampler, read, metahyper_run +from .api import ConfigResult, Sampler, metahyper_run, read from .utils import instance_from_map diff --git a/neps/metahyper/api.py b/neps/metahyper/api.py index 1dc9c3bf..8c75c0b0 100644 --- a/neps/metahyper/api.py +++ b/neps/metahyper/api.py @@ -2,6 +2,7 @@ import inspect import logging +import os import shutil import time import warnings @@ -273,6 +274,7 @@ def _check_max_evaluations( continue_until_max_evaluation_completed, ): logger.debug("Checking if max evaluations is reached") + # TODO: maybe not read everything again? previous_results, pending_configs, pending_configs_free = read( optimization_dir, serializer, logger ) @@ -292,6 +294,9 @@ def _sample_config(optimization_dir, sampler, serializer, logger, pre_load_hooks ) base_result_directory = optimization_dir / "results" + logger.debug(f"Previous results: {previous_results}") + logger.debug(f"Pending configs: {pending_configs}") + logger.debug(f"Pending configs: {pending_configs_free}") logger.debug("Sampling a new configuration") @@ -315,6 +320,9 @@ def _sample_config(optimization_dir, sampler, serializer, logger, pre_load_hooks "communication, but most likely some configs crashed during their execution " "or a jobtime-limit was reached." ) + # write some extra data per configuration if the optimizer has any + # if hasattr(sampler, "evaluation_data"): + # sampler.evaluation_data.write_all(pipeline_directory) if previous_config_id is not None: previous_config_id_file = pipeline_directory / "previous_config.id" @@ -452,6 +460,7 @@ def metahyper_run( ) evaluations_in_this_run = 0 + while True: if max_evaluations_total is not None and _check_max_evaluations( optimization_dir, diff --git a/neps/optimizers/base_optimizer.py b/neps/optimizers/base_optimizer.py index b4502eca..33bd1435 100644 --- a/neps/optimizers/base_optimizer.py +++ b/neps/optimizers/base_optimizer.py @@ -60,7 +60,7 @@ def load_state(self, state: Any): # pylint: disable=no-self-use super().load_state(state) def load_config(self, config_dict): - config = deepcopy(self.pipeline_space) + config = self.pipeline_space.copy() config.load_from(config_dict) return config diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/__init__.py b/neps/optimizers/bayesian_optimization/acquisition_functions/__init__.py index aedb483f..e78b59ff 100644 --- a/neps/optimizers/bayesian_optimization/acquisition_functions/__init__.py +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/__init__.py @@ -4,9 +4,11 @@ from typing import Callable from .ei import ComprehensiveExpectedImprovement -from .mf_ei import MFEI -from .ucb import UpperConfidenceBound, MF_UCB - +from .mf_ei import MFEI, MFEI_AtMax, MFEI_Dyna, MFEI_Random +from .mf_pi import MFPI, MFPI_AtMax, MFPI_Dyna, MFPI_Random, MFPI_Random_HiT +from .mf_two_step import MF_TwoStep +from .mf_ucb import MF_UCB, MF_UCB_AtMax, MF_UCB_Dyna +from .ucb import UpperConfidenceBound AcquisitionMapping: dict[str, Callable] = { "EI": partial( @@ -40,4 +42,73 @@ MF_UCB, maximize=False, ), + "MFEI-max": partial( + MFEI_AtMax, + in_fill="best", + augmented_ei=False, + ), + "MFEI-dyna": partial( + MFEI_Dyna, + in_fill="best", + augmented_ei=False, + ), + "MFEI-random": partial( + MFPI_Random, + in_fill="best", + augmented_ei=False, + ), + "MF-UCB-max": partial( + MF_UCB_AtMax, + maximize=False, + ), + "MF-UCB-dyna": partial( + MF_UCB_Dyna, + maximize=False, + ), + "MF_TwoStep": partial( + MF_TwoStep, + maximize=False, + ), + "MFPI": partial( + MFPI, + in_fill="best", + augmented_ei=False, + ), + "MFPI-max": partial( + MFPI_AtMax, + in_fill="best", + augmented_ei=False, + ), + "MFPI-random-horizon": partial( + MFPI_Random, + in_fill="best", + augmented_ei=False, + horizon="random", + threshold="0.0", + ), + "MFPI-dyna": partial( + MFPI_Dyna, + in_fill="best", + augmented_ei=False, + ), + "MFPI-random": partial( + MFPI_Random, + in_fill="best", + augmented_ei=False, + ), + +# Further Potential Acquisition Functions + +# "MFPI-thresh-max": partial( +# MFPI_Random, +# in_fill="best", +# augmented_ei=False, +# horizon="max", +# threshold="random", +# ), +# "MFPI-random-hit": partial( +# MFPI_Random_HiT, +# in_fill="best", +# augmented_ei=False, +# ), } diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ei.py b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ei.py index f677f894..cc3d702f 100644 --- a/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ei.py +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ei.py @@ -7,12 +7,81 @@ from torch.distributions import Normal from ....optimizers.utils import map_real_hyperparameters_from_tabular_ids -from ....search_spaces.search_space import SearchSpace +from ....search_spaces.search_space import IntegerParameter, SearchSpace from ...multi_fidelity.utils import MFObservedData +from .base_acquisition import BaseAcquisition from .ei import ComprehensiveExpectedImprovement -class MFEI(ComprehensiveExpectedImprovement): +class MFStepBase(BaseAcquisition): + """A class holding common operations that can be inherited. + + WARNING: Unsafe use of self attributes, can break if not used correctly. + """ + + def set_state( + self, + pipeline_space: SearchSpace, + surrogate_model: Any, + observations: MFObservedData, + b_step: Union[int, float], + **kwargs, + ): + # overload to select incumbent differently through observations + self.pipeline_space = pipeline_space + self.surrogate_model = surrogate_model + self.observations = observations + self.b_step = b_step + return + + def get_budget_level(self, config) -> int: + return int((config.fidelity.value - config.fidelity.lower) / self.b_step) + + def preprocess_gp(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + x, inc_list = self.preprocess(x) + return x, inc_list + + def preprocess_deep_gp(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + x, inc_list = self.preprocess(x) + x_lcs = [] + for idx in x.index: + if idx in self.observations.df.index.levels[0]: + # TODO: Samir, check if `budget_id=None` is okay? + # budget_level = self.get_budget_level(x[idx]) + # extracting the available/observed learning curve + lc = self.observations.extract_learning_curve(idx, budget_id=None) + else: + # initialize a learning curve with a placeholder + # This is later padded accordingly for the Conv1D layer + lc = [] + x_lcs.append(lc) + self.surrogate_model.set_prediction_learning_curves(x_lcs) + return x, inc_list + + def preprocess_pfn( + self, x: pd.Series + ) -> Tuple[torch.Tensor, pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + _x, inc_list = self.preprocess(x.copy()) + _x_tok = self.observations.tokenize(_x, as_tensor=True) + len_partial = len(self.observations.seen_config_ids) + z_min = x[0].fidelity.lower + z_max = x[0].fidelity.upper + # converting fidelity to the discrete budget level + # STRICT ASSUMPTION: fidelity is the second dimension + _x_tok[:len_partial, 1] = ( + _x_tok[:len_partial, 1] + self.b_step - z_min + ) / self.b_step + _x_tok[:, 1] = _x_tok[:, 1] / z_max + return _x, _x_tok, inc_list + + +# NOTE: the order of inheritance is important +class MFEI(MFStepBase, ComprehensiveExpectedImprovement): def __init__( self, pipeline_space: SearchSpace, @@ -20,30 +89,40 @@ def __init__( augmented_ei: bool = False, xi: float = 0.0, in_fill: str = "best", + inc_normalization: bool = False, log_ei: bool = False, ): super().__init__(augmented_ei, xi, in_fill, log_ei) self.pipeline_space = pipeline_space self.surrogate_model_name = surrogate_model_name + self.inc_normalization = inc_normalization self.surrogate_model = None self.observations = None self.b_step = None - def get_budget_level(self, config) -> int: - return int((config.fidelity.value - config.fidelity.lower) / self.b_step) + def preprocess_inc_list(self, **kwargs) -> list: + assert "budget_list" in kwargs, "Requires a list of query step for candidate set." + budget_list = kwargs["budget_list"] + performances = self.observations.get_best_performance_for_each_budget() + inc_list = [] + for budget_level in budget_list: + if budget_level in performances.index: + inc = performances[budget_level] + else: + inc = self.observations.get_best_seen_performance() + inc_list.append(inc) + return inc_list - def preprocess(self, x: pd.Series) -> Tuple[Iterable, Iterable]: + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: """Prepares the configurations for appropriate EI calculation. Takes a set of points and computes the budget and incumbent for each point, as required by the multi-fidelity Expected Improvement acquisition function. """ budget_list = [] - if self.pipeline_space.has_tabular: # preprocess tabular space differently # expected input: IDs pertaining to the tabular data - # expected output: IDs pertaining to current observations and set of HPs x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) indices_to_drop = [] for i, config in x.items(): @@ -67,76 +146,36 @@ def preprocess(self, x: pd.Series) -> Tuple[Iterable, Iterable]: # Drop unused configs x.drop(labels=indices_to_drop, inplace=True) - performances = self.observations.get_best_performance_for_each_budget() - inc_list = [] - for budget_level in budget_list: - if budget_level in performances.index: - inc = performances[budget_level] - else: - inc = self.observations.get_best_seen_performance() - inc_list.append(inc) + # Collecting incumbent list per configuration + inc_list = self.preprocess_inc_list(budget_list=budget_list) return x, torch.Tensor(inc_list) - def preprocess_gp(self, x: Iterable) -> Tuple[Iterable, Iterable]: - x, inc_list = self.preprocess(x) - return x.values.tolist(), inc_list - - def preprocess_deep_gp(self, x: Iterable) -> Tuple[Iterable, Iterable]: - x, inc_list = self.preprocess(x) - x_lcs = [] - for idx in x.index: - if idx in self.observations.df.index.levels[0]: - budget_level = self.get_budget_level(x[idx]) - lc = self.observations.extract_learning_curve(idx, budget_level) - else: - # initialize a learning curve with a place holder - # This is later padded accordingly for the Conv1D layer - lc = [0.0] - x_lcs.append(lc) - self.surrogate_model.set_prediction_learning_curves(x_lcs) - return x.values.tolist(), inc_list - - def preprocess_pfn(self, x: Iterable) -> Tuple[Iterable, Iterable, Iterable]: - """Prepares the configurations for appropriate EI calculation. - - Takes a set of points and computes the budget and incumbent for each point, as - required by the multi-fidelity Expected Improvement acquisition function. - """ - _x, inc_list = self.preprocess(x.copy()) - _x_tok = self.observations.tokenize(_x, as_tensor=True) - len_partial = len(self.observations.seen_config_ids) - z_min = x[0].fidelity.lower - # converting fidelity to the discrete budget level - # STRICT ASSUMPTION: fidelity is the first dimension - _x_tok[:len_partial, 0] = ( - _x_tok[:len_partial, 0] + self.b_step - z_min - ) / self.b_step - return _x_tok, _x, inc_list - def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Series]: - # _x = x.copy() # preprocessing needs to change the reference x Series so we don't copy here + # deepcopy + _x = pd.Series([x.loc[idx].copy() for idx in x.index.values], index=x.index) if self.surrogate_model_name == "pfn": - _x_tok, _x, inc_list = self.preprocess_pfn( + _x, _x_tok, inc_list = self.preprocess_pfn( x.copy() ) # IMPORTANT change from vanilla-EI ei = self.eval_pfn_ei(_x_tok, inc_list) - elif self.surrogate_model_name == "deep_gp": - _x, inc_list = self.preprocess_deep_gp( - x.copy() - ) # IMPORTANT change from vanilla-EI - ei = self.eval_gp_ei(_x, inc_list) - _x = pd.Series(_x, index=np.arange(len(_x))) + elif self.surrogate_model_name in ["deep_gp", "dpl"]: + _x, inc_list = self.preprocess_deep_gp(_x) # IMPORTANT change from vanilla-EI + ei = self.eval_gp_ei(_x.values.tolist(), inc_list) + elif self.surrogate_model_name == "gp": + _x, inc_list = self.preprocess_gp(_x) # IMPORTANT change from vanilla-EI + ei = self.eval_gp_ei(_x.values.tolist(), inc_list) else: - _x, inc_list = self.preprocess_gp( - x.copy() - ) # IMPORTANT change from vanilla-EI - ei = self.eval_gp_ei(_x, inc_list) - _x = pd.Series(_x, index=np.arange(len(_x))) + raise ValueError( + f"Unrecognized surrogate model name: {self.surrogate_model_name}" + ) + + if self.inc_normalization: + ei = ei / inc_list if ei.is_cuda: ei = ei.cpu() - if len(x) > 1 and asscalar: + if len(_x) > 1 and asscalar: return ei.detach().numpy(), _x else: return ei.detach().numpy().item(), _x @@ -145,8 +184,6 @@ def eval_pfn_ei( self, x: Iterable, inc_list: Iterable ) -> Union[np.ndarray, torch.Tensor, float]: """PFN-EI modified to preprocess samples and accept list of incumbents.""" - # x, inc_list = self.preprocess(x) # IMPORTANT change from vanilla-EI - # _x = x.copy() ei = self.surrogate_model.get_ei(x.to(self.surrogate_model.device), inc_list) if len(ei.shape) == 2: ei = ei.flatten() @@ -156,7 +193,6 @@ def eval_gp_ei( self, x: Iterable, inc_list: Iterable ) -> Union[np.ndarray, torch.Tensor, float]: """Vanilla-EI modified to preprocess samples and accept list of incumbents.""" - # x, inc_list = self.preprocess(x) # IMPORTANT change from vanilla-EI _x = x.copy() try: mu, cov = self.surrogate_model.predict(_x) @@ -182,13 +218,137 @@ def eval_gp_ei( ucdf = gauss.cdf(u) updf = torch.exp(gauss.log_prob(u)) ei = std * updf + (mu_star - mu - self.xi) * ucdf + # Clip ei if std == 0.0 + # ei = torch.where(torch.isclose(std, torch.tensor(0.0)), 0, ei) if self.augmented_ei: sigma_n = self.surrogate_model.likelihood ei *= 1.0 - torch.sqrt(torch.tensor(sigma_n, device=mu.device)) / torch.sqrt( sigma_n + torch.diag(cov) ) + + # Save data for writing + self.mu_star = mu_star.detach().numpy().tolist() + self.mu = mu.detach().numpy().tolist() + self.std = std.detach().numpy().tolist() return ei + +class MFEI_AtMax(MFEI): + def preprocess_inc_list(self, **kwargs) -> list: + assert "len_x" in kwargs, "Requires the length of the candidate set." + len_x = kwargs["len_x"] + # finds global incumbent + inc_value = min(self.observations.get_best_performance_for_each_budget()) + # uses the best seen value as the incumbent in EI computation for all candidates + inc_list = [inc_value] * len_x + return inc_list + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFEI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + budget_list = [] + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + for i, config in x.items(): + target_fidelity = config.fidelity.upper # change from MFEI + + if config.fidelity.value == target_fidelity: + # if the target_fidelity already reached, drop the configuration + indices_to_drop.append(i) + else: + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # create the same incumbent for all candidates + inc_list = self.preprocess_inc_list(len_x=len(x.index.values)) + + return x, torch.Tensor(inc_list) + + +class MFEI_Dyna(MFEI_AtMax): + """ + Computes extrapolation length of curves to maximum fidelity seen. + Uses the global incumbent as the best score in EI computation. + """ + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFEI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + # find the maximum observed steps per config to obtain the current pseudo_z_max + max_z_level_per_x = self.observations.get_max_observed_fidelity_level_per_config() + pseudo_z_level_max = max_z_level_per_x.max() # highest seen fidelity step so far + # find the fidelity step at which the best seen performance was recorded + z_inc_level = self.observations.get_budget_level_for_best_performance() + # retrieving actual fidelity values from budget level + ## marker 1: the fidelity value at which the best seen performance was recorded + z_inc = self.b_step * z_inc_level + self.pipeline_space.fidelity.lower + ## marker 2: the maximum fidelity value recorded in observation history + pseudo_z_max = ( + self.b_step * pseudo_z_level_max + self.pipeline_space.fidelity.lower + ) + + # TODO: compare with this first draft logic + # def update_fidelity(config): + # ### DO NOT DELETE THIS FUNCTION YET + # # for all configs, set the min(max(current fidelity + step, z_inc), pseudo_z_max) + # ## that is, choose the next highest marker from 1 and 2 + # z_extrapolate = min( + # max(config.fidelity.value + self.b_step, z_inc), + # pseudo_z_max + # ) + # config.fidelity.value = z_extrapolate + # return config + + def update_fidelity(config): + # for all configs, set to pseudo_z_max + ## that is, choose the highest seen fidelity in observation history + z_extrapolate = pseudo_z_max + config.fidelity.value = z_extrapolate + return config + + # collect IDs for partial configurations + _partial_config_ids = x.index <= max(self.observations.seen_config_ids) + # filter for configurations that reached max budget + indices_to_drop = [ + _idx + for _idx, _x in x.loc[_partial_config_ids].items() + if _x.fidelity.value == self.pipeline_space.fidelity.upper + ] + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # set fidelity for all partial configs + x = x.apply(update_fidelity) + + # create the same incumbent for all candidates + inc_list = self.preprocess_inc_list(len_x=len(x.index.values)) + + return x, torch.Tensor(inc_list) + + +class MFEI_Random(MFEI): + BUDGET = 1000 + def set_state( self, pipeline_space: SearchSpace, @@ -197,9 +357,73 @@ def set_state( b_step: Union[int, float], **kwargs, ): - # overload to select incumbent differently through observations - self.pipeline_space = pipeline_space - self.surrogate_model = surrogate_model - self.observations = observations - self.b_step = b_step - return + # set RNG + self.rng = np.random.RandomState(seed=42) + for i in range(len(observations.completed_runs)): + self.rng.uniform(-4, -1) + self.rng.randint(1, 51) + + return super().set_state(pipeline_space, surrogate_model, observations, b_step) + + def sample_horizon(self, steps_passed): + shortest = self.pipeline_space.fidelity.lower + longest = min(self.pipeline_space.fidelity.upper, self.BUDGET - steps_passed) + return self.rng.randint(shortest, longest + 1) + + def sample_threshold(self, f_inc): + lu = 10 ** self.rng.uniform(-4, -1) # % of gap closed + return f_inc * (1 - lu) + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + inc_list = [] + + steps_passed = len(self.observations.completed_runs) + print(f"Steps acquired: {steps_passed}") + + # Like EI-AtMax, use the global incumbent as a basis for the EI threshold + inc_value = min(self.observations.get_best_performance_for_each_budget()) + # Extension: Add a random min improvement threshold to encourage high risk high gain + inc_value = self.sample_threshold(inc_value) + print(f"Threshold for EI: {inc_value}") + + # Like MFEI: set fidelities to query using horizon as self.b_step + # Extension: Unlike DyHPO, we sample the horizon randomly over the full range + horizon = self.sample_horizon(steps_passed) + print(f"Horizon for EI: {horizon}") + for i, config in x.items(): + if i <= max(self.observations.seen_config_ids): + current_fidelity = config.fidelity.value + if np.equal(config.fidelity.value, config.fidelity.upper): + # this training run has ended, drop it from future selection + indices_to_drop.append(i) + else: + # a candidate partial training run to continue + target_fidelity = config.fidelity.value + horizon + config.fidelity.value = min( + config.fidelity.value + horizon, config.fidelity.upper + ) # if horizon exceeds max, query at max + inc_list.append(inc_value) + else: + # a candidate new training run that we would need to start + current_fidelity = 0 + config.fidelity.value = horizon + inc_list.append(inc_value) + # print(f"- {x.index.values[i]}: {current_fidelity} --> {config.fidelity.value}") + + # Drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + assert len(inc_list) == len(x) + + return x, torch.Tensor(inc_list) diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/mf_pi.py b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_pi.py new file mode 100644 index 00000000..c8c4ca5c --- /dev/null +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_pi.py @@ -0,0 +1,444 @@ +# type: ignore +from copy import deepcopy +from pathlib import Path +from typing import Any, Iterable, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from torch.distributions import Normal + +from ....optimizers.utils import map_real_hyperparameters_from_tabular_ids +from ....search_spaces.search_space import IntegerParameter, SearchSpace +from ....utils.common import SimpleCSVWriter +from ...multi_fidelity.utils import MFObservedData +from .base_acquisition import BaseAcquisition +from .ei import ComprehensiveExpectedImprovement +from .mf_ei import MFStepBase + + +# NOTE: the order of inheritance is important +class MFPI(MFStepBase, ComprehensiveExpectedImprovement): + def __init__( + self, + pipeline_space: SearchSpace, + surrogate_model_name: str = None, + augmented_ei: bool = False, + xi: float = 0.0, + in_fill: str = "best", + log_ei: bool = False, + ): + super().__init__(augmented_ei, xi, in_fill, log_ei) + self.pipeline_space = pipeline_space + self.surrogate_model_name = surrogate_model_name + self.surrogate_model = None + self.observations = None + self.b_step = None + + def preprocess_inc_list(self, **kwargs) -> list: + assert "budget_list" in kwargs, "Requires a list of query step for candidate set." + budget_list = kwargs["budget_list"] + performances = self.observations.get_best_performance_for_each_budget() + inc_list = [] + for budget_level in budget_list: + if budget_level in performances.index: + inc = performances[budget_level] + else: + inc = self.observations.get_best_seen_performance() + inc_list.append(inc) + return inc_list + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + budget_list = [] + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + indices_to_drop = [] + for i, config in x.items(): + target_fidelity = config.fidelity.lower + if i <= max(self.observations.seen_config_ids): + # IMPORTANT to set the fidelity at which EI will be calculated only for + # the partial configs that have been observed already + target_fidelity = config.fidelity.value + self.b_step + + if np.less_equal(target_fidelity, config.fidelity.upper): + # only consider the configs with fidelity lower than the max fidelity + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + else: + # if the target_fidelity higher than the max drop the configuration + indices_to_drop.append(i) + else: + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + + # Drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # Collecting incumbent list per configuration + inc_list = self.preprocess_inc_list(budget_list=budget_list) + + return x, torch.Tensor(inc_list) + + def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Series]: + # deepcopy + _x = pd.Series([x.loc[idx].copy() for idx in x.index.values], index=x.index) + if self.surrogate_model_name == "pfn": + _x, _x_tok, inc_list = self.preprocess_pfn( + x.copy() + ) # IMPORTANT change from vanilla-EI + pi = self.eval_pfn_pi(_x_tok, inc_list) + elif self.surrogate_model_name in ["deep_gp", "dpl"]: + _x, inc_list = self.preprocess_deep_gp(_x) # IMPORTANT change from vanilla-EI + pi = self.eval_gp_pi(_x.values.tolist(), inc_list) + elif self.surrogate_model_name == "gp": + _x, inc_list = self.preprocess_gp(_x) # IMPORTANT change from vanilla-EI + pi = self.eval_gp_pi(_x.values.tolist(), inc_list) + else: + raise ValueError( + f"Unrecognized surrogate model name: {self.surrogate_model_name}" + ) + + if pi.is_cuda: + pi = ei.cpu() + if len(_x) > 1 and asscalar: + return pi.detach().numpy(), _x + else: + return pi.detach().numpy().item(), _x + + def eval_pfn_pi( + self, x: Iterable, inc_list: Iterable + ) -> Union[np.ndarray, torch.Tensor, float]: + """PFN-PI modified to preprocess samples and accept list of incumbents.""" + pi = self.surrogate_model.get_pi(x.to(self.surrogate_model.device), inc_list) + if len(pi.shape) == 2: + pi = pi.flatten() + print(f"Maximum PI: {pi.max()}") + return pi + + def eval_gp_pi( + self, x: Iterable, inc_list: Iterable + ) -> Union[np.ndarray, torch.Tensor, float]: + _x = x.copy() + try: + mu, cov = self.surrogate_model.predict(_x) + except ValueError as e: + raise e + std = torch.sqrt(torch.diag(cov)) + mu_star = inc_list.to(mu.device) + + gauss = Normal(torch.zeros(1, device=mu.device), torch.ones(1, device=mu.device)) + pi = gauss.cdf((mu - mu_star) / (std + 1e-9)) + return pi + + +class MFPI_AtMax(MFPI): + def preprocess_inc_list(self, **kwargs) -> list: + assert "len_x" in kwargs, "Requires the length of the candidate set." + len_x = kwargs["len_x"] + # finds global incumbent + inc_value = min(self.observations.get_best_performance_for_each_budget()) + # uses the best seen value as the incumbent in EI computation for all candidates + inc_list = [inc_value] * len_x + return inc_list + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFPI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + budget_list = [] + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + for i, config in x.items(): + target_fidelity = config.fidelity.upper # change from MFEI + + if config.fidelity.value == target_fidelity: + # if the target_fidelity already reached, drop the configuration + indices_to_drop.append(i) + else: + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # create the same incumbent for all candidates + inc_list = self.preprocess_inc_list(len_x=len(x.index.values)) + + return x, torch.Tensor(inc_list) + + +class MFPI_Dyna(MFPI_AtMax): + """ + Computes extrapolation length of curves to maximum fidelity seen. + Uses the global incumbent as the best score in EI computation. + """ + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFEI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + # find the maximum observed steps per config to obtain the current pseudo_z_max + max_z_level_per_x = self.observations.get_max_observed_fidelity_level_per_config() + pseudo_z_level_max = max_z_level_per_x.max() # highest seen fidelity step so far + # find the fidelity step at which the best seen performance was recorded + z_inc_level = self.observations.get_budget_level_for_best_performance() + # retrieving actual fidelity values from budget level + ## marker 1: the fidelity value at which the best seen performance was recorded + z_inc = self.b_step * z_inc_level + self.pipeline_space.fidelity.lower + ## marker 2: the maximum fidelity value recorded in observation history + pseudo_z_max = ( + self.b_step * pseudo_z_level_max + self.pipeline_space.fidelity.lower + ) + + # TODO: compare with this first draft logic + # def update_fidelity(config): + # ### DO NOT DELETE THIS FUNCTION YET + # # for all configs, set the min(max(current fidelity + step, z_inc), pseudo_z_max) + # ## that is, choose the next highest marker from 1 and 2 + # z_extrapolate = min( + # max(config.fidelity.value + self.b_step, z_inc), + # pseudo_z_max + # ) + # config.fidelity.value = z_extrapolate + # return config + + def update_fidelity(config): + # for all configs, set to pseudo_z_max + ## that is, choose the highest seen fidelity in observation history + z_extrapolate = pseudo_z_max + config.fidelity.value = z_extrapolate + return config + + # collect IDs for partial configurations + _partial_config_ids = x.index <= max(self.observations.seen_config_ids) + # filter for configurations that reached max budget + indices_to_drop = [ + _idx + for _idx, _x in x.loc[_partial_config_ids].items() + if _x.fidelity.value == self.pipeline_space.fidelity.upper + ] + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # set fidelity for all partial configs + x = x.apply(update_fidelity) + + # create the same incumbent for all candidates + inc_list = self.preprocess_inc_list(len_x=len(x.index.values)) + + return x, torch.Tensor(inc_list) + + +class MFPI_Random(MFPI): + BUDGET = 1000 + + def __init__( + self, + pipeline_space: SearchSpace, + horizon: str = "random", + threshold: str = "random", + surrogate_model_name: str = None, + augmented_ei: bool = False, + xi: float = 0.0, + in_fill: str = "best", + log_ei: bool = False, + ): + super().__init__( + pipeline_space, surrogate_model_name, augmented_ei, xi, in_fill, log_ei + ) + self.horizon = horizon + self.threshold = threshold + + def set_state( + self, + pipeline_space: SearchSpace, + surrogate_model: Any, + observations: MFObservedData, + b_step: Union[int, float], + **kwargs, + ): + # set RNG + self.rng = np.random.RandomState(seed=42) + for i in range(len(observations.completed_runs)): + self.rng.uniform(-4, -1) + self.rng.randint(1, 51) + + return super().set_state(pipeline_space, surrogate_model, observations, b_step) + + def sample_horizon(self, steps_passed): + if self.horizon == "random": + shortest = self.pipeline_space.fidelity.lower + longest = min(self.pipeline_space.fidelity.upper, self.BUDGET - steps_passed) + return self.rng.randint(shortest, longest + 1) + elif self.horizon == "max": + return min(self.pipeline_space.fidelity.upper, self.BUDGET - steps_passed) + else: + return int(self.horizon) + + def sample_threshold(self, f_inc): + if self.threshold == "random": + lu = 10 ** self.rng.uniform(-4, -1) # % of gap closed + else: + lu = float(self.threshold) + return f_inc * (1 - lu) + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + inc_list = [] + + steps_passed = len(self.observations.completed_runs) + print(f"Steps acquired: {steps_passed}") + + # Like EI-AtMax, use the global incumbent as a basis for the EI threshold + inc_value = min(self.observations.get_best_performance_for_each_budget()) + # Extension: Add a random min improvement threshold to encourage high risk high gain + t_value = self.sample_threshold(inc_value) + print(f"Threshold for PI: {inc_value - t_value}") + inc_value = t_value + + # Like MFEI: set fidelities to query using horizon as self.b_step + # Extension: Unlike DyHPO, we sample the horizon randomly over the full range + horizon = self.sample_horizon(steps_passed) + print(f"Horizon for PI: {horizon}") + for i, config in x.items(): + if i <= max(self.observations.seen_config_ids): + current_fidelity = config.fidelity.value + if np.equal(config.fidelity.value, config.fidelity.upper): + # this training run has ended, drop it from future selection + indices_to_drop.append(i) + else: + # a candidate partial training run to continue + target_fidelity = config.fidelity.value + horizon + config.fidelity.value = min( + config.fidelity.value + horizon, config.fidelity.upper + ) # if horizon exceeds max, query at max + inc_list.append(inc_value) + else: + # a candidate new training run that we would need to start + current_fidelity = 0 + config.fidelity.value = horizon + inc_list.append(inc_value) + # print(f"- {x.index.values[i]}: {current_fidelity} --> {config.fidelity.value}") + + # Drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + assert len(inc_list) == len(x) + + return x, torch.Tensor(inc_list) + + +class MFPI_Random_HiT(MFPI): + BUDGET = 1000 + + def set_state( + self, + pipeline_space: SearchSpace, + surrogate_model: Any, + observations: MFObservedData, + b_step: Union[int, float], + **kwargs, + ): + # set RNG + self.rng = np.random.RandomState(seed=42) + for i in range(len(observations.completed_runs)): + self.rng.uniform(-4, 0) + self.rng.randint(1, 51) + + return super().set_state(pipeline_space, surrogate_model, observations, b_step) + + def sample_horizon(self, steps_passed): + shortest = self.pipeline_space.fidelity.lower + longest = min(self.pipeline_space.fidelity.upper, self.BUDGET - steps_passed) + return self.rng.randint(shortest, longest + 1) + + def sample_threshold(self, f_inc): + lu = 10 ** self.rng.uniform(-4, 0) # % of gap closed + return f_inc * (1 - lu) + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + inc_list = [] + + steps_passed = len(self.observations.completed_runs) + print(f"Steps acquired: {steps_passed}") + + # Like EI-AtMax, use the global incumbent as a basis for the EI threshold + inc_value = min(self.observations.get_best_performance_for_each_budget()) + # Extension: Add a random min improvement threshold to encourage high risk high gain + t_value = self.sample_threshold(inc_value) + print(f"Threshold for EI: {inc_value - t_value}") + inc_value = t_value + + # Like MFEI: set fidelities to query using horizon as self.b_step + # Extension: Unlike DyHPO, we sample the horizon randomly over the full range + horizon = self.sample_horizon(steps_passed) + print(f"Horizon for EI: {horizon}") + for i, config in x.items(): + if i <= max(self.observations.seen_config_ids): + current_fidelity = config.fidelity.value + if np.equal(config.fidelity.value, config.fidelity.upper): + # this training run has ended, drop it from future selection + indices_to_drop.append(i) + else: + # a candidate partial training run to continue + target_fidelity = config.fidelity.value + horizon + config.fidelity.value = min( + config.fidelity.value + horizon, config.fidelity.upper + ) # if horizon exceeds max, query at max + inc_list.append(inc_value) + else: + # a candidate new training run that we would need to start + current_fidelity = 0 + config.fidelity.value = horizon + inc_list.append(inc_value) + # print(f"- {x.index.values[i]}: {current_fidelity} --> {config.fidelity.value}") + + # Drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + assert len(inc_list) == len(x) + + return x, torch.Tensor(inc_list) diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/mf_two_step.py b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_two_step.py new file mode 100644 index 00000000..1d1175a2 --- /dev/null +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_two_step.py @@ -0,0 +1,241 @@ +from typing import Any, Tuple, Union + +import numpy as np +import pandas as pd + +from ....search_spaces.search_space import SearchSpace +from ...multi_fidelity.utils import MFObservedData +from .base_acquisition import BaseAcquisition +from .mf_ei import MFEI, MFEI_Dyna +from .mf_ucb import MF_UCB_Dyna + + +class MF_TwoStep(BaseAcquisition): + """2-step acquisition: employs 3 different acquisition calls.""" + + # HYPER-PARAMETERS: Going with the Freeze-Thaw BO (Swersky et al. 2014) values + N_PARTIAL = 10 + N_NEW = 3 + + def __init__( + self, + pipeline_space: SearchSpace, + surrogate_model_name: str = None, + beta: float = 1.0, + maximize: bool = False, + augmented_ei: bool = False, + xi: float = 0.0, + in_fill: str = "best", + log_ei: bool = False, + ): + """Upper Confidence Bound (UCB) acquisition function. + + Args: + beta: Controls the balance between exploration and exploitation. + maximize: If True, maximize the given model, else minimize. + DEFAULT=False, assumes minimzation. + """ + super().__init__() + # Acquisition 1: For trimming down partial candidate set + self.acq_partial_filter = MFEI_Dyna_PartialFilter( # defined below + pipeline_space=pipeline_space, + surrogate_model_name=surrogate_model_name, + augmented_ei=augmented_ei, + xi=xi, + in_fill=in_fill, + log_ei=log_ei, + ) + # Acquisition 2: For trimming down new candidate set + self.acq_new_filter = MFEI( + pipeline_space=pipeline_space, + surrogate_model_name=surrogate_model_name, + augmented_ei=augmented_ei, + xi=xi, + in_fill=in_fill, + log_ei=log_ei, + ) + # Acquisition 3: For final selection of winners from Acquisitions 1 & 2 + self.acq_combined = MF_UCB_Dyna( + pipeline_space=pipeline_space, + surrogate_model_name=surrogate_model_name, + beta=beta, + maximize=maximize, + ) + self.pipeline_space = pipeline_space + self.surrogate_model_name = surrogate_model_name + self.surrogate_model = None + self.observations = None + self.b_step = None + + def set_state( + self, + pipeline_space: SearchSpace, + surrogate_model: Any, + observations: MFObservedData, + b_step: Union[int, float], + **kwargs, + ): + # overload to select incumbent differently through observations + self.pipeline_space = pipeline_space + self.surrogate_model = surrogate_model + self.observations = observations + self.b_step = b_step + self.acq_partial_filter.set_state( + self.pipeline_space, self.surrogate_model, self.observations, self.b_step + ) + self.acq_new_filter.set_state( + self.pipeline_space, self.surrogate_model, self.observations, self.b_step + ) + self.acq_combined.set_state( + self.pipeline_space, self.surrogate_model, self.observations, self.b_step + ) + + def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Series]: + # Filter self.N_NEW among the new configuration IDs + # Filter self.N_PARTIAL among the partial configuration IDs + max_seen_id = max(self.observations.seen_config_ids) + total_seen_id = len(self.observations.seen_config_ids) + new_ids = x.index[x.index > max_seen_id].values + partial_ids = x.index[x.index <= max_seen_id].values + + # for new candidate set + acq, _samples = self.acq_new_filter.eval(x, asscalar=True) + acq = pd.Series(acq, index=_samples.index) + # drop partial configurations + acq.loc[_samples.index.values <= max_seen_id] = 0 + # NOTE: setting to 0 works as EI-based AF returns > 0 + # find configs not in top-N_NEW set as per acquisition value, to be dropped + not_top_new_idx = acq.sort_values().index[: -self.N_NEW] # len(acq) - N_NEW + # drop these configurations + acq.loc[ + not_top_new_idx + ] = 0 # to ignore in the argmax of the acquisition function + # NOTE: setting to 0 works as EI-based AF returns > 0 + # result of first round of filtering of new candidates + + acq_new_mask = pd.Series( + {idx: val for idx, val in _samples.items() if acq.loc[idx] > 0} + ) + # for partial candidate set + acq, _samples = self.acq_partial_filter.eval(x, asscalar=True) + acq = pd.Series(acq, index=_samples.index) + # weigh the acq value based on max seen for each config + acq = self._weigh_partial_acq_scores(acq=acq) + # drop new configurations + acq.loc[ + _samples.index.values > max_seen_id + ] = 0 # to ignore in the argmax of the acquisition function + # find configs not in top-N_NEW set as per acquisition value + _top_n_partial = min(self.N_PARTIAL, total_seen_id) + not_top_new_idx = acq.sort_values().index[ + :-_top_n_partial + ] # acq.argsort()[::-1][_top_n_partial:] # sorts in ascending-flips-leaves out top-N_PARTIAL + # drop these configurations + acq.loc[ + not_top_new_idx + ] = 0 # to ignore in the argmax of the acquisition function + # NOTE: setting to 0 works as EI-based AF returns > 0 + # result of first round of filtering of partial candidates + acq_partial_mask = pd.Series( + {idx: val for idx, val in _samples.items() if acq.loc[idx] > 0} + ) + + eligible_set = set( + np.concatenate( + [ + acq_partial_mask.index.values.tolist(), + acq_new_mask.index.values.tolist(), + ] + ) + ) + + # for combined selection + acq, _samples = self.acq_combined.eval(x, asscalar=True) + acq = pd.Series(acq, index=_samples.index) + # applying mask from step-1 to make final selection among (N_NEW + N_PARTIAL) + mask = acq.index.isin(eligible_set) + # NOTE: setting to -np.inf works as MF-UCB here is max.(-LCB) instead of min.(LCB) + acq[~mask] = -np.inf # will be ignored in the argmax of the acquisition function + acq_combined = pd.Series( + { + idx: acq.loc[idx] + for idx, val in _samples.items() + if acq.loc[idx] != -np.inf + } + ) + # NOTE: setting to -np.inf works as MF-UCB here is max.(-LCB) instead of min.(LCB) + acq_combined = acq_combined.reindex(acq.index, fill_value=-np.inf) + acq = acq_combined.values + + return acq, _samples + + def _weigh_partial_acq_scores(self, acq: pd.Series) -> pd.Series: + # find the best performance per configuration seen + inc_list_partial = self.observations.get_best_performance_per_config() + + # removing any config indicey that have not made it till here + _idx_drop = [_i for _i in inc_list_partial.index if _i not in acq.index] + inc_list_partial.drop(labels=_idx_drop, inplace=True) + + # normalize the scores based on relative best seen performance per config + _inc, _max = inc_list_partial.min(), inc_list_partial.max() + inc_list_partial = ( + (inc_list_partial - _inc) / (_max - _inc) if _inc < _max else inc_list_partial + ) + + # calculate weights per candidate + weights = pd.Series(1 - inc_list_partial, index=inc_list_partial.index) + + # scaling the acquisition score with weights + acq = acq * weights + + return acq + + +class MFEI_PartialFilter(MFEI): + """Custom redefinition of MF-EI with Dynamic extrapolation length to adjust incumbents.""" + + def preprocess_inc_list(self, **kwargs) -> list: + # the assertion exists to forcibly check the call to the super().preprocess() + # this function overload should only affect the operation inside it + assert "budget_list" in kwargs, "Requires the length of the candidate set." + # we still need this as placeholder for the new candidate set + # in this class we only work on the partial candidate set + inc_list = super().preprocess_inc_list(budget_list=kwargs["budget_list"]) + + n_partial = len(self.observations.seen_config_ids) + + # NOTE: Here we set the incumbent for EI calculation for each config to the + # maximum it has seen, in a bid to get an expected improvement over its previous + # observed score. This could act as a filter to diverging configurations even if + # their overall score relative to the incumbent can be high. + inc_list_partial = self.observations.get_best_performance_per_config() + # updating incumbent for EI computation for the partial configs + inc_list[:n_partial] = inc_list_partial + + return inc_list + + +class MFEI_Dyna_PartialFilter(MFEI_Dyna): + """Custom redefinition of MF-EI with Dynamic extrapolation length to adjust incumbents.""" + + def preprocess_inc_list(self, **kwargs) -> list: + # the assertion exists to forcibly check the call to the super().preprocess() + # this function overload should only affect the operation inside it + assert "len_x" in kwargs, "Requires the length of the candidate set." + # we still need this as placeholder for the new candidate set + # in this class we only work on the partial candidate set + inc_list = super().preprocess_inc_list(len_x=kwargs["len_x"]) + + n_partial = len(self.observations.seen_config_ids) + + # NOTE: Here we set the incumbent for EI calculation for each config to the + # maximum it has seen, in a bid to get an expected improvement over its previous + # observed score. This could act as a filter to diverging configurations even if + # their overall score relative to the incumbent can be high. + inc_list_partial = self.observations.get_best_performance_per_config() + + # updating incumbent for EI computation for the partial configs + inc_list[:n_partial] = inc_list_partial + + return inc_list diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ucb.py b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ucb.py new file mode 100644 index 00000000..0fffada7 --- /dev/null +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/mf_ucb.py @@ -0,0 +1,224 @@ +from typing import Any, Iterable, Tuple, Union + +import numpy as np +import pandas as pd +import torch + +from ....optimizers.utils import map_real_hyperparameters_from_tabular_ids +from ....search_spaces.search_space import IntegerParameter, SearchSpace +from ...multi_fidelity.utils import MFObservedData +from .mf_ei import MFStepBase +from .ucb import UpperConfidenceBound + + +# NOTE: the order of inheritance is important +class MF_UCB(MFStepBase, UpperConfidenceBound): + def __init__( + self, + pipeline_space: SearchSpace, + surrogate_model_name: str = None, + beta: float = 1.0, + maximize: bool = False, + ): + """Upper Confidence Bound (UCB) acquisition function. + + Args: + beta: Controls the balance between exploration and exploitation. + maximize: If True, maximize the given model, else minimize. + DEFAULT=False, assumes minimzation. + """ + super().__init__(beta, maximize) + self.pipeline_space = pipeline_space + self.surrogate_model_name = surrogate_model_name + self.surrogate_model = None + self.observations = None + self.b_step = None + + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point, as + required by the multi-fidelity Expected Improvement acquisition function. + """ + budget_list = [] + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + betas = [] + for i, config in x.items(): + target_fidelity = config.fidelity.lower + if i <= max(self.observations.seen_config_ids): + # IMPORTANT to set the fidelity at which EI will be calculated only for + # the partial configs that have been observed already + target_fidelity = config.fidelity.value + self.b_step + if np.less_equal(target_fidelity, config.fidelity.upper): + # only consider the configs with fidelity lower than the max fidelity + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + # CAN ADAPT BETA PER-SAMPLE HERE + betas.append(self.beta) + else: + # if the target_fidelity higher than the max drop the configuration + indices_to_drop.append(i) + else: + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + # CAN ADAPT BETA PER-SAMPLE HERE + betas.append(self.beta) + + # Drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + return x, torch.Tensor(betas) + + def preprocess_gp( + self, x: pd.Series, surrogate_name: str = "gp" + ) -> Tuple[pd.Series, torch.Tensor]: + if surrogate_name == "gp": + x, inc_list = self.preprocess(x) + return x, inc_list + elif surrogate_name in ["deep_gp", "dpl"]: + x, inc_list = self.preprocess(x) + x_lcs = [] + for idx in x.index: + if idx in self.observations.df.index.levels[0]: + # extracting the available/observed learning curve + lc = self.observations.extract_learning_curve(idx, budget_id=None) + else: + # initialize a learning curve with a placeholder + # This is later padded accordingly for the Conv1D layer + lc = [] + x_lcs.append(lc) + self.surrogate_model.set_prediction_learning_curves(x_lcs) + return x, inc_list + else: + raise ValueError(f"Unrecognized surrogate model name: {surrogate_name}") + + def eval_pfn_ucb( + self, x: Iterable, beta: float = (1 - 0.682) / 2 + ) -> Union[np.ndarray, torch.Tensor, float]: + """PFN-UCB modified to preprocess samples and accept list of incumbents.""" + ucb = self.surrogate_model.get_ucb( + x_test=x.to(self.surrogate_model.device), + beta=beta, # TODO: extend to have different betas for each candidates in x + ) + if len(ucb.shape) == 2: + ucb = ucb.flatten() + return ucb + + def eval(self, x: pd.Series, asscalar: bool = False) -> Tuple[np.ndarray, pd.Series]: + if self.surrogate_model_name == "pfn": + _x, _x_tok, _ = self.preprocess_pfn(x.copy()) + ucb = self.eval_pfn_ucb(_x_tok) + elif self.surrogate_model_name in ["deep_gp", "gp", "dpl"]: + _x, betas = self.preprocess_gp(x.copy(), self.surrogate_model_name) + ucb = super().eval(_x.values.tolist(), betas, asscalar) + else: + raise ValueError( + f"Unrecognized surrogate model name: {self.surrogate_model_name}" + ) + + return ucb, _x + + +class MF_UCB_AtMax(MF_UCB): + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFEI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + budget_list = [] + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + indices_to_drop = [] + betas = [] + for i, config in x.items(): + target_fidelity = config.fidelity.upper # change from MFEI + + if config.fidelity.value == target_fidelity: + # if the target_fidelity already reached, drop the configuration + indices_to_drop.append(i) + else: + config.fidelity.value = target_fidelity + budget_list.append(self.get_budget_level(config)) + + # CAN ADAPT BETA PER-SAMPLE HERE + betas.append(self.beta) + + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + return x, torch.Tensor(betas) + + +class MF_UCB_Dyna(MF_UCB): + def preprocess(self, x: pd.Series) -> Tuple[pd.Series, torch.Tensor]: + """Prepares the configurations for appropriate EI calculation. + + Takes a set of points and computes the budget and incumbent for each point. + Unlike the base class MFEI, sets the target fidelity to be max budget and the + incumbent choice to be the max seen across history for all candidates. + """ + if self.pipeline_space.has_tabular: + # preprocess tabular space differently + # expected input: IDs pertaining to the tabular data + x = map_real_hyperparameters_from_tabular_ids(x, self.pipeline_space) + + # find the maximum observed steps per config to obtain the current pseudo_z_max + max_z_level_per_x = self.observations.get_max_observed_fidelity_level_per_config() + pseudo_z_level_max = max_z_level_per_x.max() # highest seen fidelity step so far + # find the fidelity step at which the best seen performance was recorded + z_inc_level = self.observations.get_budget_level_for_best_performance() + # retrieving actual fidelity values from budget level + ## marker 1: the fidelity value at which the best seen performance was recorded + z_inc = self.b_step * z_inc_level + self.pipeline_space.fidelity.lower + ## marker 2: the maximum fidelity value recorded in observation history + pseudo_z_max = ( + self.b_step * pseudo_z_level_max + self.pipeline_space.fidelity.lower + ) + + # TODO: compare with this first draft logic + # def update_fidelity(config): + # ### DO NOT DELETE THIS FUNCTION YET + # # for all configs, set the min(max(current fidelity + step, z_inc), pseudo_z_max) + # ## that is, choose the next highest marker from 1 and 2 + # z_extrapolate = min( + # max(config.fidelity.value + self.b_step, z_inc), + # pseudo_z_max + # ) + # config.fidelity.value = z_extrapolate + # return config + + def update_fidelity(config): + # for all configs, set to pseudo_z_max + ## that is, choose the highest seen fidelity in observation history + z_extrapolate = pseudo_z_max + config.fidelity.value = z_extrapolate + return config + + # collect IDs for partial configurations + _partial_config_ids = x.index <= max(self.observations.seen_config_ids) + # filter for configurations that reached max budget + indices_to_drop = [ + _idx + for _idx, _x in x.loc[_partial_config_ids].items() + if _x.fidelity.value == self.pipeline_space.fidelity.upper + ] + # drop unused configs + x.drop(labels=indices_to_drop, inplace=True) + + # set fidelity for all partial configs + x = x.apply(update_fidelity) + + # CAN ADAPT BETA PER-SAMPLE HERE + betas = [self.beta] * len(x) # TODO: have tighter order check to Pd.Series index + + return x, torch.Tensor(betas) diff --git a/neps/optimizers/bayesian_optimization/acquisition_functions/ucb.py b/neps/optimizers/bayesian_optimization/acquisition_functions/ucb.py index adf57266..f3c89a39 100644 --- a/neps/optimizers/bayesian_optimization/acquisition_functions/ucb.py +++ b/neps/optimizers/bayesian_optimization/acquisition_functions/ucb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable, Union import numpy as np @@ -7,7 +9,7 @@ class UpperConfidenceBound(BaseAcquisition): - def __init__(self, beta: float=1.0, maximize: bool=False): + def __init__(self, beta: float = 1.0, maximize: bool = False): """Upper Confidence Bound (UCB) acquisition function. Args: @@ -18,7 +20,7 @@ def __init__(self, beta: float=1.0, maximize: bool=False): super().__init__() self.beta = beta # can be updated as part of the state for dynamism or a schedule self.maximize = maximize - + # to be initialized as part of the state self.surrogate_model = None @@ -30,31 +32,23 @@ def set_state(self, surrogate_model, **kwargs): self.beta = kwargs["beta"] else: self.logger.warning("Beta is a list, not updating beta value!") - + def eval( - self, x: Iterable, asscalar: bool = False - ) -> Union[np.ndarray, torch.Tensor, float]: + self, x: Iterable, betas: torch.Tensor | None = None, asscalar: bool = False + ) -> np.ndarray | torch.Tensor | float: try: mu, cov = self.surrogate_model.predict(x) std = torch.sqrt(torch.diag(cov)) except ValueError as e: raise e sign = 1 if self.maximize else -1 # LCB is performed if minimize=True - ucb_scores = mu + sign * np.sqrt(self.beta) * std - # if LCB, minimize acquisition, or maximize -acquisition - ucb_scores = ucb_scores.detach().numpy() * sign - - return ucb_scores - - -class MF_UCB(UpperConfidenceBound): - - def preprocess(self, x: Iterable) -> Iterable: - performances = self.observations.get_best_performance_for_each_budget() - pass - - def eval( - self, x: Iterable, asscalar: bool = False - ) -> Union[np.ndarray, torch.Tensor, float]: - x = self.preprocess(x) - return self.eval(x, asscalar=asscalar) + ucb_scores = mu + sign * torch.sqrt(self.beta if betas is None else betas) * std + # if LCB, minimize acquisition, or maximize -acquisition + ucb_scores = ucb_scores * sign + + if ucb_scores.is_cuda: + ucb_scores = ucb_scores.cpu() + if len(x) > 1 and asscalar: + return ucb_scores.detach().numpy() + else: + return ucb_scores.detach().numpy().item() diff --git a/neps/optimizers/bayesian_optimization/acquisition_samplers/freeze_thaw_sampler.py b/neps/optimizers/bayesian_optimization/acquisition_samplers/freeze_thaw_sampler.py index abae00b1..8b5c6145 100644 --- a/neps/optimizers/bayesian_optimization/acquisition_samplers/freeze_thaw_sampler.py +++ b/neps/optimizers/bayesian_optimization/acquisition_samplers/freeze_thaw_sampler.py @@ -1,6 +1,7 @@ # type: ignore from __future__ import annotations +import time import warnings from copy import deepcopy @@ -13,7 +14,6 @@ class FreezeThawSampler(AcquisitionSampler): - SAMPLES_TO_DRAW = 100 # number of random samples to draw at lowest fidelity def __init__(self, **kwargs): @@ -27,7 +27,7 @@ def __init__(self, **kwargs): self.sample_full_table = None self.set_sample_full_tabular(True) # sets flag that samples full table - def set_sample_full_tabular(self, flag: bool=False): + def set_sample_full_tabular(self, flag: bool = False): if self.is_tabular: self.sample_full_table = flag @@ -106,60 +106,70 @@ def sample( set_new_sample_fidelity: int | float = None, ) -> list(): """Samples a new set and returns the total set of observed + new configs.""" + start = time.time() partial_configs = self.observations.get_partial_configs_at_max_seen() - new_configs = self._sample_new( - index_from=self.observations.next_config_id(), n=n, ignore_fidelity=False - ) - - def __sample_single_new_tabular(index: int): - """ - A function to use in a list comprehension to slightly speed up - the sampling process when self.SAMPLE_TO_DRAW is large - """ - config = self.pipeline_space.sample( - patience=self.patience, user_priors=False, ignore_fidelity=False - ) - config["id"].value = _new_configs[index] - config.fidelity.value = set_new_sample_fidelity - return config + # print("-" * 50) + # print(f"| freeze-thaw:get_partial_at_max_seen(): {time.time()-start:.2f}s") + # print("-" * 50) + _n = n if n is not None else self.SAMPLES_TO_DRAW if self.is_tabular: - _n = n if n is not None else self.SAMPLES_TO_DRAW + # handles tabular data such that the entire unseen set of configs from the + # table is considered to be the new set of candidates _partial_ids = {conf["id"].value for conf in partial_configs} - _all_ids = set(self.pipeline_space.custom_grid_table.index.values) + _all_ids = set(list(self.pipeline_space.custom_grid_table.keys())) # accounting for unseen configs only, samples remaining table if flag is set max_n = len(_all_ids) + 1 if self.sample_full_table else _n _n = min(max_n, len(_all_ids - _partial_ids)) - + + start = time.time() _new_configs = np.random.choice( list(_all_ids - _partial_ids), size=_n, replace=False ) - new_configs = [__sample_single_new_tabular(i) for i in range(_n)] + placeholder_config = self.pipeline_space.sample( + patience=self.patience, user_priors=False, ignore_fidelity=False + ) + _configs = [deepcopy(placeholder_config) for _id in _new_configs] + for _i, val in enumerate(_new_configs): + _configs[_i]["id"].value = val + + # print("-" * 50) + # print(f"| freeze-thaw:sample:new_configs_extraction: {time.time()-start:.2f}s") + # print("-" * 50) new_configs = pd.Series( - new_configs, + _configs, index=np.arange( - len(partial_configs), len(partial_configs) + len(new_configs) + len(partial_configs), len(partial_configs) + len(_new_configs) ), ) + else: + # handles sampling new configurations for continuous spaces + new_configs = self._sample_new( + index_from=self.observations.next_config_id(), n=_n, ignore_fidelity=False + ) + # Continuous benchmarks need to deepcopy individual configs here, + # because in contrast to tabular benchmarks + # they are not reset in every sampling step + partial_configs = pd.Series( + [deepcopy(p_config_) for idx, p_config_ in partial_configs.items()], + index=partial_configs.index, + ) - elif set_new_sample_fidelity is not None: + # Updating fidelity values + start = time.time() + if set_new_sample_fidelity is not None: for config in new_configs: config.fidelity.value = set_new_sample_fidelity + # print("-" * 50) + # print(f"| freeze-thaw:sample:new_configs_set_fidelity: {time.time()-start:.2f}s") + # print("-" * 50) - # Deep copy configs for fidelity updates - partial_configs_list = [] - index_list = [] - for idx, config in partial_configs.items(): - _config = deepcopy(config) - partial_configs_list.append(_config) - index_list.append(idx) - - # We build a new series of partial configs to avoid - # incrementing fidelities multiple times due to pass-by-reference - partial_configs = pd.Series(partial_configs_list, index=index_list) - - configs = pd.concat([partial_configs, new_configs]) + start = time.time() + configs = pd.concat([deepcopy(partial_configs), new_configs]) + # print("-" * 50) + # print(f"| freeze-thaw:sample:concat_configs: {time.time()-start:.2f}s") + # print("-" * 50) return configs @@ -180,3 +190,4 @@ def set_state( and self.pipeline_space.custom_grid_table is not None ): self.is_tabular = True + self.set_sample_full_tabular(True) diff --git a/neps/optimizers/bayesian_optimization/kernels/combine_kernels.py b/neps/optimizers/bayesian_optimization/kernels/combine_kernels.py index 833602e2..849198ca 100644 --- a/neps/optimizers/bayesian_optimization/kernels/combine_kernels.py +++ b/neps/optimizers/bayesian_optimization/kernels/combine_kernels.py @@ -51,7 +51,6 @@ def fit_transform( ): N = len(configs) K = torch.zeros(N, N) if self.combined_by == "sum" else torch.ones(N, N) - gr1, x1 = extract_configs(configs) for i, k in enumerate(self.kernels): diff --git a/neps/optimizers/bayesian_optimization/mf_tpe.py b/neps/optimizers/bayesian_optimization/mf_tpe.py index f705f9f0..cb72940a 100644 --- a/neps/optimizers/bayesian_optimization/mf_tpe.py +++ b/neps/optimizers/bayesian_optimization/mf_tpe.py @@ -7,7 +7,7 @@ import numpy as np import torch from scipy.stats import spearmanr -from typing_extensions import Literal +from typing import Literal from ...metahyper import ConfigResult, instance_from_map from ...search_spaces import ( diff --git a/neps/optimizers/bayesian_optimization/models/DPL.py b/neps/optimizers/bayesian_optimization/models/DPL.py new file mode 100644 index 00000000..37279824 --- /dev/null +++ b/neps/optimizers/bayesian_optimization/models/DPL.py @@ -0,0 +1,851 @@ +from __future__ import annotations + +import logging +import os +import time +from copy import deepcopy +from pathlib import Path +from typing import Any, List, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +from scipy.stats import norm + +from ....search_spaces.search_space import ( + CategoricalParameter, + FloatParameter, + IntegerParameter, + SearchSpace, +) +from ...multi_fidelity.utils import MFObservedData + + +# TODO: Move to utils +def get_optimizer_losses(root_directory: Path | str) -> list[float]: + all_losses_file = root_directory / "all_losses_and_configs.txt" + + # Read all losses from the file in the order they are explored + losses = [ + float(line[6:]) + for line in all_losses_file.read_text(encoding="utf-8").splitlines() + if "Loss: " in line + ] + return losses + + +def get_best_loss(root_directory: Path | str) -> float: + root_directory = Path(root_directory) + best_loss_fiel = root_directory / "best_loss_trajectory.txt" + + # Get the best seen loss value + best_loss = float(best_loss_fiel.read_text(encoding="utf-8").splitlines()[-1].strip()) + + return best_loss + + +class ConditionedPowerLaw(nn.Module): + def __init__( + self, + nr_initial_features=10, + nr_units=200, + nr_layers=3, + use_learning_curve: bool = True, + kernel_size: int = 3, + nr_filters: int = 4, + nr_cnn_layers: int = 2, + ): + """ + Args: + nr_initial_features: int + The number of features per example. + nr_units: int + The number of units for every layer. + nr_layers: int + The number of layers for the neural network. + use_learning_curve: bool + If the learning curve should be use in the network. + kernel_size: int + The size of the kernel that is applied in the cnn layer. + nr_filters: int + The number of filters that are used in the cnn layers. + nr_cnn_layers: int + The number of cnn layers to be used. + """ + super().__init__() + + self.use_learning_curve = use_learning_curve + self.kernel_size = kernel_size + self.nr_filters = nr_filters + self.nr_cnn_layers = nr_cnn_layers + + self.act_func = torch.nn.LeakyReLU() + self.last_act_func = torch.nn.GLU() + self.tan_func = torch.nn.Tanh() + self.batch_norm = torch.nn.BatchNorm1d + + layers = [] + # adding one since we concatenate the features with the budget + nr_initial_features = nr_initial_features + if self.use_learning_curve: + nr_initial_features = nr_initial_features + nr_filters + + layers.append(nn.Linear(nr_initial_features, nr_units)) + layers.append(self.act_func) + + for i in range(2, nr_layers + 1): + layers.append(nn.Linear(nr_units, nr_units)) + layers.append(self.act_func) + + last_layer = nn.Linear(nr_units, 3) + layers.append(last_layer) + + self.layers = torch.nn.Sequential(*layers) + + cnn_part = [] + if use_learning_curve: + cnn_part.append( + nn.Conv1d( + in_channels=2, + kernel_size=(self.kernel_size,), + out_channels=self.nr_filters, + ), + ) + for i in range(1, self.nr_cnn_layers): + cnn_part.append(self.act_func) + cnn_part.append( + nn.Conv1d( + in_channels=self.nr_filters, + kernel_size=(self.kernel_size,), + out_channels=self.nr_filters, + ), + ), + cnn_part.append(nn.AdaptiveAvgPool1d(1)) + + self.cnn = nn.Sequential(*cnn_part) + + def forward( + self, + x: torch.Tensor, + predict_budgets: torch.Tensor, + evaluated_budgets: torch.Tensor, + learning_curves: torch.Tensor, + ): + """ + Args: + x: torch.Tensor + The examples. + predict_budgets: torch.Tensor + The budgets for which the performance will be predicted for the + hyperparameter configurations. + evaluated_budgets: torch.Tensor + The budgets for which the hyperparameter configurations have been + evaluated so far. + learning_curves: torch.Tensor + The learning curves for the hyperparameter configurations. + """ + # print(x.shape) + # print(learning_curves.shape) + # x = torch.cat((x, torch.unsqueeze(evaluated_budgets, 1)), dim=1) + if self.use_learning_curve: + lc_features = self.cnn(learning_curves) + # print(lc_features.shape) + # revert the output from the cnn into nr_rows x nr_kernels. + lc_features = torch.squeeze(lc_features, 2) + # print(lc_features) + x = torch.cat((x, lc_features), dim=1) + # print(x.shape) + if torch.any(torch.isnan(x)): + raise ValueError("NaN values in input, the network probably diverged") + x = self.layers(x) + alphas = x[:, 0] + betas = x[:, 1] + gammas = x[:, 2] + # print(x) + output = torch.add( + alphas, + torch.mul( + self.last_act_func(torch.cat((betas, betas))), + torch.pow( + predict_budgets, + torch.mul(self.last_act_func(torch.cat((gammas, gammas))), -1), + ), + ), + ) + + return output + + +ModelClass = ConditionedPowerLaw + +MODEL_MAPPING: dict[str, type[ModelClass]] = {"power_law": ConditionedPowerLaw} + + +class PowerLawSurrogate: + # defaults to be used for functions + # fit params + default_lr = 0.001 + default_batch_size = 64 + default_nr_epochs = 250 + default_refine_epochs = 20 + default_early_stopping = False + default_early_stopping_patience = 10 + + # init params + default_n_initial_full_trainings = 10 + default_n_models = 5 + default_model_config = dict( + nr_units=128, + nr_layers=2, + use_learning_curve=False, + kernel_size=3, + nr_filters=4, + nr_cnn_layers=2, + ) + + # fit+predict params + default_padding_type = "zero" + default_budget_normalize = True + default_use_min_budget = False + default_y_normalize = False + + # Defined in __init__(...) + default_no_improvement_patience = ... + + def __init__( + self, + pipeline_space: SearchSpace, + observed_data: MFObservedData | None = None, + logger=None, + surrogate_model_fit_args: dict | None = None, + # IMPORTANT: Checkpointing does not use file locking, + # IMPORTANT: hence, it is not suitable for multiprocessing settings + # IMPORTANT: For parallel runs lock the checkpoint file during the whole training + checkpointing: bool = False, + root_directory: Path | str | None = None, + # IMPORTANT: For parallel runs use a different checkpoint_file name for each + # IMPORTANT: surrogate. This makes sure that parallel runs don't override each + # IMPORTANT: others saved checkpoint. Although they will still have some conflicts due to + # IMPORTANT: global optimizer step tracking + checkpoint_file: Path | str = "surrogate_checkpoint.pth", + refine_epochs: int = default_refine_epochs, + n_initial_full_trainings: int = default_n_initial_full_trainings, + default_model_class: str = "power_law", + default_model_config: dict[str, Any] = default_model_config, + n_models: int = default_n_models, + model_classes: list[str] | None = None, + model_configs: list[dict[str, Any]] | None = None, + refine_batch_size: int | None = None, + ): + if pipeline_space.has_tabular: + self.cover_pipeline_space = pipeline_space + self.real_pipeline_space = pipeline_space.raw_tabular_space + else: + self.cover_pipeline_space = pipeline_space + self.real_pipeline_space = pipeline_space + # self.pipeline_space = pipeline_space + + self.observed_data = observed_data + self.__preprocess_search_space(self.real_pipeline_space) + self.seeds = np.random.choice(100, n_models, replace=False) + self.model_configs = ( + [dict(nr_initial_features=self.input_size, **default_model_config)] * n_models + if not model_configs + else model_configs + ) + self.model_classes = ( + [MODEL_MAPPING[default_model_class]] * n_models + if not model_classes + else [MODEL_MAPPING[m_class] for m_class in model_classes] + ) + self.device = "cpu" + self.models: list[ModelClass] = [ + self.__initialize_model(config, self.model_classes[index], self.device) + for index, config in enumerate(self.model_configs) + ] + + self.checkpointing = checkpointing + self.refine_epochs = refine_epochs + self.refine_batch_size = refine_batch_size + self.n_initial_full_trainings = n_initial_full_trainings + self.default_no_improvement_patience = int( + self.max_fidelity + 0.2 * self.max_fidelity + ) + + if checkpointing: + assert ( + root_directory is not None + ), "neps root_directory must be provided for the checkpointing" + self.root_dir = Path(os.getcwd(), root_directory) + self.checkpoint_path = Path(os.getcwd(), root_directory, checkpoint_file) + + self.surrogate_model_fit_args = ( + surrogate_model_fit_args if surrogate_model_fit_args is not None else {} + ) + + if self.surrogate_model_fit_args.get("no_improvement_patience", None) is None: + # To replicate how the original DPL implementation handles the + # no_improvement_threshold + self.surrogate_model_fit_args[ + "no_improvement_patience" + ] = self.default_no_improvement_patience + + self.categories_array = np.array(self.categories) + + self.best_state = None + self.prediction_learning_curves = None + + self.criterion = torch.nn.L1Loss() + + self.logger = logger or logging.getLogger("neps") + + def __preprocess_search_space(self, pipeline_space: SearchSpace): + self.categories = [] + self.categorical_hps = [] + + parameter_count = 0 + for hp_name, hp in pipeline_space.items(): + # Collect all categories in a list for the encoder + if hp.is_fidelity: + continue # Ignore fidelity + if isinstance(hp, CategoricalParameter): + self.categorical_hps.append(hp_name) + self.categories.extend(hp.choices) + parameter_count += len(hp.choices) + else: + parameter_count += 1 + + # add 1 for budget + self.input_size = parameter_count + self.continuous_params_size = self.input_size - len(self.categories) + self.min_fidelity = pipeline_space.fidelity.lower + self.max_fidelity = pipeline_space.fidelity.upper + + def __encode_config(self, config: SearchSpace) -> np.ndarray: + categorical_encoding = np.zeros_like(self.categories_array, dtype=np.single) + continuous_values = [] + + for hp_name, hp in config.items(): + if hp.is_fidelity: + continue # Ignore fidelity + if hp_name in self.categorical_hps: + label = hp.value + categorical_encoding[np.argwhere(self.categories_array == label)] = 1 + else: + continuous_values.append(hp.normalized().value) + + continuous_encoding = np.array(continuous_values) + + encoding = np.concatenate([categorical_encoding, continuous_encoding]) + return encoding + + def __normalize_budgets( + self, budgets: np.ndarray, use_min_budget: bool + ) -> np.ndarray: + min_budget = self.min_fidelity if use_min_budget else 0 + normalized_budgets = (budgets - min_budget) / (self.max_fidelity - min_budget) + return normalized_budgets + + def __extract_budgets( + self, x_train: list[SearchSpace], normalized: bool, use_min_budget: bool + ) -> np.ndarray: + budgets = np.array([config.fidelity.value for config in x_train], dtype=np.single) + + if normalized: + budgets = self.__normalize_budgets(budgets, use_min_budget) + return budgets + + def __preprocess_learning_curves( + self, learning_curves: list[list[float]], padding_type: str + ) -> np.ndarray: + # Add padding to the learning curves to make them the same size + existing_values_mask = [] + max_length = self.max_fidelity - 1 + + if padding_type == "last": + init_value = self.__get_mean_initial_value() + else: + init_value = 0.0 + + for lc in learning_curves: + if len(lc) == 0: + padding_value = init_value + elif padding_type == "last": + padding_value = lc[-1] + else: + padding_value = 0.0 + + padding_length = int(max_length - len(lc)) + + mask = [1] * len(lc) + [0] * padding_length + existing_values_mask.append(mask) + + lc.extend([padding_value] * padding_length) + # print(learning_curves) + learning_curves = np.array(learning_curves, dtype=np.single) + existing_values_mask = np.array(existing_values_mask, dtype=np.single) + + learning_curves = np.stack((learning_curves, existing_values_mask), axis=1) + + return learning_curves + + def __reset_xy( + self, + x_train: list[SearchSpace], + y_train: list[float], + learning_curves: list[list[float]], + normalize_y: bool = default_y_normalize, + normalize_budget: bool = default_budget_normalize, + use_min_budget: bool = default_use_min_budget, + padding_type: str = default_padding_type, + ): + self.normalize_budget = ( # pylint: disable=attribute-defined-outside-init + normalize_budget + ) + self.use_min_budget = ( # pylint: disable=attribute-defined-outside-init + use_min_budget + ) + self.padding_type = padding_type # pylint: disable=attribute-defined-outside-init + self.normalize_y = normalize_y # pylint: disable=attribute-defined-outside-init + + x_train, train_budgets, learning_curves = self._preprocess_input( + x_train, + learning_curves, + self.normalize_budget, + self.use_min_budget, + self.padding_type, + ) + + y_train = self._preprocess_y(y_train, normalize_y) + + self.x_train = x_train # pylint: disable=attribute-defined-outside-init + self.train_budgets = ( # pylint: disable=attribute-defined-outside-init + train_budgets + ) + self.learning_curves = ( # pylint: disable=attribute-defined-outside-init + learning_curves + ) + self.y_train = y_train # pylint: disable=attribute-defined-outside-init + + def _preprocess_input( + self, + x: list[SearchSpace], + learning_curves: list[list[float]], + normalize_budget: bool, + use_min_budget: bool, + padding_type: str, + ) -> [torch.tensor, torch.tensor, torch.tensor]: + budgets = self.__extract_budgets(x, normalize_budget, use_min_budget) + learning_curves = self.__preprocess_learning_curves(learning_curves, padding_type) + + x = np.array([self.__encode_config(config) for config in x], dtype=np.single) + + x = torch.tensor(x).to(device=self.device) + budgets = torch.tensor(budgets).to(device=self.device) + learning_curves = torch.tensor(learning_curves).to(device=self.device) + + return x, budgets, learning_curves + + def _preprocess_y(self, y_train: list[float], normalize_y: bool) -> torch.tensor: + y_train_array = np.array(y_train, dtype=np.single) + self.min_y = y_train_array.min() # pylint: disable=attribute-defined-outside-init + self.max_y = y_train_array.max() # pylint: disable=attribute-defined-outside-init + if normalize_y: + y_train_array = (y_train_array - self.min_y) / (self.max_y - self.min_y) + y_train_array = torch.tensor(y_train_array).to(device=self.device) + return y_train_array + + def __is_refine(self, no_improvement_patience: int) -> bool: + losses = get_optimizer_losses(self.root_dir) + + best_loss = get_best_loss(self.root_dir) + + total_optimizer_steps = len(losses) + + # Count the non-improvement + non_improvement_steps = 0 + for loss in reversed(losses): + if np.greater(loss, best_loss): + non_improvement_steps += 1 + else: + break + + self.logger.debug(f"No improvement for: {non_improvement_steps} evaulations") + + return (non_improvement_steps < no_improvement_patience) and ( + self.n_initial_full_trainings <= total_optimizer_steps + ) + + def fit( + self, + x_train: list[SearchSpace], + y_train: list[float], + learning_curves: list[list[float]], + ): + self._fit(x_train, y_train, learning_curves, **self.surrogate_model_fit_args) + + def _fit( + self, + x_train: list[SearchSpace], + y_train: list[float], + learning_curves: list[list[float]], + nr_epochs: int = default_nr_epochs, + batch_size: int = default_batch_size, + early_stopping: bool = default_early_stopping, + early_stopping_patience: int = default_early_stopping_patience, + no_improvement_patience: int = default_no_improvement_patience, + optimizer_args: dict[str, Any] | None = None, + normalize_y: bool = default_y_normalize, + normalize_budget: bool = default_budget_normalize, + use_min_budget: bool = default_use_min_budget, + padding_type: str = default_padding_type, + ): + self.__reset_xy( + x_train, + y_train, + learning_curves, + normalize_y=normalize_y, + normalize_budget=normalize_budget, + use_min_budget=use_min_budget, + padding_type=padding_type, + ) + # check when to refine + if ( + self.checkpointing + and self.__is_refine(no_improvement_patience) + and self.checkpoint_path.exists() + ): + # self.__initialize_model() + self.load_state() + weight_new_point = True + nr_epochs = self.refine_epochs + batch_size = self.refine_batch_size if self.refine_batch_size else batch_size + else: + weight_new_point = False + + if optimizer_args is None: + optimizer_args = {"lr": self.default_lr} + + for model_index, model in enumerate(self.models): + self._train_a_model( + model_index, + self.x_train, + self.train_budgets, + self.y_train, + self.learning_curves, + nr_epochs=nr_epochs, + batch_size=batch_size, + early_stopping_patience=early_stopping_patience, + early_stopping=early_stopping, + weight_new_point=weight_new_point, + optimizer_args=optimizer_args, + ) + + # save model after training if checkpointing + if self.checkpointing: + self.save_state() + + def _train_a_model( + self, + model_index: int, + x_train: torch.tensor, + train_budgets: torch.tensor, + y_train: torch.tensor, + learning_curves: torch.tensor, + nr_epochs: int, + batch_size: int, + early_stopping_patience: int, + early_stopping: bool, + weight_new_point: bool, + optimizer_args: dict[str, Any], + ): + # Setting seeds will interfere with SearchSpace random sampling + if self.cover_pipeline_space.has_tabular: + seed = self.seeds[model_index] + torch.manual_seed(seed) + np.random.seed(seed) + + model = self.models[model_index] + + optimizer = torch.optim.Adam( + **dict({"params": model.parameters()}, **optimizer_args) + ) + + count_down = early_stopping_patience + best_loss = np.inf + best_state = deepcopy(model.state_dict()) + + model.train() + + if weight_new_point: + new_x, new_b, new_lc, new_y = self.prep_new_point() + else: + new_x, new_b, new_lc, new_y = [torch.tensor([])] * 4 + + for epoch in range(0, nr_epochs): + if early_stopping and count_down == 0: + self.logger.info( + f"Epoch: {epoch - 1} surrogate training stops due to early " + f"stopping with the patience: {early_stopping_patience} and " + f"the minimum average loss of {best_loss} and " + f"the final average loss of {best_loss}" + ) + model.load_state_dict(best_state) + break + + n_examples_batch = x_train.size(dim=0) + + # get a random permutation for mini-batches + permutation = torch.randperm(n_examples_batch) + + # optimize over mini-batches + total_scaled_loss = 0.0 + for batch_idx, start_index in enumerate( + range(0, n_examples_batch, batch_size) + ): + end_index = start_index + batch_size + if end_index > n_examples_batch: + end_index = n_examples_batch + indices = permutation[start_index:end_index] + batch_x, batch_budget, batch_lc, batch_y = ( + x_train[indices], + train_budgets[indices], + learning_curves[indices], + y_train[indices], + ) + + minibatch_size = end_index - start_index + + if weight_new_point: + batch_x = torch.cat((batch_x, new_x)) + batch_budget = torch.cat((batch_budget, new_b)) + batch_lc = torch.cat((batch_lc, new_lc)) + batch_y = torch.cat((batch_y, new_y)) + + # increase the batchsize + minibatch_size += new_x.shape[0] + + # if only one example in the batch, skip the batch. + # Otherwise, the code will fail because of batchnorm + if minibatch_size <= 1: + continue + + # Zero backprop gradients + optimizer.zero_grad(set_to_none=True) + + outputs = model(batch_x, batch_budget, batch_budget, batch_lc) + loss = self.criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + total_scaled_loss += loss.detach().item() * minibatch_size + + running_loss = total_scaled_loss / n_examples_batch + + if running_loss < best_loss: + best_loss = running_loss + count_down = early_stopping_patience + best_state = deepcopy(model.state_dict()) + elif early_stopping: + self.logger.debug( + f"No improvement over the minimum loss value of {best_loss} " + f"for the past {early_stopping_patience - count_down} epochs " + f"the training will stop in {count_down} epochs" + ) + count_down -= 1 + if early_stopping: + model.load_state_dict(best_state) + return model + + def set_prediction_learning_curves(self, learning_curves: list[list[float]]): + # pylint: disable=attribute-defined-outside-init + self.prediction_learning_curves = learning_curves + # pylint: enable=attribute-defined-outside-init + + def predict( + self, + x: list[SearchSpace], + learning_curves: list[list[float]] | None = None, + real_budgets: list[int | float] | None = None, + ) -> [torch.tensor, torch.tensor]: + # Preprocess input + # [print(_x.hp_values()) for _x in x] + if learning_curves is None: + learning_curves = self.prediction_learning_curves + + if real_budgets is None: + # Get the list of budgets the configs are evaluated for + real_budgets = [len(lc) + 1 for lc in learning_curves] + + x_test, prediction_budgets, learning_curves = self._preprocess_input( + x, + learning_curves, + self.normalize_budget, + self.use_min_budget, + self.padding_type, + ) + # preprocess the list of budgets the configs are evaluated for + real_budgets = np.array(real_budgets, dtype=np.single) + real_budgets = self.__normalize_budgets(real_budgets, self.use_min_budget) + real_budgets = torch.tensor(real_budgets).to(self.device) + + all_predictions = [] + for model in self.models: + model.eval() + + preds = model(x_test, prediction_budgets, real_budgets, learning_curves) + all_predictions.append(preds.detach().cpu().numpy()) + + means = torch.tensor(np.mean(all_predictions, axis=0)).cpu() + std_predictions = np.std(all_predictions, axis=0) + cov = torch.diag(torch.tensor(np.power(std_predictions, 2))).cpu() + + return means, cov + + def load_state(self, state: dict[str, int | str | dict[str, Any]] | None = None): + # load and save last evaluated config as well + if state is None: + checkpoint = torch.load(self.checkpoint_path) + else: + checkpoint = state + + self.last_point = checkpoint["last_point"] + + for model_index in range(checkpoint["n_models"]): + self.models[model_index].load_state_dict( + checkpoint[f"model_{model_index}_state_dict"] + ) + self.models[model_index].to(self.device) + + def get_state(self) -> dict[str, int | str | dict[str, Any]]: + n_models = len(self.models) + model_states = { + f"model_{model_index}_state_dict": deepcopy( + self.models[model_index].state_dict() + ) + for model_index in range(n_models) + } + + # get last point + last_point = self.get_last_point() + current_state = dict(n_models=n_models, last_point=last_point, **model_states) + + return current_state + + def __config_ids(self) -> list[str]: + # Parallelization issues + all_losses_file = self.root_dir / "all_losses_and_configs.txt" + + if all_losses_file.exists(): + # Read all losses from the file in the order they are explored + config_ids = [ + str(line[11:]) + for line in all_losses_file.read_text(encoding="utf-8").splitlines() + if "Config ID: " in line + ] + else: + config_ids = [] + + return config_ids + + def save_state(self, state: dict[str, int | str | dict[str, Any]] | None = None): + # TODO: save last evaluated config as well + if state is None: + torch.save( + self.get_state(), + self.checkpoint_path, + ) + else: + assert ( + "last_point" in state and "n_models" in state + ), "The state dictionary is not complete" + torch.save( + state, + self.checkpoint_path, + ) + + def get_last_point(self) -> str: + # Only for single worker case + last_config_id = self.__config_ids()[-1] + # For parallel runs + # get the last config_id that's also in self.observed_configs + return last_config_id + + def get_new_points(self) -> [list[SearchSpace], list[list[float]], list[float]]: + # Get points that haven't been trained on before + + config_ids = self.__config_ids() + + if self.last_point: + index = config_ids.index(self.last_point) + 1 + else: + index = len(config_ids) - 1 + + new_config_indices = [ + tuple(map(int, config_id.split("_"))) for config_id in config_ids[index:] + ] + + # Only include the points that exist in the observed data already + # (not a use case for single worker runs) + existing_index_map = self.observed_data.df.index.isin(new_config_indices) + + new_config_df = self.observed_data.df.loc[existing_index_map, :].copy(deep=True) + + new_configs, new_lcs, new_y = self.observed_data.get_training_data_4DyHPO( + new_config_df, self.cover_pipeline_space + ) + + return new_configs, new_lcs, new_y + + @staticmethod + def __initialize_model( + model_params: dict[str, Any], model_class: type[ModelClass], device: str + ) -> ModelClass: + model = model_class(**model_params) + model.to(device) + return model + + def prep_new_point(self) -> [torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + new_point, new_lc, new_y = self.get_new_points() + + new_x, new_b, new_lc = self._preprocess_input( + new_point, + new_lc, + self.normalize_budget, + self.use_min_budget, + self.padding_type, + ) + new_y = self._preprocess_y(new_y, self.normalize_y) + + return new_x, new_b, new_lc, new_y + + def __get_mean_initial_value(self): + mean = self.observed_data.get_trajectories().loc[:, 0].mean() + + return mean + + +if __name__ == "__main__": + max_fidelity = 50 + pipe_space = SearchSpace( + float_=FloatParameter(lower=0.0, upper=5.0), + e=IntegerParameter(lower=1, upper=max_fidelity, is_fidelity=True), + ) + + configs = [pipe_space.sample(ignore_fidelity=False) for _ in range(100)] + + y = np.random.random(100).tolist() + + lcs = [ + np.random.random(size=np.random.randint(low=1, high=max_fidelity)).tolist() + for _ in range(100) + ] + + surrogate = PowerLawSurrogate(pipe_space) + + surrogate.fit(x_train=configs, learning_curves=lcs, y_train=y) + + means, stds = surrogate.predict(configs, lcs) + + print(list(zip(means, y))) + print(stds) diff --git a/neps/optimizers/bayesian_optimization/models/__init__.py b/neps/optimizers/bayesian_optimization/models/__init__.py index 0eaeb127..4e8bf42f 100755 --- a/neps/optimizers/bayesian_optimization/models/__init__.py +++ b/neps/optimizers/bayesian_optimization/models/__init__.py @@ -1,4 +1,5 @@ from ....metahyper.utils import MissingDependencyError +from .DPL import PowerLawSurrogate from .gp import ComprehensiveGP from .gp_hierarchy import ComprehensiveGPHierarchy @@ -9,12 +10,13 @@ try: from .pfn import PFN_SURROGATE # only if available locally -except Exception as e: +except ImportError as e: PFN_SURROGATE = MissingDependencyError("pfn", e) SurrogateModelMapping = { "deep_gp": DeepGP, "gp": ComprehensiveGP, "gp_hierarchy": ComprehensiveGPHierarchy, + "dpl": PowerLawSurrogate, "pfn": PFN_SURROGATE, } diff --git a/neps/optimizers/bayesian_optimization/models/deepGP.py b/neps/optimizers/bayesian_optimization/models/deepGP.py index ee47ce70..9b7fd841 100644 --- a/neps/optimizers/bayesian_optimization/models/deepGP.py +++ b/neps/optimizers/bayesian_optimization/models/deepGP.py @@ -18,11 +18,8 @@ ) -def count_non_improvement_steps(root_directory: Path | str) -> int: - root_directory = Path(root_directory) - +def get_optimizer_losses(root_directory: Path | str) -> list[float]: all_losses_file = root_directory / "all_losses_and_configs.txt" - best_loss_fiel = root_directory / "best_loss_trajectory.txt" # Read all losses from the file in the order they are explored losses = [ @@ -30,18 +27,17 @@ def count_non_improvement_steps(root_directory: Path | str) -> int: for line in all_losses_file.read_text(encoding="utf-8").splitlines() if "Loss: " in line ] + return losses + + +def get_best_loss(root_directory: Path | str) -> float: + root_directory = Path(root_directory) + best_loss_fiel = root_directory / "best_loss_trajectory.txt" + # Get the best seen loss value best_loss = float(best_loss_fiel.read_text(encoding="utf-8").splitlines()[-1].strip()) - # Count the non-improvement - count = 0 - for loss in reversed(losses): - if np.greater(loss, best_loss): - count += 1 - else: - break - - return count + return best_loss class NeuralFeatureExtractor(nn.Module): @@ -167,13 +163,19 @@ def __init__( # IMPORTANT: hence, it is not suitable for multiprocessing settings checkpointing: bool = False, root_directory: Path | str | None = None, + # IMPORTANT: For parallel runs use a different checkpoint_file name for each + # IMPORTANT: surrogate. This makes sure that parallel runs don't override each + # IMPORTANT: others saved checkpoint. Although they will still have some conflicts due to + # IMPORTANT: global optimizer step tracking checkpoint_file: Path | str = "surrogate_checkpoint.pth", refine_epochs: int = 50, + n_initial_full_trainings: int = 10, **kwargs, # pylint: disable=unused-argument - ): + ): self.surrogate_model_fit_args = ( surrogate_model_fit_args if surrogate_model_fit_args is not None else {} ) + self.n_initial_full_trainings = n_initial_full_trainings self.checkpointing = checkpointing self.refine_epochs = refine_epochs @@ -204,9 +206,18 @@ def __init__( neural_network_args.get("n_layers", 2) ) + if self.surrogate_model_fit_args.get("perf_patience", -1) is None: + # To replicate how the original DyHPO implementation handles the + # no_improvement_threshold + self.surrogate_model_fit_args["perf_patience"] = int( + self.max_fidelity + 0.2 * self.max_fidelity + ) + # build the neural network self.nn = NeuralFeatureExtractor(self.input_size, **neural_network_args) + self.best_state = None + self.logger = logger or logging.getLogger("neps") def __initialize_gp_model( @@ -237,6 +248,27 @@ def __initialize_gp_model( mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device) return model, likelihood, mll + def __is_refine(self, perf_patience: int): + losses = get_optimizer_losses(self.root_dir) + + best_loss = get_best_loss(self.root_dir) + + total_optimizer_steps = len(losses) + + # Count the non-improvement + non_improvement_steps = 0 + for loss in reversed(losses): + if np.greater(loss, best_loss): + non_improvement_steps += 1 + else: + break + + self.logger.debug(f"No improvement for: {non_improvement_steps} evaulations") + + return (non_improvement_steps < perf_patience) and ( + self.n_initial_full_trainings <= total_optimizer_steps + ) + def __preprocess_search_space(self, pipeline_space: SearchSpace): self.categories = [] self.categorical_hps = [] @@ -258,9 +290,9 @@ def __preprocess_search_space(self, pipeline_space: SearchSpace): self.max_fidelity = pipeline_space.fidelity.upper def __encode_config(self, config: SearchSpace): - categorical_encoding = np.zeros_like(self.categories_array) + categorical_encoding = np.zeros_like(self.categories_array, dtype=np.single) continuous_values = [] - + # print(config.hp_values()) for hp_name, hp in config.items(): if hp.is_fidelity: continue # Ignore fidelity @@ -268,6 +300,7 @@ def __encode_config(self, config: SearchSpace): label = hp.value categorical_encoding[np.argwhere(self.categories_array == label)] = 1 else: + # self.logger.info(f"{hp_name} Value: {hp.value} Normalized {hp.normalized().value}") continuous_values.append(hp.normalized().value) continuous_encoding = np.array(continuous_values) @@ -276,13 +309,15 @@ def __encode_config(self, config: SearchSpace): return encoding def __extract_budgets( - self, x_train: list[SearchSpace], normalized: bool = True + self, + x_train: list[SearchSpace], + normalized: bool = True, + use_min_budget: bool = False, ) -> np.ndarray: + min_budget = self.min_fidelity if use_min_budget else 0 budgets = np.array([config.fidelity.value for config in x_train], dtype=np.single) if normalized: - normalized_budgets = (budgets - self.min_fidelity) / ( - self.max_fidelity - self.min_fidelity - ) + normalized_budgets = (budgets - min_budget) / (self.max_fidelity - min_budget) budgets = normalized_budgets return budgets @@ -316,14 +351,18 @@ def __reset_xy( learning_curves: list[list[float]], normalize_y: bool = False, normalize_budget: bool = True, + use_min_budget: bool = False, ): self.normalize_budget = ( # pylint: disable=attribute-defined-outside-init normalize_budget ) + self.use_min_budget = ( # pylint: disable=attribute-defined-outside-init + use_min_budget + ) self.normalize_y = normalize_y # pylint: disable=attribute-defined-outside-init x_train, train_budgets, learning_curves = self._preprocess_input( - x_train, learning_curves, normalize_budget + x_train, learning_curves, normalize_budget, use_min_budget ) y_train = self._preprocess_y(y_train, normalize_y) @@ -342,8 +381,9 @@ def _preprocess_input( x: list[SearchSpace], learning_curves: list[list[float]], normalize_budget: bool = True, + use_min_budget: bool = False, ): - budgets = self.__extract_budgets(x, normalize_budget) + budgets = self.__extract_budgets(x, normalize_budget, use_min_budget) learning_curves = self.__preprocess_learning_curves(learning_curves) x = np.array([self.__encode_config(config) for config in x], dtype=np.single) @@ -369,6 +409,7 @@ def fit( y_train: list[float], learning_curves: list[list[float]], ): + self.logger.info(f"FIT ARGS: {self.surrogate_model_fit_args}") self._fit(x_train, y_train, learning_curves, **self.surrogate_model_fit_args) def _fit( @@ -378,6 +419,7 @@ def _fit( learning_curves: list[list[float]], normalize_y: bool = False, normalize_budget: bool = True, + use_min_budget: bool = False, n_epochs: int = 1000, batch_size: int = 64, optimizer_args: dict | None = None, @@ -391,6 +433,7 @@ def _fit( learning_curves, normalize_y=normalize_y, normalize_budget=normalize_budget, + use_min_budget=use_min_budget, ) self.model, self.likelihood, self.mll = self.__initialize_gp_model(len(y_train)) self.nn = NeuralFeatureExtractor(self.input_size, **self.nn_args) @@ -399,12 +442,10 @@ def _fit( self.nn.to(self.device) if self.checkpointing and self.checkpoint_path.exists(): - non_improvement_steps = count_non_improvement_steps(self.root_dir) # If checkpointing and patience is not exhausted load a partial model - if non_improvement_steps < perf_patience: + if self.__is_refine(perf_patience): n_epochs = self.refine_epochs self.load_checkpoint() - self.logger.debug(f"No improvement for: {non_improvement_steps} evaulations") self.logger.debug(f"N Epochs for the full training: {n_epochs}") initial_state = self.get_state() @@ -421,7 +462,7 @@ def _fit( patience=patience, ) if self.checkpointing: - self.save_checkpoint() + self.save_checkpoint(self.best_state) except gpytorch.utils.errors.NotPSDError: self.logger.info("Model training failed loading the untrained model") self.load_checkpoint(initial_state) @@ -467,6 +508,7 @@ def __train_model( f"the minimum average loss of {min_avg_loss_val} and " f"the final average loss of {average_loss}" ) + self.load_checkpoint(self.best_state) break n_examples_batch = x_train.size(dim=0) @@ -530,6 +572,7 @@ def __train_model( if average_loss < min_avg_loss_val: min_avg_loss_val = average_loss count_down = patience + self.best_state = self.get_state() elif early_stopping: self.logger.debug( f"No improvement over the minimum loss value of {min_avg_loss_val} " @@ -558,7 +601,7 @@ def predict( if learning_curves is None: learning_curves = self.prediction_learning_curves x_test, test_budgets, learning_curves = self._preprocess_input( - x, learning_curves, self.normalize_budget + x, learning_curves, self.normalize_budget, self.use_min_budget ) self.model.eval() diff --git a/neps/optimizers/bayesian_optimization/optimizer.py b/neps/optimizers/bayesian_optimization/optimizer.py index 6c47ac8b..90aac0ac 100644 --- a/neps/optimizers/bayesian_optimization/optimizer.py +++ b/neps/optimizers/bayesian_optimization/optimizer.py @@ -3,7 +3,7 @@ import random from typing import Any -from typing_extensions import Literal +from typing import Literal from ...metahyper import ConfigResult, instance_from_map from ...search_spaces.hyperparameters.categorical import ( diff --git a/neps/optimizers/default_searchers/priorband.yaml b/neps/optimizers/default_searchers/priorband.yaml index 9b11ae9a..0b2cf1ae 100644 --- a/neps/optimizers/default_searchers/priorband.yaml +++ b/neps/optimizers/default_searchers/priorband.yaml @@ -9,7 +9,7 @@ searcher_kwargs: sample_default_first: true sample_default_at_target: false prior_weight_type: geometric - inc_sample_type: mutation + inc_sample_type: mutation inc_mutation_rate: 0.5 inc_mutation_std: 0.25 inc_style: dynamic diff --git a/neps/optimizers/default_searchers/priorband_bo.yaml b/neps/optimizers/default_searchers/priorband_bo.yaml index 04e530c1..4c00280e 100644 --- a/neps/optimizers/default_searchers/priorband_bo.yaml +++ b/neps/optimizers/default_searchers/priorband_bo.yaml @@ -16,7 +16,7 @@ searcher_kwargs: # arguments for model model_based: true # crucial argument to set to allow model-search - modelling_type: joint + modelling_type: joint initial_design_size: 10 surrogate_model: gp # or {"gp_hierarchy"} acquisition: EI # or {"LogEI", "AEI"} diff --git a/neps/optimizers/multi_fidelity/dyhpo.py b/neps/optimizers/multi_fidelity/dyhpo.py index b72afd15..b23f82ba 100755 --- a/neps/optimizers/multi_fidelity/dyhpo.py +++ b/neps/optimizers/multi_fidelity/dyhpo.py @@ -5,9 +5,11 @@ from typing import Any import numpy as np +import pandas as pd from ...metahyper import ConfigResult, instance_from_map from ...search_spaces.search_space import FloatParameter, IntegerParameter, SearchSpace +from ...utils.common import EvaluationData, SimpleCSVWriter from ..base_optimizer import BaseOptimizer from ..bayesian_optimization.acquisition_functions import AcquisitionMapping from ..bayesian_optimization.acquisition_functions.base_acquisition import BaseAcquisition @@ -22,6 +24,19 @@ from .utils import MFObservedData +class AcqWriter(SimpleCSVWriter): + def set_data(self, sample_configs: pd.Series, acq_vals: pd.Series): + config_vals = pd.DataFrame( + [config.hp_values() for config in sample_configs], index=sample_configs.index + ) + if isinstance(acq_vals, pd.Series): + acq_vals.name = "Acq Value" + # pylint: disable=attribute-defined-outside-init + self.df = config_vals.join(acq_vals) + self.df = self.df.sort_values(by="Acq Value") + # pylint: enable=attribute-defined-outside-init + + class MFEIBO(BaseOptimizer): """Base class for MF-BO algorithms that use DyHPO-like acquisition and budgeting.""" @@ -105,8 +120,11 @@ def __init__( self.total_fevals: int = 0 self.observed_configs = MFObservedData( - columns=["config", "perf", "learning_curves"], - index_names=["config_id", "budget_id"], + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + learning_curve_col="learning_curves", ) # Preparing model @@ -123,7 +141,7 @@ def __init__( self._prep_model_args(self.hp_kernels, self.graph_kernels, pipeline_space) # TODO: Better solution than branching based on the surrogate name is needed - if surrogate_model in ["deep_gp", "gp"]: + if surrogate_model in ["deep_gp", "gp", "dpl"]: model_policy = FreezeThawModel elif surrogate_model == "pfn": model_policy = PFNSurrogate @@ -163,6 +181,8 @@ def __init__( ) self.count = 0 + self.evaluation_data = EvaluationData() + def _prep_model_args(self, hp_kernels, graph_kernels, pipeline_space): if self.surrogate_model_name in ["gp", "gp_hierarchy"]: # setup for GP implemented in NePS @@ -263,7 +283,7 @@ def total_budget_spent(self) -> int | float: return total_budget_spent - def is_init_phase(self, budget_based: bool = True) -> bool: + def is_init_phase(self, budget_based: bool = False) -> bool: if budget_based: # Check if we are still in the initial design phase based on # either the budget spent so far or the number of configurations evaluated @@ -290,10 +310,12 @@ def load_results( pending_evaluations (dict[str, ConfigResult]): [description] """ self.observed_configs = MFObservedData( - columns=["config", "perf", "learning_curves"], - index_names=["config_id", "budget_id"], + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + learning_curve_col="learning_curves", ) - # previous optimization run exists and needs to be loaded self._load_previous_observations(previous_results) self.total_fevals = len(previous_results) + len(pending_evaluations) @@ -305,7 +327,6 @@ def load_results( self.observed_configs.df.sort_index( level=self.observed_configs.df.index.names, inplace=True ) - # TODO: can we do better than keeping a copy of the observed configs? # TODO: can we not hide this in load_results and have something that pops out # more, like a set_state or policy_args @@ -419,7 +440,7 @@ def get_config_and_ids( # pylint: disable=no-self-use ) config.fidelity.value = self.min_budget _config_id = self.observed_configs.next_config_id() - elif self.is_init_phase(budget_based=True) or self._model_update_failed: + elif self.is_init_phase() or self._model_update_failed: # promote a config randomly if initial design size is satisfied but the # initial design budget has not been exhausted self.logger.info("promoting...") @@ -427,7 +448,7 @@ def get_config_and_ids( # pylint: disable=no-self-use else: if self.count == 0: self.logger.info("\nPartial learning curves as initial design:\n") - self.logger.info(f"{self.observed_configs.get_learning_curves()}\n") + self.logger.info(f"{self.observed_configs.get_trajectories()}\n") self.count += 1 # main acquisition call here after initial design is turned off self.logger.info("acquiring...") @@ -435,29 +456,76 @@ def get_config_and_ids( # pylint: disable=no-self-use samples = self.acquisition_sampler.sample( set_new_sample_fidelity=self.pipeline_space.fidelity.lower ) # fidelity values here should be the observations or min. fidelity + # calculating acquisition function values for the candidate samples acq, _samples = self.acquisition.eval( # type: ignore[attr-defined] x=samples, asscalar=True ) + acq = pd.Series(acq, index=_samples.index) + # maximizing acquisition function - _idx = np.argsort(acq)[-1] + best_idx = acq.sort_values().index[-1] # extracting the config ID for the selected maximizer - _config_id = samples.index[_samples.index.values[_idx]] + _config_id = best_idx # samples.index[_samples.index.values[_idx]] # `_samples` should have new configs with fidelities set to as required # NOTE: len(samples) need not be equal to len(_samples) as `samples` contain # all (partials + new) configurations obtained from the sampler, but # in `_samples`, configs are removed that have reached maximum epochs allowed # NOTE: `samples` and `_samples` should share the same index values, hence, - # avoid using `.iloc` and work with `.loc` on pandas DataFrame/Series - - # Is this "config = _samples.loc[_config_id]"? + # avoid using `.iloc` and work with `.loc` on these pandas DataFrame/Series + + acq_writer = AcqWriter("Acq_values") + # Writes extra information into Acq_values.csv + # if hasattr(self.acquisition, "mu"): + # # collect prediction learning_curves + # lcs = [] + # # and tabular ids + # tabular_ids = [] + # for idx in _samples.index: + # if self.acquisition_sampler.is_tabular: + # tabular_ids.append(samples[idx]["id"].value) + # if idx in self.observed_configs.df.index.levels[0]: + # # budget_level = self.get_budget_level(_samples[idx]) + # # extracting the available/observed learning curve + # lc = self.observed_configs.extract_learning_curve(idx, budget_id=None) + # else: + # # initialize a learning curve with a placeholder + # # This is later padded accordingly for the Conv1D layer + # lc = [] + # lcs.append(lc) + # + # data = {"Acq Value": acq.values, + # "preds": self.acquisition.mu, + # "incumbents": self.acquisition.mu_star, + # "std": self.acquisition.std, + # "pred_learning_curves": lcs} + # if self.acquisition_sampler.is_tabular: + # data["tabular_ids"] = tabular_ids + # + # acq = pd.DataFrame(data, index=_samples.index) + acq_writer.set_data(_samples, acq) + self.evaluation_data.data_dict["acq"] = acq_writer + + # assigning config hyperparameters config = samples.loc[_config_id] - config.fidelity.value = _samples.loc[_config_id].fidelity.value + # IMPORTANT: setting the fidelity appropriately + + config.fidelity.value = ( + config.fidelity.lower + if best_idx > max(self.observed_configs.seen_config_ids) + else ( + self.get_budget_value( + self.observed_configs.get_max_observed_fidelity_level_per_config().loc[ + best_idx + ] + ) + + self.step_size # ONE-STEP FIDELITY QUERY + ) + ) # generating correct IDs if _config_id in self.observed_configs.seen_config_ids: config_id = f"{_config_id}_{self.get_budget_level(config)}" previous_config_id = f"{_config_id}_{self.get_budget_level(config) - 1}" else: config_id = f"{self.observed_configs.next_config_id()}_{self.get_budget_level(config)}" - return config.hp_values(), config_id, previous_config_id diff --git a/neps/optimizers/multi_fidelity/hyperband.py b/neps/optimizers/multi_fidelity/hyperband.py index 86ff2f5f..1db92056 100644 --- a/neps/optimizers/multi_fidelity/hyperband.py +++ b/neps/optimizers/multi_fidelity/hyperband.py @@ -4,10 +4,9 @@ import typing from copy import deepcopy -from typing import Any +from typing import Any, Literal import numpy as np -from typing_extensions import Literal from ...metahyper import ConfigResult from ...search_spaces.search_space import SearchSpace @@ -99,7 +98,9 @@ def _update_sh_bracket_state(self) -> None: # for the current SH bracket in HB # TODO: can we avoid copying full observation history bracket = self.sh_brackets[self.current_sh_bracket] # type: ignore - bracket.observed_configs = self.observed_configs.copy() + # bracket.max_budget_configs = self.max_budget_configs.copy() + # TODO: Do we NEED to copy here instead? + bracket.MFobserved_configs = self.MFobserved_configs # pylint: disable=no-self-use def clear_old_brackets(self): @@ -169,7 +170,7 @@ def clear_old_brackets(self): base_rung_sizes = [] # sorted(self.config_map.values(), reverse=True) for bracket in self.sh_brackets.values(): base_rung_sizes.append(sorted(bracket.config_map.values(), reverse=True)[0]) - while end <= len(self.observed_configs): + while end <= len(self.max_budget_configs): # subsetting only this SH bracket from the history sh_bracket = self.sh_brackets[self.current_sh_bracket] sh_bracket.clean_rung_information() @@ -177,14 +178,14 @@ def clear_old_brackets(self): # correct SH bracket object to make the right budget calculations # pylint: disable=protected-access bracket_budget_used = sh_bracket._calc_budget_used_in_bracket( - deepcopy(self.observed_configs.rung.values[start:end]) + deepcopy(self.max_budget_configs.rung.values[start:end]) ) # if budget used is less than the total SH budget then still an active bracket current_bracket_full_budget = sum(sh_bracket.full_rung_trace) if bracket_budget_used < current_bracket_full_budget: # updating rung information of the current bracket # pylint: disable=protected-access - sh_bracket._get_rungs_state(self.observed_configs.iloc[start:end]) + sh_bracket._get_rungs_state(self.max_budget_configs.iloc[start:end]) # extra call to use the updated rung member info to find promotions # SyncPromotion signals a wait if a rung is full but with # incomplete/pending evaluations, signals to starts a new SH bracket @@ -209,7 +210,7 @@ def clear_old_brackets(self): # updates rung info with the latest active, incomplete bracket sh_bracket = self.sh_brackets[self.current_sh_bracket] # pylint: disable=protected-access - sh_bracket._get_rungs_state(self.observed_configs.iloc[start:end]) + sh_bracket._get_rungs_state(self.max_budget_configs.iloc[start:end]) sh_bracket._handle_promotions() # self._handle_promotion() need not be called as it is called by load_results() @@ -302,7 +303,7 @@ def __init__( prior_confidence=prior_confidence, random_interleave_prob=random_interleave_prob, sample_default_first=sample_default_first, - sample_default_at_target=sample_default_at_target + sample_default_at_target=sample_default_at_target, ) self.sampling_args = { "inc": None, @@ -379,7 +380,7 @@ def _update_sh_bracket_state(self) -> None: config_map=bracket.config_map, ) bracket.rung_promotions = bracket.promotion_policy.retrieve_promotions() - bracket.observed_configs = self.observed_configs.copy() + bracket.max_budget_configs = self.max_budget_configs.copy() def _get_bracket_to_run(self): """Samples the ASHA bracket to run. diff --git a/neps/optimizers/multi_fidelity/mf_bo.py b/neps/optimizers/multi_fidelity/mf_bo.py index 8b14d70b..18703903 100755 --- a/neps/optimizers/multi_fidelity/mf_bo.py +++ b/neps/optimizers/multi_fidelity/mf_bo.py @@ -33,7 +33,7 @@ def _fit_models(self): if self.pipeline_space.has_prior: # PriorBand + BO total_resources = calc_total_resources_spent( - self.observed_configs, self.rung_map + self.max_budget_configs, self.rung_map ) decay_t = total_resources / self.max_budget else: @@ -42,7 +42,7 @@ def _fit_models(self): # extract pending configurations # doing this separately as `rung_histories` do not record pending configs - pending_df = self.observed_configs[self.observed_configs.perf.isna()] + pending_df = self.max_budget_configs[self.max_budget_configs.perf.isna()] if self.modelling_type == "rung": # collect only the finished configurations at the highest active `rung` # for training the surrogate and considering only those pending @@ -58,7 +58,7 @@ def _fit_models(self): ) self.logger.info(f"Building model at rung {rung}") # collecting finished evaluations at `rung` - train_df = self.observed_configs.loc[ + train_df = self.max_budget_configs.loc[ self.rung_histories[rung]["config"] ].copy() @@ -89,7 +89,7 @@ def _fit_models(self): pending_x = [] for rung in range(self.min_rung, self.max_rung + 1): _ids = self.rung_histories[rung]["config"] - _x = deepcopy(self.observed_configs.loc[_ids].config.values.tolist()) + _x = deepcopy(self.max_budget_configs.loc[_ids].config.values.tolist()) # update fidelity fidelity = [self.rung_map[rung]] * len(_x) _x = list(map(update_fidelity, _x, fidelity)) @@ -131,7 +131,7 @@ def is_init_phase(self) -> bool: # builds a model across all fidelities with the fidelity as a dimension # in this case, calculate the total number of function evaluations spent # and in vanilla BO fashion use that to compare with the initital design size - resources = calc_total_resources_spent(self.observed_configs, self.rung_map) + resources = calc_total_resources_spent(self.max_budget_configs, self.rung_map) resources /= self.max_budget if resources < self.init_size: return True @@ -198,6 +198,13 @@ def __init__( ) if self.surrogate_model_name in ["deep_gp", "pfn"]: self.surrogate_model_args.update({"pipeline_space": pipeline_space}) + elif self.surrogate_model_name == "dpl": + self.surrogate_model_args.update( + { + "pipeline_space": self.pipeline_space, + "observed_data": self.observed_configs, + } + ) # instantiate the surrogate model self.surrogate_model = instance_from_map( @@ -233,7 +240,7 @@ def _fantasize_pending(self, train_x, train_y, pending_x): def _fit(self, train_x, train_y, train_lcs): if self.surrogate_model_name in ["gp", "gp_hierarchy"]: self.surrogate_model.fit(train_x, train_y) - elif self.surrogate_model_name in ["deep_gp", "pfn"]: + elif self.surrogate_model_name in ["deep_gp", "pfn", "dpl"]: self.surrogate_model.fit(train_x, train_y, train_lcs) else: # check neps/optimizers/bayesian_optimization/models/__init__.py for options @@ -244,7 +251,7 @@ def _fit(self, train_x, train_y, train_lcs): def _predict(self, test_x, test_lcs): if self.surrogate_model_name in ["gp", "gp_hierarchy"]: return self.surrogate_model.predict(test_x) - elif self.surrogate_model_name in ["deep_gp", "pfn"]: + elif self.surrogate_model_name in ["deep_gp", "pfn", "dpl"]: return self.surrogate_model.predict(test_x, test_lcs) else: # check neps/optimizers/bayesian_optimization/models/__init__.py for options @@ -262,12 +269,33 @@ def set_state( self.surrogate_model_args = ( surrogate_model_args if surrogate_model_args is not None else {} ) + if self.surrogate_model_name == "dpl": + self.surrogate_model_args.update( + { + "pipeline_space": self.pipeline_space, + "observed_data": self.observed_configs, + } + ) + self.surrogate_model = instance_from_map( + SurrogateModelMapping, + self.surrogate_model_name, + name="surrogate model", + kwargs=self.surrogate_model_args, + ) + # only to handle tabular spaces if self.pipeline_space.has_tabular: if self.surrogate_model_name in ["deep_gp", "pfn"]: self.surrogate_model_args.update( {"pipeline_space": self.pipeline_space.raw_tabular_space} ) + elif self.surrogate_model_name == "dpl": + self.surrogate_model_args.update( + { + "pipeline_space": self.pipeline_space, + "observed_data": self.observed_configs, + } + ) # instantiate the surrogate model, again, with the new pipeline space self.surrogate_model = instance_from_map( SurrogateModelMapping, @@ -275,6 +303,19 @@ def set_state( name="surrogate model", kwargs=self.surrogate_model_args, ) + elif self.surrogate_model_name == "dpl": + self.surrogate_model_args.update( + { + "pipeline_space": self.pipeline_space, + "observed_data": self.observed_configs, + } + ) + self.surrogate_model = instance_from_map( + SurrogateModelMapping, + self.surrogate_model_name, + name="surrogate model", + kwargs=self.surrogate_model_args, + ) def update_model(self, train_x=None, train_y=None, pending_x=None, decay_t=None): if train_x is None: @@ -322,6 +363,8 @@ def preprocess_training_set(self): configs, idxs, performances = self.observed_configs.get_tokenized_data( self.observed_configs.df.copy().assign(config=_configs) ) + idxs = idxs.astype(float) + idxs[:, 1] = idxs[:, 1] / _configs[0].fidelity.upper # TODO: account for fantasization self.train_x = torch.Tensor(np.hstack([idxs, configs])).to(device) self.train_y = torch.Tensor(performances).to(device) diff --git a/neps/optimizers/multi_fidelity/successive_halving.py b/neps/optimizers/multi_fidelity/successive_halving.py index a3145dc2..eb540e69 100644 --- a/neps/optimizers/multi_fidelity/successive_halving.py +++ b/neps/optimizers/multi_fidelity/successive_halving.py @@ -5,10 +5,9 @@ import random import typing from copy import deepcopy +from typing import Literal import numpy as np -import pandas as pd -from typing_extensions import Literal from ...metahyper import ConfigResult from ...search_spaces.hyperparameters.categorical import ( @@ -22,6 +21,7 @@ from ..base_optimizer import BaseOptimizer from .promotion_policy import AsyncPromotionPolicy, SyncPromotionPolicy from .sampling_policy import FixedPriorPolicy, RandomUniformPolicy +from .utils import MFObservedData CUSTOM_FLOAT_CONFIDENCE_SCORES = FLOAT_CONFIDENCE_SCORES.copy() CUSTOM_FLOAT_CONFIDENCE_SCORES.update({"ultra": 0.05}) @@ -102,8 +102,7 @@ def __init__( # the parameter is exposed to allow HB to call SH with different stopping rates self.early_stopping_rate = early_stopping_rate self.sampling_policy = sampling_policy( - pipeline_space=self.pipeline_space, - logger=self.logger + pipeline_space=self.pipeline_space, logger=self.logger ) self.promotion_policy = promotion_policy(self.eta) @@ -132,9 +131,20 @@ def __init__( self.sampling_args: dict = {} self.fidelities = list(self.rung_map.values()) + + self.MFobserved_configs = MFObservedData( + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + auxiliary_cols=["rung"], + ) + # TODO: replace with MFobserved_configs # stores the observations made and the corresponding fidelity explored # crucial data structure used for determining promotion candidates - self.observed_configs = pd.DataFrame([], columns=("config", "rung", "perf")) + self.__max_observed_configs = None + self.history_length = 0 + # self.max_budget_configs = pd.DataFrame([], columns=("config", "rung", "perf")) # stores which configs occupy each rung at any time self.rung_members: dict = dict() # stores config IDs per rung self.rung_members_performance: dict = dict() # performances recorded per rung @@ -155,6 +165,22 @@ def __init__( self._enhance_priors() self.rung_histories = None + @property + def max_budget_configs(self): + """ + Make this property dynamically dependent on self.MFobserved_configs. So the state + of the algo only depends on self.MFobserved_configs. + """ + if self.__max_observed_configs is None or self.history_length != len( + self.MFobserved_configs.df + ): + self.__max_observed_configs = self.MFobserved_configs.copy_df( + df=self.MFobserved_configs.reduce_to_max_seen_budgets() + ) + self.history_length = len(self.MFobserved_configs.df) + + return self.__max_observed_configs + @classmethod def _get_rung_trace(cls, rung_map: dict, config_map: dict) -> list[int]: """Lists the rung IDs in sequence of the flattened SH tree.""" @@ -164,9 +190,13 @@ def _get_rung_trace(cls, rung_map: dict, config_map: dict) -> list[int]: return rung_trace def get_incumbent_score(self): + # budget_perf = self.MFobserved_configs.get_best_performance_for_each_budget() + # y_star = budget_perf[budget_perf.index.max] + + # TODO: replace this with existing method y_star = np.inf # minimizing optimizer - if len(self.observed_configs): - y_star = self.observed_configs.perf.values.min() + if len(self.max_budget_configs): + y_star = self.max_budget_configs.perf.values.min() return y_star def _get_rung_map(self, s: int = 0) -> dict: @@ -219,52 +249,88 @@ def _get_config_id_split(cls, config_id: str) -> tuple[str, str]: def _load_previous_observations( self, previous_results: dict[str, ConfigResult] ) -> None: - for config_id, config_val in previous_results.items(): + def index_data_split(config_id: str, config_val): _config, _rung = self._get_config_id_split(config_id) perf = self.get_loss(config_val.result) - if int(_config) in self.observed_configs.index: - # config already recorded in dataframe - rung_recorded = self.observed_configs.at[int(_config), "rung"] - if rung_recorded < int(_rung): - # config recorded for a lower rung but higher rung eval available - self.observed_configs.at[int(_config), "config"] = config_val.config - self.observed_configs.at[int(_config), "rung"] = int(_rung) - self.observed_configs.at[int(_config), "perf"] = perf - else: - _df = pd.DataFrame( - [[config_val.config, int(_rung), perf]], - columns=self.observed_configs.columns, - index=pd.Series(int(_config)), # key for config_id - ) - self.observed_configs = pd.concat( - (self.observed_configs, _df) - ).sort_index() - # for efficiency, redefining the function to have the - # `rung_histories` assignment inside the for loop - # rung histories are collected only for `previous` and not `pending` configs - self.rung_histories[int(_rung)]["config"].append(int(_config)) - self.rung_histories[int(_rung)]["perf"].append(perf) + index = int(_config), int(_rung) + _data = [config_val.config, perf, int(_rung)] + return index, _data + + if len(previous_results) > 0: + index_row = [ + tuple(index_data_split(config_id, config_val)) + for config_id, config_val in previous_results.items() + ] + indices, rows = zip(*index_row) + self.MFobserved_configs.add_data(data=list(rows), index=list(indices)) + # TODO: replace this with new optimized method + # for config_id, config_val in previous_results.items(): + # _config, _rung = self._get_config_id_split(config_id) + # perf = self.get_loss(config_val.result) + # if int(_config) in self.observed_configs.index: + # # config already recorded in dataframe + # rung_recorded = self.observed_configs.at[int(_config), "rung"] + # if rung_recorded < int(_rung): + # # config recorded for a lower rung but higher rung eval available + # self.observed_configs.at[int(_config), "config"] = config_val.config + # self.observed_configs.at[int(_config), "rung"] = int(_rung) + # self.observed_configs.at[int(_config), "perf"] = perf + # else: + # _df = pd.DataFrame( + # [[config_val.config, int(_rung), perf]], + # columns=self.observed_configs.columns, + # index=pd.Series(int(_config)), # key for config_id + # ) + # self.observed_configs = pd.concat( + # (self.observed_configs, _df) + # ).sort_index() + # # for efficiency, redefining the function to have the + # # `rung_histories` assignment inside the for loop + # # rung histories are collected only for `previous` and not `pending` configs + # self.rung_histories[int(_rung)]["config"].append(int(_config)) + # self.rung_histories[int(_rung)]["perf"].append(perf) return def _handle_pending_evaluations( self, pending_evaluations: dict[str, ConfigResult] ) -> None: + def index_data_split(config_id: str, config_val): + _config, _rung = self._get_config_id_split(config_id) + # perf = self.get_loss(config_val.result) + index = int(_config), int(_rung) + _data = [ + # use `config_val` instead of `config_val.config` + # unlike `previous_results` case + config_val, + np.nan, + int(_rung), + ] + return index, _data + + if len(pending_evaluations) > 0: + index_row = [ + tuple(index_data_split(config_id, config_val)) + for config_id, config_val in pending_evaluations.items() + ] + indices, rows = zip(*index_row) + self.MFobserved_configs.add_data(data=list(rows), index=list(indices)) + # TODO: replace this # iterates over all pending evaluations and updates the list of observed # configs with the rung and performance as None - for config_id, config in pending_evaluations.items(): - _config, _rung = self._get_config_id_split(config_id) - if int(_config) not in self.observed_configs.index: - _df = pd.DataFrame( - [[config, int(_rung), np.nan]], - columns=self.observed_configs.columns, - index=pd.Series(int(_config)), # key for config_id - ) - self.observed_configs = pd.concat( - (self.observed_configs, _df) - ).sort_index() - else: - self.observed_configs.at[int(_config), "rung"] = int(_rung) - self.observed_configs.at[int(_config), "perf"] = np.nan + # for config_id, config in pending_evaluations.items(): + # _config, _rung = self._get_config_id_split(config_id) + # if int(_config) not in self.observed_configs.index: + # _df = pd.DataFrame( + # [[config, int(_rung), np.nan]], + # columns=self.observed_configs.columns, + # index=pd.Series(int(_config)), # key for config_id + # ) + # self.observed_configs = pd.concat( + # (self.observed_configs, _df) + # ).sort_index() + # else: + # self.observed_configs.at[int(_config), "rung"] = int(_rung) + # self.observed_configs.at[int(_config), "perf"] = np.nan return def clean_rung_information(self): @@ -276,7 +342,7 @@ def _get_rungs_state(self, observed_configs=None): """Collects info on configs at a rung and their performance there.""" # to account for incomplete evaluations from being promoted --> working on a copy observed_configs = ( - self.observed_configs.copy().dropna(inplace=False) + self.max_budget_configs.copy().dropna(inplace=False) if observed_configs is None else observed_configs ) @@ -290,6 +356,7 @@ def _get_rungs_state(self, observed_configs=None): # iterates over the list of explored configs and buckets them to respective # rungs depending on the highest fidelity it was evaluated at self.clean_rung_information() + # TODO: create a new method for this for _rung in observed_configs.rung.unique(): idxs = observed_configs.rung == _rung self.rung_members[_rung] = observed_configs.index[idxs].values @@ -331,7 +398,15 @@ def load_results( for rung in range(self.min_rung, self.max_rung + 1) } - self.observed_configs = pd.DataFrame([], columns=("config", "rung", "perf")) + self.MFobserved_configs = MFObservedData( + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + auxiliary_cols=["rung"], + ) + + # self.observed_configs = pd.DataFrame([], columns=("config", "rung", "perf")) # previous optimization run exists and needs to be loaded self._load_previous_observations(previous_results) @@ -340,6 +415,12 @@ def load_results( # account for pending evaluations self._handle_pending_evaluations(pending_evaluations) + # TODO: change this after testing + # Copy data into old format + # self.max_budget_configs = self.MFobserved_configs.copy_df( + # df=self.MFobserved_configs.reduce_to_max_seen_budgets() + # ) + # process optimization state and bucket observations per rung self._get_rungs_state() @@ -374,7 +455,9 @@ def sample_new_config( return config def _generate_new_config_id(self): - return self.observed_configs.index.max() + 1 if len(self.observed_configs) else 0 + return self.MFobserved_configs.next_config_id() + # TODO: replace this with existing + # return self.observed_configs.index.max() + 1 if len(self.observed_configs) else 0 def get_default_configuration(self): pass @@ -403,7 +486,8 @@ def get_config_and_ids( # pylint: disable=no-self-use rung_to_promote = self.is_promotable() if rung_to_promote is not None: # promotes the first recorded promotable config in the argsort-ed rung - row = self.observed_configs.iloc[self.rung_promotions[rung_to_promote][0]] + # TODO: What to do with this? + row = self.max_budget_configs.iloc[self.rung_promotions[rung_to_promote][0]] config = deepcopy(row["config"]) rung = rung_to_promote + 1 # assigning the fidelity to evaluate the config at @@ -417,7 +501,7 @@ def get_config_and_ids( # pylint: disable=no-self-use if ( self.use_priors and self.sample_default_first - and len(self.observed_configs) == 0 + and len(self.max_budget_configs) == 0 ): if self.sample_default_at_target: # sets the default config to be evaluated at the target fidelity @@ -501,15 +585,15 @@ def clear_old_brackets(self): start += 1 end += 1 # iterates over the different SH brackets which span start-end by index - while end <= len(self.observed_configs): + while end <= len(self.max_budget_configs): # for the SH bracket in start-end, calculate total SH budget used bracket_budget_used = self._calc_budget_used_in_bracket( - deepcopy(self.observed_configs.rung.values[start:end]) + deepcopy(self.max_budget_configs.rung.values[start:end]) ) # if budget used is less than a SH bracket budget then still an active bracket if bracket_budget_used < sum(self.full_rung_trace): # subsetting only this SH bracket from the history - self._get_rungs_state(self.observed_configs.iloc[start:end]) + self._get_rungs_state(self.max_budget_configs.iloc[start:end]) # extra call to use the updated rung member info to find promotions # SyncPromotion signals a wait if a rung is full but with # incomplete/pending evaluations, and signals to starts a new SH bracket @@ -527,7 +611,7 @@ def clear_old_brackets(self): end = start + self.config_map[self.min_rung] # updates rung info with the latest active, incomplete bracket - self._get_rungs_state(self.observed_configs.iloc[start:end]) + self._get_rungs_state(self.max_budget_configs.iloc[start:end]) # _handle_promotion() need not be called as it is called by load_results() return diff --git a/neps/optimizers/multi_fidelity/utils.py b/neps/optimizers/multi_fidelity/utils.py index f1d359a6..cbc456db 100644 --- a/neps/optimizers/multi_fidelity/utils.py +++ b/neps/optimizers/multi_fidelity/utils.py @@ -1,6 +1,7 @@ # type: ignore from __future__ import annotations +from copy import deepcopy from typing import Any, Sequence import numpy as np @@ -10,6 +11,9 @@ from ...optimizers.utils import map_real_hyperparameters_from_tabular_ids from ...search_spaces.search_space import SearchSpace +# from neps.optimizers.utils import map_real_hyperparameters_from_tabular_ids +# from neps.search_spaces.search_space import SearchSpace + def continuous_to_tabular( config: SearchSpace, categorical_space: SearchSpace @@ -55,34 +59,36 @@ class MFObservedData: default_config_col = "config" default_perf_col = "perf" default_lc_col = "learning_curves" + # TODO: deepcopy all the mutable outputs from the dataframe def __init__( self, - columns: list[str] | None = None, - index_names: list[str] | None = None, + config_id: str | None = None, + budget_id: str | None = None, + config_col: str | None = None, + perf_col: str | None = None, + learning_curve_col: str | None = None, + auxiliary_cols: list[str] | None = None, ): - if columns is None: - columns = [self.default_config_col, self.default_perf_col] - if index_names is None: - index_names = [self.default_config_idx, self.default_budget_idx] + self.config_col = self.default_config_col if config_col is None else config_col + self.perf_col = self.default_perf_col if perf_col is None else perf_col - self.config_col = columns[0] - self.perf_col = columns[1] + self.config_idx = self.default_config_idx if config_id is None else config_id + self.budget_idx = self.default_budget_idx if budget_id is None else budget_id - if len(columns) > 2: - self.lc_col_name = columns[2] - else: - self.lc_col_name = self.default_lc_col + self.lc_col_name = learning_curve_col + + auxiliary_cols = [] if auxiliary_cols is None else auxiliary_cols - if len(index_names) == 1: - index_names += ["budget_id"] + self.index_names = [self.config_idx, self.budget_idx] + col_names = [self.config_col, self.perf_col, self.lc_col_name] + auxiliary_cols + self.columns = [col_name for col_name in col_names if col_name is not None] - self.config_idx = index_names[0] - self.budget_idx = index_names[1] + index = pd.MultiIndex.from_tuples([], names=self.index_names) - index = pd.MultiIndex.from_tuples([], names=index_names) + self.df = pd.DataFrame([], columns=self.columns, index=index) - self.df = pd.DataFrame([], columns=columns, index=index) + self.mutable_columns = [self.config_col, self.lc_col_name] @property def pending_condition(self): @@ -111,6 +117,13 @@ def next_config_id(self) -> int: else: return 0 + @staticmethod + def __validate_index(index_list): + """Extends single indices to multi-index case""" + if all([isinstance(idx, int) for idx in index_list]): + index_list = list(zip(index_list, [0] * len(index_list))) + return index_list + def add_data( self, data: list[Any] | list[list[Any]], @@ -120,7 +133,6 @@ def add_data( """ Add data only if none of the indices are already existing in the DataFrame """ - # TODO: If index is only config_id extend it if not isinstance(index, list): index_list = [index] data_list = [data] @@ -128,9 +140,12 @@ def add_data( index_list = index data_list = data + index_list = self.__validate_index(index_list) + if not self.df.index.isin(index_list).any(): - _df = pd.DataFrame(data_list, columns=self.df.columns, index=index_list) - self.df = pd.concat((self.df, _df)) + index = pd.MultiIndex.from_tuples(index_list, names=self.index_names) + _df = pd.DataFrame(data_list, columns=self.df.columns, index=index) + self.df = _df.copy() if self.df.empty else pd.concat((self.df, _df)) elif error: raise ValueError( f"Data with at least one of the given indices already " @@ -151,6 +166,9 @@ def update_data( index_list = [index] else: index_list = index + + index_list = self.__validate_index(index_list) + if self.df.index.isin(index_list).sum() == len(index_list): column_names, data = zip(*data_dict.items()) data = list(zip(*data)) @@ -163,7 +181,7 @@ def update_data( f"Given indices: {index_list}" ) - def get_learning_curves(self): + def get_trajectories(self): return self.df.pivot_table( index=self.df.index.names[0], columns=self.df.index.names[1], @@ -178,13 +196,13 @@ def get_incumbents_for_budgets(self, maximize: bool = False): Returns a series object with the best partial configuration for each budget id Note: this will always map the best lowest ID if two configurations - has the same performance at the same fidelity + have the same performance at the same fidelity """ - learning_curves = self.get_learning_curves() + trajectories = self.get_trajectories() if maximize: - config_ids = learning_curves.idxmax(axis=0) + config_ids = trajectories.idxmax(axis=0) else: - config_ids = learning_curves.idxmin(axis=0) + config_ids = trajectories.idxmin(axis=0) indices = list(zip(config_ids.values.tolist(), config_ids.index.to_list())) partial_configs = self.df.loc[indices, self.config_col].to_list() @@ -197,14 +215,23 @@ def get_best_performance_for_each_budget(self, maximize: bool = False): Note: this will always map the best lowest ID if two configurations has the same performance at the same fidelity """ - learning_curves = self.get_learning_curves() + trajectories = self.get_trajectories() if maximize: - performance = learning_curves.max(axis=0) + performance = trajectories.max(axis=0) else: - performance = learning_curves.min(axis=0) + performance = trajectories.min(axis=0) return performance + def get_budget_level_for_best_performance(self, maximize: bool = False) -> int: + """Returns the lowest budget level at which the highest performance was recorded.""" + perf_per_z = self.get_best_performance_for_each_budget(maximize=maximize) + y_star = self.get_best_seen_performance(maximize=maximize) + # uses the minimum of the budget that see the maximum obseved score + op = max if maximize else min + z_inc = int(op([_z for _z, _y in perf_per_z.items() if _y == y_star])) + return z_inc + def get_best_learning_curve_id(self, maximize: bool = False): """ Returns a single configuration id of the best observed performance @@ -212,44 +239,89 @@ def get_best_learning_curve_id(self, maximize: bool = False): Note: this will always return the single best lowest ID if two configurations has the same performance """ - learning_curves = self.get_learning_curves() + trajectories = self.get_trajectories() if maximize: - return learning_curves.max(axis=1).idxmax() + return trajectories.max(axis=1).idxmax() else: - return learning_curves.min(axis=1).idxmin() + return trajectories.min(axis=1).idxmin() def get_best_seen_performance(self, maximize: bool = False): - learning_curves = self.get_learning_curves() + trajectories = self.get_trajectories() if maximize: - return learning_curves.max(axis=1).max() + return trajectories.max(axis=1).max() else: - return learning_curves.min(axis=1).min() + return trajectories.min(axis=1).min() def add_budget_column(self): - combined_df = self.df.reset_index(level=1) - combined_df.set_index( - keys=[self.budget_idx], drop=False, append=True, inplace=True - ) - return combined_df + pass + # budget_column = self.df.index.get_level_values(1) + # self.df[self.budget_idx] = budget_column + # combined_df = self.df.reset_index(level=1) + # combined_df.set_index( + # keys=[self.budget_idx], drop=False, append=True, inplace=True + # ) + # return combined_df + + def copy_df(self, df: pd.DataFrame | None = None): + """ + Use this function to copy df if you are going to + perform some operations on its elements. + + DataFrames are not meant for mutable data-types, + nevertheless we do put mutable SearchSpace objects into the config_col of the DF + In order not to change the values stored objects in the DF we deepcopy all + mutable columns here. + + self.mutable_columns must keep track of + the mutable columns at all times. + """ + if df is None: + df = self.df + new_df = pd.DataFrame() + new_df.index = df.index.copy(deep=True) + + for column in df.columns: + if column in self.mutable_columns: + new_column = [deepcopy(value) for value in df[column].values] + new_df[column] = new_column + else: + new_df[column] = df[column].copy(deep=True) + + return new_df def reduce_to_max_seen_budgets(self): self.df.sort_index(inplace=True) - combined_df = self.add_budget_column() + budget_column = self.df.index.get_level_values(1) + combined_df = self.df.copy(deep=True) + combined_df[self.budget_idx] = budget_column + # combined_df = self.copy_df(df=combined_df) return combined_df.groupby(level=0).last() def get_partial_configs_at_max_seen(self): return self.reduce_to_max_seen_budgets()[self.config_col] - def extract_learning_curve(self, config_id: int, budget_id: int) -> list[float]: + def extract_learning_curve( + self, config_id: int, budget_id: int | None = None + ) -> list[float]: + if budget_id is None: + # budget_id only None when predicting + # extract full observed learning curve for prediction pipeline + budget_id = ( + max(self.df.loc[config_id].index.get_level_values("budget_id").values) + 1 + ) + + # For the first epoch we have no learning curve available + if budget_id == 0: + return [] # reduce budget_id to discount the current validation loss # both during training and prediction phase budget_id = max(0, budget_id - 1) if self.lc_col_name in self.df.columns: lc = self.df.loc[(config_id, budget_id), self.lc_col_name] else: - lcs = self.get_learning_curves() - lc = lcs.loc[config_id, :budget_id].values.flatten().tolist() - return lc + trajectories = self.get_trajectories() + lc = trajectories.loc[config_id, :budget_id].values.flatten().tolist() + return deepcopy(lc) def get_training_data_4DyHPO( self, df: pd.DataFrame, pipeline_space: SearchSpace | None = None @@ -270,6 +342,29 @@ def get_training_data_4DyHPO( learning_curves.append(self.extract_learning_curve(config_id, budget_id)) return configs, learning_curves, performance + def get_best_performance_per_config(self, maximize: bool = False) -> pd.Series: + """Returns the best score recorded per config across fidelities seen.""" + op = np.max if maximize else np.min + perf = ( + self.df.sort_values( + "budget_id", ascending=False + ) # sorts with largest budget first + .groupby("config_id") # retains only config_id + .first() # retrieves the largest budget seen for each config_id + .learning_curves.apply( # extracts all values seen till largest budget for a config + op + ) # finds the minimum over per-config learning curve + ) + return perf + + def get_max_observed_fidelity_level_per_config(self) -> pd.Series: + """Returns the highest fidelity level recorded per config seen.""" + max_z_observed = { + _id: self.df.loc[_id, :].index.sort_values()[-1] + for _id in self.df.index.get_level_values("config_id").sort_values() + } + return pd.Series(max_z_observed) + def get_tokenized_data(self, df: pd.DataFrame): idxs = df.index.values idxs = np.array([list(idx) for idx in idxs]) @@ -300,48 +395,93 @@ def token_ids(self) -> np.ndarray: if __name__ == "__main__": # TODO: Either delete these or convert them to tests (karibbov) + + def multi_index_parallel(): + data = MFObservedData( + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + ) + + # When adding multiple indices data should be list of rows(lists) and the index should be list of tuples + data.add_data( + [["conf1", 0.5], ["conf2", 0.7], ["conf1", 0.6], ["conf2", 0.4]], + index=[(0, 0), (1, 1), (0, 3), (1, 0)], + ) + # print(data.df) + + data.add_data( + [["conf1", 0.5], ["conf2", 0.10], ["conf1", 0.11]], + index=[(0, 2), (1, 2), (0, 1)], + ) + + print(data.df) + # print(data.get_trajectories()) + # print( + # "Mapping of budget IDs into best performing configurations at each fidelity:\n", + # data.get_incumbents_for_budgets(), + # ) + # print( + # "Best Performance at each budget level:\n", + # data.get_best_performance_for_each_budget(), + # ) + # print( + # "Configuration ID of the best observed performance so far: ", + # data.get_best_learning_curve_id(), + # ) + # print(data.extract_learning_curve(0, 2)) + # # data.df.sort_index(inplace=True) + # print(data.get_partial_configs_at_max_seen()) + # + # # When updating multiple indices at a time both the values in the data dictionary and the indices should be lists + data.update_data({"perf": [1.8, 1.5]}, index=[(1, 1), (0, 0)]) + print(data.df) + + def multi_index_single(): + data = MFObservedData( + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + ) + + # when adding a single row second level list is not necessary + data.add_data(["conf1", 0.5], index=(0, 0)) + data.add_data(["conf1", 0.8], index=(1, 0)) + print(data.df) + + data.update_data({"perf": [1.8], "budget_col": [5]}, index=(0, 0)) + print(data.df) + + def single_index_parallel(): + data = MFObservedData( + config_id="config_id", + budget_id="budget_id", + config_col="config", + perf_col="perf", + ) + + # When adding multiple indices data should be list of rows(lists) and the index should be list of tuples + data.add_data( + [["conf1", 0.5], ["conf2", 0.7]], + index=[(0), (1)], + ) + print(data.df) + + data.add_data( + [["conf1", 0.5], ["conf2", 0.10]], + index=[(2), (3)], + ) + + print(data.df) + + data.update_data({"perf": [1.8, 1.5]}, index=[(1), (0)]) + print(data.df) + """ Here are a few examples of how to manage data with this class: """ - data = MFObservedData(["config", "perf"], index_names=["config_id", "budget_id"]) - - # When adding multiple indices data should be list of rows(lists) and the index should be list of tuples - data.add_data( - [["conf1", 0.5], ["conf2", 0.7], ["conf1", 0.6], ["conf2", 0.4]], - index=[(0, 0), (1, 1), (0, 3), (1, 0)], - ) - data.add_data( - [["conf1", 0.5], ["conf2", 0.10], ["conf1", 0.11]], - index=[(0, 2), (1, 2), (0, 1)], - ) - - print(data.df) - print(data.get_learning_curves()) - print( - "Mapping of budget IDs into best performing configurations at each fidelity:\n", - data.get_incumbents_for_budgets(), - ) - print( - "Best Performance at each budget level:\n", - data.get_best_performance_for_each_budget(), - ) - print( - "Configuration ID of the best observed performance so far: ", - data.get_best_learning_curve_id(), - ) - print(data.extract_learning_curve(0, 2)) - # data.df.sort_index(inplace=True) - print(data.get_partial_configs_at_max_seen()) - - # When updating multiple indices at a time both the values in the data dictionary and the indices should be lists - data.update_data({"perf": [1.8, 1.5]}, index=[(1, 1), (0, 0)]) - print(data.df) - - data = MFObservedData(["config", "perf"], index_names=["config_id", "budget_id"]) - - # when adding a single row second level list is not necessary - data.add_data(["conf1", 0.5], index=(0, 0)) - print(data.df) - - data.update_data({"perf": [1.8], "budget_col": [5]}, index=(0, 0)) - print(data.df) + single_index_parallel() + multi_index_single() + single_index_parallel() diff --git a/neps/optimizers/multi_fidelity_prior/async_priorband.py b/neps/optimizers/multi_fidelity_prior/async_priorband.py index c932d45d..df58b5a3 100644 --- a/neps/optimizers/multi_fidelity_prior/async_priorband.py +++ b/neps/optimizers/multi_fidelity_prior/async_priorband.py @@ -1,9 +1,9 @@ from __future__ import annotations import typing +from typing import Literal import numpy as np -from typing_extensions import Literal from ...metahyper import ConfigResult from ...search_spaces.search_space import SearchSpace @@ -238,7 +238,7 @@ def _update_sh_bracket_state(self) -> None: config_map=bracket.config_map, ) bracket.rung_promotions = bracket.promotion_policy.retrieve_promotions() - bracket.observed_configs = self.observed_configs.copy() + bracket.max_budget_configs = self.max_budget_configs.copy() bracket.rung_histories = self.rung_histories def load_results( diff --git a/neps/optimizers/multi_fidelity_prior/priorband.py b/neps/optimizers/multi_fidelity_prior/priorband.py index eca50ebe..ba66d98f 100644 --- a/neps/optimizers/multi_fidelity_prior/priorband.py +++ b/neps/optimizers/multi_fidelity_prior/priorband.py @@ -3,9 +3,9 @@ from __future__ import annotations import typing +from typing import Literal import numpy as np -from typing_extensions import Literal from ...search_spaces.search_space import SearchSpace from ..bayesian_optimization.acquisition_functions.base_acquisition import BaseAcquisition @@ -34,7 +34,7 @@ def find_all_distances_from_incumbent(self, incumbent): """Finds the distance to the nearest neighbour.""" dist = lambda x: compute_config_dist(incumbent, x) # computing distance of incumbent from all seen points in history - distances = [dist(config) for config in self.observed_configs.config] + distances = [dist(config) for config in self.max_budget_configs.config] # ensuring the distances exclude 0 or the distance from itself distances = [d for d in distances if d > 0] return distances @@ -47,14 +47,14 @@ def find_1nn_distance_from_incumbent(self, incumbent): def find_incumbent(self, rung: int = None) -> SearchSpace: """Find the best performing configuration seen so far.""" - rungs = self.observed_configs.rung.values - idxs = self.observed_configs.index.values + rungs = self.max_budget_configs.rung.values + idxs = self.max_budget_configs.index.values while rung is not None: # enters this scope is `rung` argument passed and not left empty or None if rung not in rungs: self.logger.warn(f"{rung} not in {np.unique(idxs)}") # filtering by rung based on argument passed - idxs = self.observed_configs.rung.values == rung + idxs = self.max_budget_configs.rung.values == rung # checking width of current rung if len(idxs) < self.eta: self.logger.warn( @@ -63,9 +63,9 @@ def find_incumbent(self, rung: int = None) -> SearchSpace: # extracting the incumbent configuration if len(idxs): # finding the config with the lowest recorded performance - _perfs = self.observed_configs.loc[idxs].perf.values + _perfs = self.max_budget_configs.loc[idxs].perf.values inc_idx = np.nanargmin([np.nan if t is None else t for t in _perfs]) - inc = self.observed_configs.loc[idxs].iloc[inc_idx].config + inc = self.max_budget_configs.loc[idxs].iloc[inc_idx].config else: # THIS block should not ever execute, but for runtime anomalies, if no # incumbent can be extracted, the prior is treated as the incumbent @@ -126,7 +126,9 @@ def is_activate_inc(self) -> bool: resources += bracket.config_map[rung] * continuation_resources # find resources spent so far for all finished evaluations - resources_used = calc_total_resources_spent(self.observed_configs, self.rung_map) + resources_used = calc_total_resources_spent( + self.max_budget_configs, self.rung_map + ) if resources_used >= resources and len( self.rung_histories[self.max_rung]["config"] @@ -190,7 +192,7 @@ def prior_to_incumbent_ratio(self) -> float | float: if self.inc_style == "constant": return self._prior_to_incumbent_ratio_constant() elif self.inc_style == "decay": - resources = calc_total_resources_spent(self.observed_configs, self.rung_map) + resources = calc_total_resources_spent(self.max_budget_configs, self.rung_map) return self._prior_to_incumbent_ratio_decay( resources, self.eta, self.min_budget, self.max_budget ) @@ -244,7 +246,7 @@ def _prior_to_incumbent_ratio_dynamic(self, rung: int) -> float | float: [ # `compute_scores` returns a tuple of scores resp. by prior and inc compute_scores( - self.observed_configs.loc[config_id].config, prior, inc + self.max_budget_configs.loc[config_id].config, prior, inc ) for config_id in top_configs ] diff --git a/neps/optimizers/utils.py b/neps/optimizers/utils.py index c203f4db..32fa87fa 100644 --- a/neps/optimizers/utils.py +++ b/neps/optimizers/utils.py @@ -3,45 +3,31 @@ from ..search_spaces.search_space import SearchSpace -# def map_real_hyperparameters_from_tabular_ids( -# ids: pd.Series, pipeline_space: SearchSpace -# ) -> pd.Series: -# return x - - def map_real_hyperparameters_from_tabular_ids( x: pd.Series, pipeline_space: SearchSpace ) -> pd.Series: - """ Maps the tabular IDs to the actual HPs from the pipeline space. - + """Maps the tabular IDs to the actual HPs from the pipeline space. + Args: x (pd.Series): A pandas series with the tabular IDs. TODO: Mention expected format of the series. pipeline_space (SearchSpace): The pipeline space. - Returns: + Returns: pd.Series: A pandas series with the actual HPs. TODO: Mention expected format of the series. """ if len(x) == 0: return x - # extract fid name - _x = x.iloc[0].hp_values() - _x.pop("id") - fid_name = list(_x.keys())[0] - for i in x.index.values: - # extracting actual HPs from the tabular space - _config = pipeline_space.custom_grid_table.loc[x.loc[i]["id"].value].to_dict() - # updating fidelities as per the candidate set passed - _config.update({fid_name: x.loc[i][fid_name].value}) - # placeholder config from the raw tabular space - config = pipeline_space.raw_tabular_space.sample( - patience=100, - user_priors=True, - ignore_fidelity=True # True allows fidelity to appear in the sample - ) - # copying values from table to placeholder config of type SearchSpace - config.load_from(_config) - # replacing the ID in the candidate set with the actual HPs of the config - x.loc[i] = config - return x + # copying hyperparameter configs based on IDs + _x = pd.Series( + [ + pipeline_space.custom_grid_table[x.loc[idx]["id"].value] + for idx in x.index.values + ], + index=x.index, + ) + # setting the passed fidelities for the corresponding IDs + for idx in _x.index.values: + _x.loc[idx].fidelity.value = x.loc[idx].fidelity.value + return _x diff --git a/neps/plot/tensorboard_eval.py b/neps/plot/tensorboard_eval.py index 6463563e..9e4b7393 100644 --- a/neps/plot/tensorboard_eval.py +++ b/neps/plot/tensorboard_eval.py @@ -484,7 +484,7 @@ def log( curve on tensorboard (default: True) writer_config_hparam (bool, optional): Write hyperparameters logging of the configs (default: True). - write_summary_incumbent (bool, optional): Set to `True` for a live + write_summary_incumbent (bool, optional): Set to `True` for a live incumbent trajectory. extra_data (dict, optional): Additional experiment data for logging. """ diff --git a/neps/search_spaces/architecture/graph.py b/neps/search_spaces/architecture/graph.py index 42b2c388..d6a58124 100644 --- a/neps/search_spaces/architecture/graph.py +++ b/neps/search_spaces/architecture/graph.py @@ -6,13 +6,13 @@ import sys import types from collections import Counter +from pathlib import Path from typing import Callable from typing import Counter as CounterType import networkx as nx import torch from networkx.algorithms.dag import lexicographical_topological_sort -from pathlib import Path from torch import nn from ...utils.common import AttrDict diff --git a/neps/search_spaces/hyperparameters/integer.py b/neps/search_spaces/hyperparameters/integer.py index 32789251..122a3514 100644 --- a/neps/search_spaces/hyperparameters/integer.py +++ b/neps/search_spaces/hyperparameters/integer.py @@ -2,7 +2,7 @@ from copy import deepcopy -from typing_extensions import Literal +from typing import Literal from .float import FloatParameter @@ -33,10 +33,10 @@ def __init__( def __repr__(self): return f"" - + def load_from(self, value): super().load_from(int(value)) - + def _set_float_hp_val(self): # IMPORTANT function to call wherever `self.float_hp` is used in this class self.float_hp.value = None if self.value is None else float(self.value) diff --git a/neps/search_spaces/search_space.py b/neps/search_spaces/search_space.py index 0280f14c..c0a6b36b 100644 --- a/neps/search_spaces/search_space.py +++ b/neps/search_spaces/search_space.py @@ -247,6 +247,20 @@ def set_custom_grid_space( "hyperparameter for accurate modeling." ) self.has_tabular = True + # Updating `custom_grid_table` as a map for quick lookup with placeholder fidelity + placeholder_config = self.raw_tabular_space.sample( + ignore_fidelity=True + ) # sets fidelity as None + # `placeholder_config` allows to store map values as NePS SearchSpace type + # and also create a placeholder for fideity value + _map = { + idx: deepcopy(placeholder_config) + for idx in self.custom_grid_table.index.values + } + _ = [ + v.load_from(self.custom_grid_table.loc[k].to_dict()) for k, v in _map.items() + ] + self.custom_grid_table = _map @property def has_fidelity(self): @@ -483,7 +497,14 @@ def load_from(self, config: dict): self.hyperparameters[name].load_from(config[name]) def copy(self): - return deepcopy(self) + _copy = deepcopy(self) + + if _copy.has_tabular: + # each configuration does not need to carry the tabular data + _copy.has_tabular = False + _copy.custom_grid_table = None + _copy.raw_tabular_space = None + return _copy def sample_default_configuration( self, patience: int = 1, ignore_fidelity=True, ignore_missing_defaults=False diff --git a/neps/search_spaces/yaml_search_space_utils.py b/neps/search_spaces/yaml_search_space_utils.py index e0efe616..d3ed7041 100644 --- a/neps/search_spaces/yaml_search_space_utils.py +++ b/neps/search_spaces/yaml_search_space_utils.py @@ -3,8 +3,9 @@ import re -def convert_scientific_notation(value: str | int | float, show_usage_flag=False) \ - -> float | (float, bool): +def convert_scientific_notation( + value: str | int | float, show_usage_flag=False +) -> float | (float, bool): """ Convert a given value to a float if it's a string that matches scientific e notation. This is especially useful for numbers like "3.3e-5" which YAML parsers may not diff --git a/neps/status/status.py b/neps/status/status.py index e7bc99d3..247dfbe7 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -294,7 +294,9 @@ def _save_data_to_csv( run_data_df.index == "num_evaluated_configs", "value" ] # checks if the current worker has more evaluated configs than the previous - if int(num_evaluated_configs_csv) < num_evaluated_configs_run.iloc[0]: + if int(num_evaluated_configs_csv) < int( + num_evaluated_configs_run.iloc[0] + ): config_data_df = config_data_df.sort_values( by="result.loss", ascending=True ) @@ -319,6 +321,8 @@ def _save_data_to_csv( def post_run_csv(root_directory: str | Path, logger=None) -> None: + root_directory = Path(root_directory) + if logger is None: logger = logging.getLogger("neps_status") diff --git a/neps/utils/common.py b/neps/utils/common.py index bcc3066e..dad37e37 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -1,6 +1,7 @@ from __future__ import annotations import glob +import json import os import random from pathlib import Path @@ -285,3 +286,55 @@ class AttrDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self + + +class DataWriter: + """ + A class to specify how to save/write a data to the folder by + implementing your own write_data function. + Use the set_attributes function to set all your necessary attributes and the data + and then write_data will be called with only the directory path as argument + during the write process + """ + + def __init__(self, name: str): + self.name = name + + def set_attributes(self, attribute_dict: dict[str, Any]): + for attribute_name, attribute in attribute_dict.items(): + setattr(self, attribute_name, attribute) + + def write_data(self, to_directory: Path): + raise NotImplementedError + + +class EvaluationData: + """ + A class to store some data for a single evaluation (configuration) + and write that data to its corresponding config folder + """ + + def __init__(self): + self.data_dict: dict[str, DataWriter] = {} + + def write_all(self, directory: Path): + for _, data_writer in self.data_dict.items(): + data_writer.write_data(directory) + + +class SimpleCSVWriter(DataWriter): + def write_data(self, to_directory: Path): + # self.df: pd.DataFrame = pd.DataFrame() + path = to_directory / str(self.name + ".csv") + self.df.to_csv(path, float_format="%g") + + +class SimpleJSONWriter(DataWriter): + def __init__(self): + self.data: dict[str, Any] = {} + + def write_data(self, to_directory: Path): + # self.df: pd.DataFrame = pd.DataFrame() + path = to_directory / str(self.name + ".json") + with open(path, "w") as file: + json.dump(self.data, file) diff --git a/neps_examples/basic_usage/architecture_and_hyperparameters.py b/neps_examples/basic_usage/architecture_and_hyperparameters.py index e0b63fe4..53a65e23 100644 --- a/neps_examples/basic_usage/architecture_and_hyperparameters.py +++ b/neps_examples/basic_usage/architecture_and_hyperparameters.py @@ -117,6 +117,7 @@ def run_pipeline(**config): neps.run( run_pipeline=run_pipeline, pipeline_space=pipeline_space, + searcher="random_search", root_directory="results/hyperparameters_architecture_example", max_evaluations_total=15, ) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index a70cc494..c435cce8 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -57,13 +57,13 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +from neps.plot.tensorboard_eval import tblogger from torch.optim import lr_scheduler from torch.utils.data.dataloader import DataLoader from torch.utils.data.sampler import SubsetRandomSampler from torchvision.transforms import transforms import neps -from neps.plot.tensorboard_eval import tblogger """ Steps for a successful training pipeline: diff --git a/neps_examples/experimental/hierarchical_architecture_hierarchical_GP.py b/neps_examples/experimental/hierarchical_architecture_hierarchical_GP.py index 3db93bde..15ca9c9e 100644 --- a/neps_examples/experimental/hierarchical_architecture_hierarchical_GP.py +++ b/neps_examples/experimental/hierarchical_architecture_hierarchical_GP.py @@ -3,13 +3,13 @@ import logging import time +from neps.optimizers.bayesian_optimization.models.gp_hierarchy import ( + ComprehensiveGPHierarchy, +) from torch import nn import neps from neps.optimizers.bayesian_optimization.kernels import GraphKernelMapping -from neps.optimizers.bayesian_optimization.models.gp_hierarchy import ( - ComprehensiveGPHierarchy, -) from neps.search_spaces.architecture import primitives as ops from neps.search_spaces.architecture import topologies as topos diff --git a/neps_examples/template/lightning_template.py b/neps_examples/template/lightning_template.py index 9c674fc4..b91f856a 100644 --- a/neps_examples/template/lightning_template.py +++ b/neps_examples/template/lightning_template.py @@ -37,9 +37,9 @@ import torch from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger +from neps.utils.common import get_initial_directory, load_lightning_checkpoint import neps -from neps.utils.common import get_initial_directory, load_lightning_checkpoint logger = logging.getLogger("neps_template.run") diff --git a/neps_examples/template/priorband_template.py b/neps_examples/template/priorband_template.py index a8bd8f3c..da47feaf 100644 --- a/neps_examples/template/priorband_template.py +++ b/neps_examples/template/priorband_template.py @@ -31,9 +31,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from neps.utils.common import load_checkpoint, save_checkpoint import neps -from neps.utils.common import load_checkpoint, save_checkpoint logger = logging.getLogger("neps_template.run") diff --git a/pyproject.toml b/pyproject.toml index 66be36bb..2867b998 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ packages = [ { include = "neps_examples" }, ] - [tool.poetry.dependencies] python = ">=3.8,<3.12" ConfigSpace = "^0.7" diff --git a/tests/test_neps_api/testing_scripts/baseoptimizer_neps.py b/tests/test_neps_api/testing_scripts/baseoptimizer_neps.py index 1fe9a219..d1634267 100644 --- a/tests/test_neps_api/testing_scripts/baseoptimizer_neps.py +++ b/tests/test_neps_api/testing_scripts/baseoptimizer_neps.py @@ -1,10 +1,11 @@ import logging -import neps from neps.optimizers.bayesian_optimization.optimizer import BayesianOptimization from neps.optimizers.multi_fidelity.hyperband import Hyperband from neps.search_spaces.search_space import SearchSpace +import neps + pipeline_space_fidelity = dict( val1=neps.FloatParameter(lower=-10, upper=10), val2=neps.IntegerParameter(lower=1, upper=5, is_fidelity=True), diff --git a/tests/test_neps_api/testing_yaml/optimizer_test.yaml b/tests/test_neps_api/testing_yaml/optimizer_test.yaml index cad2221d..96c585a7 100644 --- a/tests/test_neps_api/testing_yaml/optimizer_test.yaml +++ b/tests/test_neps_api/testing_yaml/optimizer_test.yaml @@ -9,4 +9,4 @@ searcher_kwargs: # Specific arguments depending on the searcher random_interleave_prob: 0.1 disable_priors: false prior_confidence: high - sample_default_first: false \ No newline at end of file + sample_default_first: false