From 2d029cf41837e0a52d1a5b7630210fc8f23360fc Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Mon, 7 Aug 2023 08:14:42 +0200 Subject: [PATCH] tensorboard functionality --- .../convenience/neps_tblogger_tutorial.py | 424 ++++++++++++++ .../experimental/tensorboard_eval.py | 387 ------------- pyproject.toml | 4 + src/metahyper/api.py | 12 + src/neps/api.py | 7 + src/neps/plot/tensorboard_eval.py | 537 ++++++++++++++++++ 6 files changed, 984 insertions(+), 387 deletions(-) create mode 100644 neps_examples/convenience/neps_tblogger_tutorial.py delete mode 100644 neps_examples/experimental/tensorboard_eval.py create mode 100644 src/neps/plot/tensorboard_eval.py diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py new file mode 100644 index 00000000..24fbc108 --- /dev/null +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -0,0 +1,424 @@ +""" +NePS tblogger With TensorBoard +==================================== +This tutorial demonstrates how to use TensorBoard plugin with NePS tblogger class +to detect performance data of the different model configurations during training. + + +Setup +----- +To install ``torchvision`` and ``tensorboard`` use the following command: + +.. code-block:: + + pip install torchvision + +""" +import argparse +import logging +import os +import random +import shutil +import time +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torch.optim import lr_scheduler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +from torchvision.transforms import transforms + +import neps +from neps.plot.tensorboard_eval import tblogger + +""" +Steps: + +#1 Define the seeds for reproducibility. +#2 Prepare the input data. +#3 Design the model. +#4 Design the pipeline search spaces. +#5 Design the run pipeline function. +#6 Use neps.run the run the entire search using your specified searcher. + +""" + +############################################################# +# Definig the seeds for reproducibility + + +def set_seed(seed=123): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +############################################################# +# Prepare the input data. For this tutorial we use the MNIST dataset. + + +def MNIST( + batch_size: int = 32, n_train: int = 8192, n_valid: int = 1024 +) -> Tuple[DataLoader, DataLoader, DataLoader]: + train_dataset = torchvision.datasets.MNIST( + root="./data", train=True, transform=transforms.ToTensor(), download=True + ) + test_dataset = torchvision.datasets.MNIST( + root="./data", train=False, transform=transforms.ToTensor(), download=True + ) + + train_sampler = SubsetRandomSampler(range(n_train)) + valid_sampler = SubsetRandomSampler(range(n_train, n_train + n_valid)) + train_dataloader = DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler + ) + val_dataloader = DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler + ) + test_dataloader = DataLoader( + dataset=test_dataset, batch_size=batch_size, shuffle=False + ) + + return train_dataloader, val_dataloader, test_dataloader + + +############################################################# +# Design small MLP model to be able to represent the input data. + + +class MLP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.relu = nn.ReLU() + self.linear1 = nn.Linear(in_features=784, out_features=392) + self.linear2 = nn.Linear(in_features=392, out_features=196) + self.linear3 = nn.Linear(in_features=196, out_features=10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.relu(self.linear1(x)) + x = self.relu(self.linear2(x)) + x = self.linear3(x) + + return x + + +############################################################# +# Define the training step and return the validation error and misclassified images. + + +def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for x, y in data_loader: + output = model(x) + _, predicted = torch.max(output.data, 1) + correct += (predicted == y).sum().item() + total += y.size(0) + + accuracy = correct / total + return 1 - accuracy + + +def training(model, optimizer, criterion, train_loader, validation_loader): + """ + Function that trains the model for one epoch and evaluates the model on the validation set. Used by the searcher. + + Args: + model (nn.Module): Model to be trained. + optimizer (torch.nn.optim): Optimizer used to train the weights (depends on the pipeline space). + criterion (nn.modules.loss) : Loss function to use. + train_loader (torch.utils.Dataloader): Data loader containing the training data. + validation_loader (torch.utils.Dataloader): Data loader containing the validation data. + + Returns: + (float) validation error for the epoch. + """ + incorrect_images = [] + model.train() + for x, y in train_loader: + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + predicted_labels = torch.argmax(output, dim=1) + incorrect_mask = predicted_labels != y + incorrect_images.append(x[incorrect_mask]) + + validation_loss = loss_ev(model, validation_loader) + + if len(incorrect_images) > 0: + incorrect_images = torch.cat(incorrect_images, dim=0) + + return validation_loss, incorrect_images + + +############################################################# +# Design the pipeline search spaces. + + +# For BO: +def pipeline_space_BO() -> dict: + pipeline = dict( + lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), + optim=neps.CategoricalParameter(choices=["Adam", "SGD"]), + weight_decay=neps.FloatParameter(lower=1e-4, upper=1e-1, log=True), + ) + + return pipeline + + +# For Hyperband +def pipeline_space_Hyperband() -> dict: + pipeline = dict( + lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), + optim=neps.CategoricalParameter(choices=["Adam", "SGD"]), + weight_decay=neps.FloatParameter(lower=1e-4, upper=1e-1, log=True), + epochs=neps.IntegerParameter(lower=1, upper=9, is_fidelity=True), + ) + + return pipeline + + +############################################################# +# Implement the pipeline run search. + + +# For BO: +def run_pipeline_BO(lr, optim, weight_decay): + model = MLP() + + if optim == "Adam": + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + elif optim == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) + + max_epochs = 9 + + train_loader, validation_loader, test_loader = MNIST( + batch_size=64, n_train=4096, n_valid=512 + ) + + scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) + + criterion = nn.CrossEntropyLoss() + losses = [] + + tblogger.disable(False) + + for i in range(max_epochs): + loss, miss_img = training( + optimizer=optimizer, + model=model, + criterion=criterion, + train_loader=train_loader, + validation_loader=validation_loader, + ) + losses.append(loss) + + tblogger.log( + loss=loss, + current_epoch=i, + data={ + "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), + "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), + "layer_gradient": tblogger.layer_gradient_logging(model=model), + }, + ) + + scheduler.step() + + print(f" Epoch {i + 1} / {max_epochs} Val Error: {loss} ") + + train_accuracy = loss_ev(model, train_loader) + test_accuracy = loss_ev(model, test_loader) + + return { + "loss": loss, + "info_dict": { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + "val_errors": losses, + "cost": max_epochs, + }, + } + + +# For Hyperband +def run_pipeline_Hyperband(pipeline_directory, previous_pipeline_directory, **configs): + model = MLP() + checkpoint_name = "checkpoint.pth" + start_epoch = 0 + + train_loader, validation_loader, test_loader = MNIST( + batch_size=32, n_train=4096, n_valid=512 + ) + + # define loss + criterion = nn.CrossEntropyLoss() + + # Define the optimizer + if configs["optim"] == "Adam": + optimizer = torch.optim.Adam( + model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] + ) + elif configs["optim"] == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] + ) + + scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) + + # We make use of checkpointing to resume training models on higher fidelities + if previous_pipeline_directory is not None: + # Read in state of the model after the previous fidelity rung + checkpoint = torch.load(previous_pipeline_directory / checkpoint_name) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epochs_previously_spent = checkpoint["epoch"] + else: + epochs_previously_spent = 0 + + start_epoch += epochs_previously_spent + + losses = list() + + tblogger.disable(False) + + epochs = configs["epochs"] + + for epoch in range(start_epoch, epochs): + # Call the training function, get the validation errors and append them to val errors + loss, miss_img = training( + model, optimizer, criterion, train_loader, validation_loader + ) + losses.append(loss) + + tblogger.log( + loss=loss, + current_epoch=epoch, + hparam_accuracy_mode=True, + data={ + "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), + "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), + "layer_gradient": tblogger.layer_gradient_logging(model=model), + }, + ) + + scheduler.step() + + print(f" Epoch {epoch + 1} / {epochs} Val Error: {loss} ") + + train_accuracy = loss_ev(model, train_loader) + test_accuracy = loss_ev(model, test_loader) + + torch.save( + { + "epoch": epochs, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + pipeline_directory / checkpoint_name, + ) + + return { + "loss": loss, + "info_dict": { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + "val_errors": losses, + "cost": epochs - epochs_previously_spent, + }, + "cost": epochs - epochs_previously_spent, + } + + +############################################################# +""" +Defining the main with argument parsing to use either BO or Hyperband and specifying their +respective properties +""" + +if __name__ == "__main__": + argParser = argparse.ArgumentParser() + argParser.add_argument( + "--searcher", + type=str, + choices=["bayesian_optimization", "hyperband"], + default="bayesian_optimization", + help="Searcher type used", + ) + argParser.add_argument( + "--max_cost_total", type=int, default=30, help="Max cost used for Hyperband" + ) + argParser.add_argument( + "--max_evaluations_total", type=int, default=10, help="Max evaluation used for BO" + ) + args = argParser.parse_args() + + if args.searcher == "hyperband": + start_time = time.time() + set_seed(112) + logging.basicConfig(level=logging.INFO) + if os.path.exists("results/hyperband"): + shutil.rmtree("results/hyperband") + neps.run( + run_pipeline=run_pipeline_Hyperband, + pipeline_space=pipeline_space_Hyperband(), + root_directory="hyperband", + max_cost_total=args.max_cost_total, + searcher="hyperband", + ) + + """ + To check live plots during this command run, please open a new terminal with the directory of this saved project and run + + tensorboard --logdir hyperband + """ + + end_time = time.time() # Record the end time + execution_time = end_time - start_time + print(f"Execution time: {execution_time} seconds") + + elif args.searcher == "bayesian_optimization": + start_time = time.time() + set_seed(112) + logging.basicConfig(level=logging.INFO) + if os.path.exists("results/bayesian_optimization"): + shutil.rmtree("results/bayesian_optimization") + neps.run( + run_pipeline=run_pipeline_BO, + pipeline_space=pipeline_space_BO(), + root_directory="bayesian_optimization", + max_evaluations_total=args.max_evaluations_total, + searcher="bayesian_optimization", + ) + + """ + To check live plots during this command run, please open a new terminal with the directory of this saved project and run + + tensorboard --logdir bayesian_optimization + """ + + end_time = time.time() # Record the end time + execution_time = end_time - start_time + print(f"Execution time: {execution_time} seconds") + + """ + When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations + of 9 epochs each: + + python neps_tblogger_tutorial.py + + + If you wish to do this run with hyperband searcher with default max cost total of 30. Please run this command on the terminal: + + python neps_tblogger_tutorial.py --searcher hyperband + """ diff --git a/neps_examples/experimental/tensorboard_eval.py b/neps_examples/experimental/tensorboard_eval.py deleted file mode 100644 index d5667eac..00000000 --- a/neps_examples/experimental/tensorboard_eval.py +++ /dev/null @@ -1,387 +0,0 @@ -import math -import random -import warnings -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -from torch.utils.tensorboard import SummaryWriter -from torch.utils.tensorboard.summary import hparams - - -# Inherit from class and change to fit purpose: -class SummaryWriter_(SummaryWriter): - """ - This function before the update used to create another subfolder inside the logdir and then create further 'tfevent' - which makes everything else uneasy to differentiate and hence this gives the same result with a much easier way and logs - everything on the same 'tfevent' as for other functions. - In addition, a change in the metric dictiornay was made for the cause of making the printed 'Loss' or 'Accuracy' display on the - Summary file - """ - - def add_hparams(self, hparam_dict, metric_dict, global_step): - if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict): - raise TypeError("hparam_dict and metric_dict should be dictionary.") - updated_metric = {} - for key, value in metric_dict.items(): - updated_key = "Summary" + "/" + key - updated_metric[updated_key] = value - exp, ssi, sei = hparams(hparam_dict, updated_metric) - - self.file_writer.add_summary(exp) - self.file_writer.add_summary(ssi) - self.file_writer.add_summary(sei) - for k, v in updated_metric.items(): - self.add_scalar(tag=k, scalar_value=v, global_step=global_step) - - -class tensorboard_evaluations: - def __init__(self, log_dir: str = "/logs") -> None: - self._log_dir = log_dir - - self._best_incum_track = np.inf - self._step_update = 1 - - self._toggle_epoch_max_reached = False - - self._config_track = 1 - - self._fidelity_search_count = 0 - self._fidelity_counter = 0 - self._fidelity_bool = False - self._fidelity_was_bool = False - - self._config_dict: dict[str, dict[str, Union[list[str], float, int]]] = {} - self._config_track_last = 1 - self._prev_config_list: list[str] = [] - - self._writer_config = [] - self._writer_summary = SummaryWriter_(log_dir=self._log_dir + "/summary") - self._writer_config.append( - SummaryWriter_( - log_dir=self._log_dir + "/configs" + "/config_" + str(self._config_track) - ) - ) - - def _make_grid(self, images: torch.tensor, nrow: int, padding: int = 2): - batch_size, num_channels, height, width = images.size() - x_mapping = min(nrow, batch_size) - y_mapping = int(math.ceil(float(batch_size) / x_mapping)) - height, width = height + 2, width + 2 - - grid = torch.zeros( - (num_channels, height * y_mapping + padding, width * x_mapping + padding) - ) - - k = 0 - for y in range(y_mapping): - for x in range(x_mapping): - if k >= batch_size: - break - image = images[k] - grid[ - :, - y * height + padding : y * height + padding + height - padding, - x * width + padding : x * width + padding + width - padding, - ] = image - k += 1 - - return grid - - def _incumbent(self, **incum_data) -> None: - """ - A function used to mainly display out the incumbent trajectory based on the step update which is after finishing every computation. - In other words, epochs == max_epochs - """ - loss = incum_data["loss"] - if loss < self._best_incum_track: - self._best_incum_track = loss - self._writer_summary.add_scalar( - tag="Summary" + "/Incumbent_Graph", - scalar_value=self._best_incum_track, - global_step=self._step_update, - ) - self._step_update += 1 - - def _track_config(self, **config_data) -> None: - config_list = config_data["config_list"] - loss = float(config_data["loss"]) - - for config_dict in self._config_dict.values(): - if self._prev_config_list != config_list: - if config_dict["config_list"] == config_list: - if self._fidelity_search_count == 0: - self._config_track_last = self._config_track - self._fidelity_was_bool = True - self._fidelity_bool = True - self._fidelity_search_count += 1 - loss_prev = self._config_dict["config_" + str(self._config_track)][ - "loss" - ] - self._incumbent(loss=loss_prev) - config = config_dict["config"] - if isinstance(config, (int, float)): - self._config_track = int(config) - - if not self._fidelity_bool: - if len(self._prev_config_list) > 0: - if self._prev_config_list != config_list: - self._fidelity_search_count = 0 - if self._fidelity_was_bool: - loss_prev = self._config_dict[ - "config_" + str(self._config_track) - ]["loss"] - self._incumbent(loss=loss_prev) - self._config_track = self._config_track_last + 1 - self._fidelity_counter += 1 - self._config_dict.clear() - self._fidelity_was_bool = False - else: - loss_prev = self._config_dict[ - "config_" + str(self._config_track) - ]["loss"] - self._incumbent(loss=loss_prev) - self._config_track += 1 - self._writer_config.append( - SummaryWriter_( - log_dir=self._log_dir - + "/configs" - + "/config_" - + str(self._config_track) - ) - ) - else: - self._fidelity_bool = False - self._toggle_epoch_max_reached = False - - self._config_dict["config_" + str(self._config_track)] = { - "config_list": config_list, - "loss": float(loss), - "config": self._config_track, - } - - self._prev_config_list = config_list - - def write_scalar_configs( - self, config_list: list, current_epoch: int, loss: float, scalar: float, tag: str - ) -> None: - """ - Writes any scalar to the specific corresponding config, EX: Learning_rate decay tracking, Accuracy... - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - scalar: a float (The scalar value to be visualized) - tag: a string (The tag of the scalar EX: tag = 'Learning_Rate') - """ - if tag == "loss": - scalar = loss - - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - self._writer_config[self._config_track - 1].add_scalar( - tag="Config" + str(self._config_track) + "/" + tag, - scalar_value=scalar, - global_step=current_epoch, - ) - - def write_scalar_fidelity( - self, config_list: list, current_epoch: int, loss: float, Accuracy: bool = False - ) -> None: - """ - This function will take the each fidelity and show the accuracy or the loss during HPO search for each fidelity. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - Accuracy: a bool (If true it will change the loss to accuracy % and display the results. - If false it will remain displaying with respect to the loss) - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if Accuracy: - acc = (1 - loss) * 100 - scalar_value = acc - else: - scalar_value = loss - - self._writer_config[self._config_track - 1].add_scalar( - tag="Summary" + "/Fidelity_" + str(self._fidelity_counter), - scalar_value=scalar_value, - global_step=current_epoch, - ) - - def write_histogram( - self, config_list: list, current_epoch: int, loss: float, model: nn.Module - ) -> None: - """ - By logging histograms for all parameters, you can gain insights into the distribution of different - parameter types and identify potential issues or patterns in their values. This comprehensive analysis - can help you better understand your model's behavior during training. - - Ex: Weights where their histograms do not show a change in shape from the first epoch up until the last prove to - mean that the training is not done properly and hence weights are not updated in the rythm they should - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - model: a nn.Module (The model which we want to analyze) - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - for _, param in model.named_parameters(): - self._writer_config[self._config_track - 1].add_histogram( - "Config" + str(self._config_track), - param.clone().cpu().data.numpy(), - current_epoch, - ) - - def write_image( - self, - config_list: list, - max_epochs: int, - current_epoch: int, - loss: float, - image_input: torch.Tensor, - num_images: int = 10, - random_images: bool = False, - resize_images: np.array = None, - ignore_warning: bool = True, - ) -> None: - """ - The user is free on how they want to tackle image visualization on tensorboard, they specify the numebr of images - they want to show and if the images should be taken randomly or not. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - max_epochs: an integer (Maximum epoch that can be reached at that specific run) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run) - image_imput: a Tensor (The input image in batch, shape: 12x3x28x28 'BxCxWxH') - num_images: an integer (The number of images ot be displayed for each config on tensorboard) - random_images: a bool (True is the images should be sampled randomly, False otherwise) - resize_images: an array (Resizing an the images to make them fit and be clearly visible on the grid) - ignore_warning: a bool (At the moment a warning is appearing, bug will be fixed later) - - Example code of displaying wrongly classified images: - - 1- In the trianing for loop: - predicted_labels = torch.argmax(output_of_model_after_input, dim=1) - misclassification_mask = predicted_labels != y_actual_labels - misclassified_images.append(x[misclassification_mask]) - - 2- Before the return, outside the training loop: - if len(misclassified_images) > 0: - misclassified_images = torch.cat(misclassified_images, dim=0) - - 3- Returning the misclassified images - return ..., misclassified_images - - Then use these misclassified_images as the image_input of this function - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if resize_images is None: - resize_images = [56, 56] - - if ignore_warning is True: - warnings.filterwarnings("ignore", category=DeprecationWarning) - - if current_epoch == max_epochs - 1: - if num_images > len(image_input): - num_images = len(image_input) - - if random_images is False: - subset_images = image_input[:num_images] - else: - random_indices = random.sample(range(len(image_input)), num_images) - subset_images = image_input[random_indices] - - resized_images = torch.nn.functional.interpolate( - subset_images, - size=(resize_images[0], resize_images[1]), - mode="bilinear", - align_corners=False, - ) - - nrow = int(resized_images.size(0) ** 0.75) - img_grid = self._make_grid(resized_images, nrow=nrow) - - self._writer_config[self._config_track - 1].add_image( - tag="IMG_config " + str(self._config_track), - img_tensor=img_grid, - global_step=self._config_track, - ) - - def write_hparam( - self, - config_list: list, - current_epoch: int, - loss: float, - Accuracy: bool = False, - **pipeline_space, - ) -> None: - """ - '.add_hparam' is a function in TensorBoard that allows you to log hyperparameters associated with your training run. - It takes a dictionary of hyperparameter names and values and associates them with the current run, making it easy to - compare and analyze different hyperparameter configurations. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run) - Accuracy: a bool (If true it will change the loss to accuracy % and display the results. - If false it will remain displaying with respect to the loss) - pipeline_space: The name of the hyperparameters in addition to their kwargs to be searched on. - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if Accuracy: - str_name = "Accuracy" - str_value = (1 - loss) * 100 - else: - str_name = "Loss" - str_value = loss - - values = {str_name: str_value} - - self._writer_config[self._config_track - 1].add_hparams( - pipeline_space, values, current_epoch - ) - - def close_writers(self) -> None: - """ - Closing the writers created after finishing all the tensorboard visualizations - """ - self._writer_summary.close() - for _, writer in enumerate(self._writer_config): - writer.close() diff --git a/pyproject.toml b/pyproject.toml index fa819ba0..fba3332d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,10 @@ more-itertools = "^9.0.0" portalocker = "^2.6.0" seaborn = "^0.12.1" pyyaml = "^6.0" +tensorboard = [ + {version = "^2.11", python = "<3.8"}, + {version = "^2.13", python = ">=3.8"} +] [tool.poetry.group.dev.dependencies] pre-commit = "^2.10" diff --git a/src/metahyper/api.py b/src/metahyper/api.py index 1babcdec..0da60626 100644 --- a/src/metahyper/api.py +++ b/src/metahyper/api.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any +from neps.plot.tensorboard_eval import tblogger + from ._locker import Locker from .utils import YamlSerializer, find_files, non_empty_file @@ -391,6 +393,16 @@ def run( pipeline_directory, previous_pipeline_directory, ) = _sample_config(optimization_dir, sampler, serializer, logger) + # Take the config data in case tensorboard is to be used. + if tblogger.logger_init_bool or tblogger.logger_bool: + tblogger.config_track_init_api( + config_id=config_id, + config=config, + config_working_directory=pipeline_directory, + config_previous_directory=previous_pipeline_directory, + optim_path=optimization_dir, + ) + tblogger.logger_init_bool = False config_lock_file = pipeline_directory / ".config_lock" config_lock_file.touch(exist_ok=True) diff --git a/src/neps/api.py b/src/neps/api.py index 721919d2..6eb97bd9 100644 --- a/src/neps/api.py +++ b/src/neps/api.py @@ -15,6 +15,7 @@ from metahyper import instance_from_map from .optimizers import BaseOptimizer, SearcherMapping +from .plot.tensorboard_eval import tblogger from .search_spaces.parameter import Parameter from .search_spaces.search_space import SearchSpace, pipeline_space_from_configspace from .utils.result_utils import get_loss @@ -82,8 +83,14 @@ def write_loss_and_config(file_handle, loss_, config_id_, config_): f"Finished evaluating config {config_id}" f" -- new best with loss {float(loss) :.3f}" ) + if tblogger.logger_bool: + tblogger.tracking_incumbent_api(best_loss=loss) + else: logger.info(f"Finished evaluating config {config_id}") + # Track the incumbent from the best loss + if tblogger.logger_bool: + tblogger.tracking_incumbent_api(best_loss=best_loss) return _post_evaluation_hook diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py new file mode 100644 index 00000000..18dca4c3 --- /dev/null +++ b/src/neps/plot/tensorboard_eval.py @@ -0,0 +1,537 @@ +import math +import os +import random +import warnings +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard.summary import hparams + + +class SummaryWriter_(SummaryWriter): + """ + This class inherits from the base SummaryWriter class and provides modifications to improve the logging. + It simplifies the logging structure and ensures consistent tag formatting for metrics. + + Changes Made: + - Avoids creating unnecessary subfolders in the log directory. + - Ensures all logs are stored in the same 'tfevent' directory for better organization. + - Updates metric keys to have a consistent 'Summary/' prefix for clarity. + - Improves the display of 'Loss' or 'Accuracy' on the Summary file. + + Methods: + - add_hparams: Overrides the base method to log hyperparameters and metrics with better formatting. + """ + + def add_hparams(self, hparam_dict, metric_dict, global_step): + if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict): + raise TypeError("hparam_dict and metric_dict should be dictionary.") + updated_metric = {} + for key, value in metric_dict.items(): + updated_key = "Summary" + "/" + key + updated_metric[updated_key] = value + exp, ssi, sei = hparams(hparam_dict, updated_metric) + + self.file_writer.add_summary(exp) + self.file_writer.add_summary(ssi) + self.file_writer.add_summary(sei) + for k, v in updated_metric.items(): + self.add_scalar(tag=k, scalar_value=v, global_step=global_step) + + +class tblogger: + config = None + config_id: Optional[int] = None + config_working_directory = None + config_previous_directory = None + optim_path = None + + config_value_fid: Optional[str] = None + fidelity_mode: bool = False + + logger_init_bool: bool = True + logger_bool: bool = False + + image_logger: bool = False + image_value: Optional[torch.tensor] = None + image_name: Optional[str] = None + epoch_value: Optional[int] = None + + disable_logging: bool = False + + loss: Optional[float] = None + current_epoch: int + scalar_accuracy_mode: bool = False + hparam_accuracy_mode: bool = False + + config_writer: Optional[SummaryWriter_] = None + summary_writer: Optional[SummaryWriter_] = None + + logging_mode: list = [] + + @staticmethod + def config_track_init_api( + config_id, config, config_working_directory, config_previous_directory, optim_path + ): + """ + Track the Configuration space data from the way it is done on neps metahyper '_sample_config' to keep insinc with + config ids and directories NePS is operating on. + """ + + tblogger.config = config + tblogger.config_id = config_id + tblogger.config_working_directory = config_working_directory + tblogger.config_previous_directory = config_previous_directory + tblogger.optim_path = optim_path + + @staticmethod + def _initialize_writers(): + if not tblogger.config_writer: + optim_config_path = tblogger.optim_path / "results" + if tblogger.config_previous_directory is not None: + tblogger.fidelity_mode = True + while not tblogger.config_writer: + if os.path.exists(tblogger.config_previous_directory / "tbevents"): + find_previous_config_id = ( + tblogger.config_working_directory / "previous_config.id" + ) + if os.path.exists(find_previous_config_id): + with open(find_previous_config_id) as file: + contents = file.read() + tblogger.config_value_fid = contents + tblogger.config_writer = SummaryWriter_( + tblogger.config_previous_directory / "tbevents" + ) + else: + find_previous_config_path = ( + tblogger.config_previous_directory / "previous_config.id" + ) + if os.path.exists(find_previous_config_path): + with open(find_previous_config_path) as file: + contents = file.read() + tblogger.config_value_fid = contents + tblogger.config_working_directory = ( + tblogger.config_previous_directory + ) + tblogger.config_previous_directory = ( + optim_config_path / f"config_{contents}" + ) + else: + tblogger.fidelity_mode = False + tblogger.config_writer = SummaryWriter_( + tblogger.config_working_directory / "tbevents" + ) + + @staticmethod + def _make_grid(images: torch.tensor, nrow: int, padding: int = 2): + """ + Create a grid of images from a batch of images. + + Args: + images (torch.Tensor): The input batch of images with shape (batch_size, num_channels, height, width). + nrow (int): The number rows on the grid. + padding (int, optional): The padding between images in the grid. Default is 2. + + Returns: + torch.Tensor: A grid of images with shape (num_channels, total_height, total_width), + where total_height and total_width depend on the number of images and the grid settings. + """ + batch_size, num_channels, height, width = images.size() + x_mapping = min(nrow, batch_size) + y_mapping = int(math.ceil(float(batch_size) / x_mapping)) + height, width = height + 2, width + 2 + + grid = torch.zeros( + (num_channels, height * y_mapping + padding, width * x_mapping + padding) + ) + + k = 0 + for y in range(y_mapping): + for x in range(x_mapping): + if k >= batch_size: + break + image = images[k] + grid[ + :, + y * height + padding : y * height + padding + height - padding, + x * width + padding : x * width + padding + width - padding, + ] = image + k += 1 + + return grid + + @staticmethod + def scalar_logging(value: float) -> list: + """ + Prepare a scalar value for logging. + + Args: + value (float): The scalar value to be logged. + + Returns: + list: A list containing the logging mode and the value for logging. + The list format is [logging_mode, value]. + """ + logging_mode = "scalar" + return [logging_mode, value] + + @staticmethod + def image_logging( + img_tensor: torch.Tensor, + counter: int, + resize_images: Optional[List[Optional[int]]] = None, + ignore_warning: bool = True, + random_images: bool = True, + num_images: int = 20, + ) -> List[Union[str, torch.Tensor, int, bool, List[Optional[int]]]]: + """ + Prepare an image tensor for logging. + + Args: + img_tensor (torch.Tensor): The image tensor to be logged. + counter (int): A counter value for teh frequency of image logging (ex: counter 2 means for every + 2 epochs a new set of images are logged). + resize_images (list of int): A list of integers representing the image sizes + after resizing or None if no resizing required. + Default is None. + ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. + random_images (bool, optional): Whether the images are selected randomly. Default is True. + num_images (int, optional): The number of images to log. Default is 20. + + Returns: + list: A list containing the logging mode and all the necessary parameters for image logging. + The list format is [logging_mode, img_tensor, counter, repetitive, resize_images, + ignore_warning, random_images, num_images]. + """ + logging_mode = "image" + return [ + logging_mode, + img_tensor, + counter, + resize_images, + ignore_warning, + random_images, + num_images, + ] + + @staticmethod + def layer_gradient_logging(model: nn.Module): + """ + Prepare a model for logging layer gradients. + + Args: + model (nn.Module): The PyTorch model for which layer gradients will be logged. + + Returns: + list: A list containing the logging mode and the model for layer gradient logging. + The list format is [logging_mode, model]. + """ + logging_mode = "gradient_mean" + return [logging_mode, model] + + @staticmethod + def _file_arrange(): + # TODO: Have only one tfevent file in the respective folders instead of multiple (especially in the summary folder) + pass + + @staticmethod + def _write_scalar_config(tag: str, value: Union[float, int]): + """ + Write scalar values to the TensorBoard log. + + Args: + tag (str): The tag for the scalar value. + value (float or int): The scalar value to be logged. Default is None. + + Note: + If the tag is 'Loss' and scalar_accuracy_mode is True, the tag will be changed to 'Accuracy', + and the value will be transformed accordingly. + + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.scalar_accuracy_mode (bool) + - tblogger.fidelity_mode (bool) + - tblogger.config_writer (SummaryWriter_) + + The function will log the scalar value under different tags based on fidelity mode and other configurations. + """ + tblogger._initialize_writers() + + if tag == "Loss": + if tblogger.scalar_accuracy_mode: + tag = "Accuracy" + value = (1 - value) * 100 + if tblogger.config_writer is not None: + if tblogger.fidelity_mode: + tblogger.config_writer.add_scalar( + tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, + scalar_value=value, + global_step=tblogger.current_epoch, + ) + else: + tblogger.config_writer.add_scalar( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + scalar_value=value, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def _write_image_config( + tag: str, + image: torch.tensor, + counter: int, + resize_images: Optional[List[Optional[int]]] = None, + ignore_warning: bool = True, + random_images: bool = True, + num_images: int = 20, + ): + """ + Write images to the TensorBoard log. + + Args: + tag (str): The tag for the images. + image (torch.Tensor): The image tensor to be logged. + counter (int): A counter value associated with the images. + resize_images (list of int): A list of integers representing the image sizes + after resizing or None if no resizing required. + Default is None. + ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. + random_images (bool, optional): Whether the images are selected randomly. Default is True. + num_images (int, optional): The number of images to log. Default is 20. + + Note: + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.current_epoch (int) + - tblogger.fidelity_mode (bool) + - tblogger.config_writer (SummaryWriter_) + - tblogger.config_value_fid (int or None) + - tblogger.config_id (int) + + The function will log a subset of images to TensorBoard based on the given configurations. + """ + tblogger._initialize_writers() + + if resize_images is None: + resize_images = [32, 32] + + if ignore_warning is True: + warnings.filterwarnings("ignore", category=DeprecationWarning) + + if tblogger.current_epoch % counter == 0: + if num_images > len(image): + num_images = len(image) + + if random_images is False: + subset_images = image[:num_images] + else: + random_indices = random.sample(range(len(image)), num_images) + subset_images = image[random_indices] + + resized_images = torch.nn.functional.interpolate( + subset_images, + size=(resize_images[0], resize_images[1]), + mode="bilinear", + align_corners=False, + ) + + nrow = int(resized_images.size(0) ** 0.75) + img_grid = tblogger._make_grid(resized_images, nrow=nrow) + if tblogger.config_writer is not None: + if tblogger.fidelity_mode: + tblogger.config_writer.add_image( + tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, + img_tensor=img_grid, + global_step=tblogger.current_epoch, + ) + else: + tblogger.config_writer.add_image( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + img_tensor=img_grid, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def _write_hparam_config(): + """ + Write hyperparameter configurations to the TensorBoard log, inspired by the 'hparam' original function of tensorboard. + + Note: + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.hparam_accuracy_mode (bool) + - tblogger.loss (float) + - tblogger.config_writer (SummaryWriter_) + - tblogger.config (dict) + - tblogger.current_epoch (int) + + The function will log hyperparameter configurations along with a metric value (either accuracy or loss) + to TensorBoard based on the given configurations. + """ + tblogger._initialize_writers() + + if tblogger.hparam_accuracy_mode: + str_name = "Accuracy" + str_value = (1 - tblogger.loss) * 100 + else: + str_name = "Loss" + str_value = tblogger.loss + + values = {str_name: str_value} + if tblogger.config_writer is not None: + tblogger.config_writer.add_hparams( + hparam_dict=tblogger.config, + metric_dict=values, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def tracking_incumbent_api(best_loss): + """ + Track the incumbent (best) loss and log it in the TensorBoard summary. + + Args: + best_loss (float): The best loss value to be tracked, according to the _post_hook_function of NePS. + + Note: + The function relies on the following global variables: + - tblogger.config_writer (SummaryWriter_) + - tblogger.optim_path (str) + - tblogger.incum_tracker (int) + - tblogger.incum_val (float) + - tblogger.summary_writer (SummaryWriter_) + + The function logs the incumbent loss in a TensorBoard summary with a graph. + It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. + """ + if tblogger.config_writer: + tblogger.config_writer.close() + tblogger.config_writer = None + + file_path = str(tblogger.optim_path) + "/all_losses_and_configs.txt" + tblogger.incum_tracker = 0 + with open(file_path) as f: + for line in f: + tblogger.incum_tracker += line.count("Config ID") + + tblogger.incum_val = float(best_loss) + + logdir = str(tblogger.optim_path) + "/summary" + + if tblogger.summary_writer is None: + tblogger.summary_writer = SummaryWriter_(logdir) + + tblogger.summary_writer.add_scalar( + tag="Summary" + "/Incumbent_graph", + scalar_value=tblogger.incum_val, + global_step=tblogger.incum_tracker, + ) + + tblogger.summary_writer.flush() + tblogger.summary_writer.close() + + @staticmethod + def disable(disable_logger: bool = True): + """ + The function allows for enabling or disabling the logger functionality + throughout the program execution by updating the value of 'tblogger.disable_logging'. + When the logger is disabled, it will not perform any logging operations. + + Args: + disable_logger (bool, optional): A boolean flag to control the logger. + If True (default), the logger will be disabled. + If False, the logger will be enabled. + + Example: + # Disable the logger + tblogger.disable() + + # Enable the logger + tblogger.disable(False) + """ + tblogger.disable_logging = disable_logger + + @staticmethod + def log( + loss: float, + current_epoch: int, + writer_scalar: bool = True, + writer_hparam: bool = True, + scalar_accuracy_mode: bool = False, + hparam_accuracy_mode: bool = False, + data: Optional[dict] = None, + ): + """ + Log experiment data to the logger, including scalar values, hyperparameters, images, and layer gradients. + + Args: + loss (float): The current loss value in training. + current_epoch (int): The current epoch of the experiment. + writer_scalar (bool, optional): Whether to write the loss or accuracy for the + configs during training. Default is True. + writer_hparam (bool, optional): Whether to write hyperparameters logging + of the configs during training. Default is True. + scalar_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's + value accordingliy. Default is False. + hparam_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's + value accordingliy. Default is False. + data (dict, optional): Additional experiment data to be logged. It should be in the format: + { + 'tag1': tblogger.scalar_logging(value=value1), + 'tag2': tblogger.image_logging(img_tensor=img, counter=2), + 'tag3': tblogger.layer_gradient_logging(model=model), + } + Default is None. + + """ + tblogger.current_epoch = current_epoch + tblogger.loss = loss + tblogger.scalar_accuracy_mode = scalar_accuracy_mode + tblogger.hparam_accuracy_mode = hparam_accuracy_mode + + if not tblogger.disable_logging: + tblogger.logger_bool = True + + if writer_scalar: + tblogger._write_scalar_config(tag="Loss", value=loss) + + if writer_hparam: + tblogger._write_hparam_config() + + if data is not None: + for key in data: + if data[key][0] == "scalar": + tblogger._write_scalar_config(tag=str(key), value=data[key][1]) + + elif data[key][0] == "image": + tblogger._write_image_config( + tag=str(key), + image=data[key][1], + counter=data[key][2], + resize_images=data[key][3], + ignore_warning=data[key][4], + random_images=data[key][5], + num_images=data[key][6], + ) + + elif data[key][0] == "gradient_mean": + for i, layer in enumerate(data[key][1].children()): + layer_gradients = [param.grad for param in layer.parameters()] + if layer_gradients: + mean_gradient = torch.mean( + torch.cat([grad.view(-1) for grad in layer_gradients]) + ) + tblogger._write_scalar_config( + tag=f"{key}_gradient_{i}", value=mean_gradient.item() + ) + + else: + tblogger.logger_bool = False