diff --git a/neps/plot/tensorboard_eval.py b/neps/plot/tensorboard_eval.py index e77329b4..c23cb427 100644 --- a/neps/plot/tensorboard_eval.py +++ b/neps/plot/tensorboard_eval.py @@ -4,7 +4,7 @@ import math from pathlib import Path -from typing import Any, ClassVar, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Mapping from typing_extensions import override import numpy as np @@ -12,10 +12,17 @@ from torch.utils.tensorboard.summary import hparams from torch.utils.tensorboard.writer import SummaryWriter -from neps.runtime import get_in_progress_trial, get_workers_neps_state +from neps.runtime import ( + get_in_progress_trial, + get_workers_neps_state, + register_notify_trial_end, +) from neps.status.status import get_summary_dict from neps.utils.common import get_initial_directory +if TYPE_CHECKING: + from neps.state.trial import Trial + class SummaryWriter_(SummaryWriter): # noqa: N801 """This class inherits from the base SummaryWriter class and provides @@ -87,6 +94,8 @@ def _initiate_internal_configurations() -> None: trial = get_in_progress_trial() neps_state = get_workers_neps_state() + register_notify_trial_end("NEPS_TBLOGGER", tblogger.end_of_config) + # We are assuming that neps state is all filebased here root_dir = Path(neps_state.location) assert root_dir.exists() @@ -97,12 +106,12 @@ def _initiate_internal_configurations() -> None: if trial.metadata.previous_trial_location is not None else None ) + tblogger.config_id = trial.metadata.id tblogger.optimizer_dir = root_dir tblogger.config = trial.config @staticmethod def _is_initialized() -> bool: - # Returns 'True' if config_writer is already initialized. 'False' otherwise return tblogger.config_writer is not None @staticmethod @@ -110,7 +119,7 @@ def _initialize_writers() -> None: # This code runs only once per config, to assign that config a config_writer. if ( tblogger.config_previous_directory is None - and tblogger.config_working_directory + and tblogger.config_working_directory is not None ): # If no fidelities are there yet, define the writer via the config_id tblogger.config_id = str(tblogger.config_working_directory).rsplit( @@ -120,8 +129,9 @@ def _initialize_writers() -> None: tblogger.config_working_directory / "tbevents" ) return + # Searching for the initial directory where tensorboard events are stored. - if tblogger.config_working_directory: + if tblogger.config_working_directory is not None: init_dir = get_initial_directory( pipeline_directory=tblogger.config_working_directory ) @@ -135,7 +145,7 @@ def _initialize_writers() -> None: ) @staticmethod - def end_of_config() -> None: + def end_of_config(trial: Trial) -> None: # noqa: ARG004 """Closes the writer.""" if tblogger.config_writer: # Close and reset previous config writers for consistent logging. diff --git a/neps/runtime.py b/neps/runtime.py index 5cf0f29f..c7108a28 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -93,6 +93,13 @@ def get_in_progress_trial() -> Trial: return _CURRENTLY_RUNNING_TRIAL_IN_PROCESS +_TRIAL_END_CALLBACKS: dict[str, Callable[[Trial], None]] = {} + + +def register_notify_trial_end(key: str, callback: Callable[[Trial], None]) -> None: + _TRIAL_END_CALLBACKS[key] = callback + + @contextmanager def _set_global_trial(trial: Trial) -> Iterator[None]: global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 @@ -107,6 +114,8 @@ def _set_global_trial(trial: Trial) -> Iterator[None]: ) _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial yield + for _key, callback in _TRIAL_END_CALLBACKS.items(): + callback(trial) _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None diff --git a/neps/utils/common.py b/neps/utils/common.py index be0782f7..27e6691b 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -3,11 +3,12 @@ from __future__ import annotations import inspect +import random from functools import partial -import numpy as np from pathlib import Path -import random from typing import Any, Iterable, Mapping, Sequence + +import numpy as np import torch import yaml @@ -140,6 +141,9 @@ def load_lightning_checkpoint( return checkpoint_path, checkpoint +_INTIAL_DIRECTORY_CACHE: dict[str, Path] = {} + + # TODO: We should have a better way to have a shared folder between trials. # Right now, the fidelity lineage is linear, however this will be a difficulty # when/if we have a tree structure. @@ -155,13 +159,15 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: """ neps_state = get_workers_neps_state() if pipeline_directory is not None: - pipeline_directory = Path(pipeline_directory) # TODO: Hard coded assumption - config_id = pipeline_directory.name.split("_", maxsplit=1)[-1] + config_id = Path(pipeline_directory).name.split("_", maxsplit=1)[-1] trial = neps_state.get_trial_by_id(config_id) else: trial = get_in_progress_trial() + if trial.metadata.id in _INTIAL_DIRECTORY_CACHE: + return _INTIAL_DIRECTORY_CACHE[trial.metadata.id] + # Recursively find the initial directory while (prev_trial_id := trial.metadata.previous_trial_id) is not None: trial = neps_state.get_trial_by_id(prev_trial_id) @@ -170,7 +176,10 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: # TODO: Hard coded assumption that we are operating in a filebased neps assert isinstance(initial_dir, str) - return Path(initial_dir) + path = Path(initial_dir) + + _INTIAL_DIRECTORY_CACHE[trial.metadata.id] = path + return path def get_searcher_data( @@ -396,13 +405,13 @@ def __init__(self, *args, **kwargs): class DataWriter: + """A class to specify how to save/write a data to the folder by + implementing your own write_data function. + Use the set_attributes function to set all your necessary attributes and the data + and then write_data will be called with only the directory path as argument + during the write process. """ - A class to specify how to save/write a data to the folder by - implementing your own write_data function. - Use the set_attributes function to set all your necessary attributes and the data - and then write_data will be called with only the directory path as argument - during the write process - """ + def __init__(self, name: str): self.name = name @@ -415,10 +424,10 @@ def write_data(self, to_directory: Path): class EvaluationData: + """A class to store some data for a single evaluation (configuration) + and write that data to its corresponding config folder. """ - A class to store some data for a single evaluation (configuration) - and write that data to its corresponding config folder - """ + def __init__(self): self.data_dict: dict[str, DataWriter] = {}