Skip to content

Commit

Permalink
fix(tensorboard): Ensure tensorboard knows about end of config
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Sep 17, 2024
1 parent c19801e commit defa370
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 20 deletions.
22 changes: 16 additions & 6 deletions neps/plot/tensorboard_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@

import math
from pathlib import Path
from typing import Any, ClassVar, Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Mapping
from typing_extensions import override

import numpy as np
import torch
from torch.utils.tensorboard.summary import hparams
from torch.utils.tensorboard.writer import SummaryWriter

from neps.runtime import get_in_progress_trial, get_workers_neps_state
from neps.runtime import (
get_in_progress_trial,
get_workers_neps_state,
register_notify_trial_end,
)
from neps.status.status import get_summary_dict
from neps.utils.common import get_initial_directory

if TYPE_CHECKING:
from neps.state.trial import Trial


class SummaryWriter_(SummaryWriter): # noqa: N801
"""This class inherits from the base SummaryWriter class and provides
Expand Down Expand Up @@ -87,6 +94,8 @@ def _initiate_internal_configurations() -> None:
trial = get_in_progress_trial()
neps_state = get_workers_neps_state()

register_notify_trial_end("NEPS_TBLOGGER", tblogger.end_of_config)

# We are assuming that neps state is all filebased here
root_dir = Path(neps_state.location)
assert root_dir.exists()
Expand All @@ -97,20 +106,20 @@ def _initiate_internal_configurations() -> None:
if trial.metadata.previous_trial_location is not None
else None
)
tblogger.config_id = trial.metadata.id
tblogger.optimizer_dir = root_dir
tblogger.config = trial.config

@staticmethod
def _is_initialized() -> bool:
# Returns 'True' if config_writer is already initialized. 'False' otherwise
return tblogger.config_writer is not None

@staticmethod
def _initialize_writers() -> None:
# This code runs only once per config, to assign that config a config_writer.
if (
tblogger.config_previous_directory is None
and tblogger.config_working_directory
and tblogger.config_working_directory is not None
):
# If no fidelities are there yet, define the writer via the config_id
tblogger.config_id = str(tblogger.config_working_directory).rsplit(
Expand All @@ -120,8 +129,9 @@ def _initialize_writers() -> None:
tblogger.config_working_directory / "tbevents"
)
return

# Searching for the initial directory where tensorboard events are stored.
if tblogger.config_working_directory:
if tblogger.config_working_directory is not None:
init_dir = get_initial_directory(
pipeline_directory=tblogger.config_working_directory
)
Expand All @@ -135,7 +145,7 @@ def _initialize_writers() -> None:
)

@staticmethod
def end_of_config() -> None:
def end_of_config(trial: Trial) -> None: # noqa: ARG004
"""Closes the writer."""
if tblogger.config_writer:
# Close and reset previous config writers for consistent logging.
Expand Down
9 changes: 9 additions & 0 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def get_in_progress_trial() -> Trial:
return _CURRENTLY_RUNNING_TRIAL_IN_PROCESS


_TRIAL_END_CALLBACKS: dict[str, Callable[[Trial], None]] = {}


def register_notify_trial_end(key: str, callback: Callable[[Trial], None]) -> None:
_TRIAL_END_CALLBACKS[key] = callback


@contextmanager
def _set_global_trial(trial: Trial) -> Iterator[None]:
global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603
Expand All @@ -107,6 +114,8 @@ def _set_global_trial(trial: Trial) -> Iterator[None]:
)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial
yield
for _key, callback in _TRIAL_END_CALLBACKS.items():
callback(trial)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None


Expand Down
37 changes: 23 additions & 14 deletions neps/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import inspect
import random
from functools import partial
import numpy as np
from pathlib import Path
import random
from typing import Any, Iterable, Mapping, Sequence

import numpy as np
import torch
import yaml

Expand Down Expand Up @@ -140,6 +141,9 @@ def load_lightning_checkpoint(
return checkpoint_path, checkpoint


_INTIAL_DIRECTORY_CACHE: dict[str, Path] = {}


# TODO: We should have a better way to have a shared folder between trials.
# Right now, the fidelity lineage is linear, however this will be a difficulty
# when/if we have a tree structure.
Expand All @@ -155,13 +159,15 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path:
"""
neps_state = get_workers_neps_state()
if pipeline_directory is not None:
pipeline_directory = Path(pipeline_directory)
# TODO: Hard coded assumption
config_id = pipeline_directory.name.split("_", maxsplit=1)[-1]
config_id = Path(pipeline_directory).name.split("_", maxsplit=1)[-1]
trial = neps_state.get_trial_by_id(config_id)
else:
trial = get_in_progress_trial()

if trial.metadata.id in _INTIAL_DIRECTORY_CACHE:
return _INTIAL_DIRECTORY_CACHE[trial.metadata.id]

# Recursively find the initial directory
while (prev_trial_id := trial.metadata.previous_trial_id) is not None:
trial = neps_state.get_trial_by_id(prev_trial_id)
Expand All @@ -170,7 +176,10 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path:

# TODO: Hard coded assumption that we are operating in a filebased neps
assert isinstance(initial_dir, str)
return Path(initial_dir)
path = Path(initial_dir)

_INTIAL_DIRECTORY_CACHE[trial.metadata.id] = path
return path


def get_searcher_data(
Expand Down Expand Up @@ -396,13 +405,13 @@ def __init__(self, *args, **kwargs):


class DataWriter:
"""A class to specify how to save/write a data to the folder by
implementing your own write_data function.
Use the set_attributes function to set all your necessary attributes and the data
and then write_data will be called with only the directory path as argument
during the write process.
"""
A class to specify how to save/write a data to the folder by
implementing your own write_data function.
Use the set_attributes function to set all your necessary attributes and the data
and then write_data will be called with only the directory path as argument
during the write process
"""

def __init__(self, name: str):
self.name = name

Expand All @@ -415,10 +424,10 @@ def write_data(self, to_directory: Path):


class EvaluationData:
"""A class to store some data for a single evaluation (configuration)
and write that data to its corresponding config folder.
"""
A class to store some data for a single evaluation (configuration)
and write that data to its corresponding config folder
"""

def __init__(self):
self.data_dict: dict[str, DataWriter] = {}

Expand Down

0 comments on commit defa370

Please sign in to comment.