Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/mlflow sync tag #83

Merged
merged 45 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d6fa830
feat: authentication support for mlflow sync
gmertes Sep 9, 2024
4606e76
chore: formatting
gmertes Sep 9, 2024
620f3a0
chore: changelog
gmertes Sep 9, 2024
2dd693b
Merge remote-tracking branch 'origin/develop' into feat/mlflow-sync-auth
JesperDramsch Sep 23, 2024
debf2ab
chore: changelog add link
gmertes Sep 23, 2024
a70e22c
Merge branch 'develop' into feat/mlflow-sync-auth
gmertes Oct 1, 2024
60d0dcc
fix: mlflow auth use web seed token
gmertes Oct 9, 2024
e5ab3ec
feat: make target env var an optional argument
gmertes Oct 9, 2024
93ce7a8
chore: docstrings
gmertes Oct 9, 2024
391a494
fix: tests
gmertes Oct 9, 2024
2f7f1c2
chore: add comment
gmertes Oct 9, 2024
498c6de
chore: changelog
gmertes Oct 9, 2024
b54c475
fix: sync authentication flag
gmertes Oct 9, 2024
4693972
refactor: move `health_check` to submodule top level
gmertes Oct 9, 2024
6fed2e7
feat: add health check
gmertes Oct 9, 2024
34a3f37
chore: update error msg
gmertes Oct 9, 2024
0116027
refactor: mlflow utils
gmertes Oct 10, 2024
6b97a56
chore: docstring
gmertes Oct 10, 2024
9b3088c
Merge branch 'fix/mlflow-auth-api' into fix/mlflow_sync_tag
anaprietonem Oct 10, 2024
81ee192
tag and catch for new experiments
anaprietonem Oct 11, 2024
caae529
Merge branch 'develop' into fix/mlflow_sync_tag
anaprietonem Oct 14, 2024
a69872e
update changelog
anaprietonem Oct 14, 2024
d1b4ee3
download artifacts if server2server
anaprietonem Oct 14, 2024
ee6f3f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
a7b7462
correct changelog typo
anaprietonem Oct 14, 2024
03bf8c8
Merge branch 'fix/mlflow_sync_tag' of github.com:ecmwf/anemoi-trainin…
anaprietonem Oct 14, 2024
13a3436
remove artifact path to clean things up
anaprietonem Oct 14, 2024
10052fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
6901f7f
test env variables for server2server
anaprietonem Oct 17, 2024
9a31392
Merge branch 'develop' into feat/mlflow-sync-auth-copy2
anaprietonem Oct 17, 2024
5654211
fixing code style according to pre-commit hooks
anaprietonem Oct 18, 2024
dcb7b27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
e09d1a7
update changelog
anaprietonem Oct 18, 2024
9407b97
Merge branch 'fix/mlflow_sync_tag' of github.com:ecmwf/anemoi-trainin…
anaprietonem Oct 18, 2024
f830788
fix missing return
anaprietonem Oct 18, 2024
48472d2
move to correct section
anaprietonem Oct 18, 2024
af3f315
better catching of cleaning tempdirectory
anaprietonem Oct 18, 2024
e7b8a08
remove sys.exit
anaprietonem Oct 18, 2024
e4a5702
implementing suggestions to reduce tech debt
anaprietonem Oct 18, 2024
9ef468e
removing tech debt
anaprietonem Oct 18, 2024
fcb247b
add logging message
anaprietonem Oct 18, 2024
7d0b2f6
refactor and remove env variables
anaprietonem Oct 20, 2024
f8f8bc6
fixes according to PR feedback to remove long logs and authentication…
anaprietonem Oct 21, 2024
e6f5af2
Merge branch 'develop' into fix/mlflow_sync_tag
anaprietonem Oct 21, 2024
8082359
update changelog
anaprietonem Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ Keep it human-readable, your future self will thank you!
## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.0...HEAD)

### Added

- Mlflow-sync to include new tag for server to server syncing [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
- Mlflow-sync to include functionality to resume and fork server2server runs [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
- Rollout training for Limited Area Models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)
- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)

### Fixed
- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83)

### Changed


## [0.2.0 - Feature release](https://github.com/ecmwf/anemoi-training/compare/0.1.0...0.2.0) - 2024-10-16

- Make pin_memory of the Dataloader configurable (#64)
Expand Down
11 changes: 6 additions & 5 deletions src/anemoi/training/diagnostics/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def get_mlflow_logger(config: DictConfig) -> None:
return None

from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger
from anemoi.training.diagnostics.mlflow.logger import get_mlflow_run_params

resumed = config.training.run_id is not None
forked = config.training.fork_run_id is not None
Expand All @@ -39,7 +38,6 @@ def get_mlflow_logger(config: DictConfig) -> None:
tracking_uri = save_dir
# create directory if it does not exist
Path(config.hardware.paths.logs.mlflow).mkdir(parents=True, exist_ok=True)
run_id, run_name, tags = get_mlflow_run_params(config, tracking_uri)

log_hyperparams = True
if resumed and not config.diagnostics.log.mlflow.on_resume_create_child:
Expand All @@ -53,19 +51,22 @@ def get_mlflow_logger(config: DictConfig) -> None:
)
log_hyperparams = False

LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri)
logger = AnemoiMLflowLogger(
experiment_name=config.diagnostics.log.mlflow.experiment_name,
project_name=config.diagnostics.log.mlflow.project_name,
tracking_uri=tracking_uri,
save_dir=save_dir,
run_name=run_name,
run_id=run_id,
run_name=config.diagnostics.log.mlflow.run_name,
run_id=config.training.run_id,
fork_run_id=config.training.fork_run_id,
log_model=config.diagnostics.log.mlflow.log_model,
offline=offline,
tags=tags,
resumed=resumed,
forked=forked,
log_hyperparams=log_hyperparams,
authentication=config.diagnostics.log.mlflow.authentication,
on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child,
)
config_params = OmegaConf.to_container(config, resolve=True)

Expand Down
152 changes: 101 additions & 51 deletions src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,55 +33,11 @@
if TYPE_CHECKING:
from argparse import Namespace

from omegaconf import OmegaConf
import mlflow

LOGGER = logging.getLogger(__name__)


def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | None, str, dict[str, Any]]:
run_id = None
tags = {"projectName": config.diagnostics.log.mlflow.project_name}
# create a tag with the command used to run the script
command = os.environ.get("ANEMOI_TRAINING_CMD", sys.argv[0])
tags["command"] = command.split("/")[-1] # get the python script name
tags["mlflow.source.name"] = command
if len(sys.argv) > 1:
# add the arguments to the command tag
tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:])
if config.training.run_id or config.training.fork_run_id:
"Either run_id or fork_run_id must be provided to resume a run."

import mlflow

if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline:
TokenAuth(tracking_uri).authenticate()

mlflow_client = mlflow.MlflowClient(tracking_uri)

if config.training.run_id and config.diagnostics.log.mlflow.on_resume_create_child:
parent_run_id = config.training.run_id # parent_run_id
run_name = mlflow_client.get_run(parent_run_id).info.run_name
tags["mlflow.parentRunId"] = parent_run_id
tags["resumedRun"] = "True" # tags can't take boolean values
elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child:
run_id = config.training.run_id
run_name = mlflow_client.get_run(run_id).info.run_name
mlflow_client.update_run(run_id=run_id, status="RUNNING")
tags["resumedRun"] = "True"
else:
parent_run_id = config.training.fork_run_id
tags["forkedRun"] = "True"
tags["forkedRunId"] = parent_run_id

if config.diagnostics.log.mlflow.run_name:
run_name = config.diagnostics.log.mlflow.run_name
else:
import uuid

run_name = f"{uuid.uuid4()!s}"
return run_id, run_name, tags


class LogsMonitor:
"""Class for logging terminal output.

Expand All @@ -90,7 +46,7 @@ class LogsMonitor:

Note: If there is an error, the terminal output logging ends before the error message is printed into the log file.
In order for the user to see the error message, the user must look at the slurm output file.
We provide the SLRM job id in the very beginning of the log file and print the final status of the run in the end.
We provide the SLURM job id in the very beginning of the log file and print the final status of the run in the end.

Parameters
----------
Expand Down Expand Up @@ -191,7 +147,7 @@ def start(self) -> None:
self._buffer_registry[id(self)] = self._io_buffer
# Start thread to asynchronously collect logs
self._th_collector.start()
LOGGER.info("Termial Log Path: %s", self.file_save_path)
LOGGER.info("Terminal Log Path: %s", self.file_save_path)
if os.getenv("SLURM_JOB_ID"):
LOGGER.info("SLURM job id: %s", os.getenv("SLURM_JOB_ID"))

Expand Down Expand Up @@ -288,31 +244,33 @@ class AnemoiMLflowLogger(MLFlowLogger):
def __init__(
self,
experiment_name: str = "lightning_logs",
project_name: str = "anemoi",
run_name: str | None = None,
tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"),
tags: dict[str, Any] | None = None,
save_dir: str | None = "./mlruns",
log_model: Literal[True, False, "all"] = False,
prefix: str = "",
resumed: bool | None = False,
forked: bool | None = False,
run_id: str | None = None,
fork_run_id: str | None = None,
offline: bool | None = False,
authentication: bool | None = None,
log_hyperparams: bool | None = True,
on_resume_create_child: bool | None = True,
) -> None:
"""Initialize the AnemoiMLflowLogger.

Parameters
----------
experiment_name : str, optional
Name of experiment, by default "lightning_logs"
project_name : str, optional
Name of the project, by default "anemoi"
run_name : str | None, optional
Name of run, by default None
tracking_uri : str | None, optional
Tracking URI of server, by default os.getenv("MLFLOW_TRACKING_URI")
tags : dict[str, Any] | None, optional
Tags to apply, by default None
save_dir : str | None, optional
Directory to save logs to, by default "./mlruns"
log_model : Literal[True, False, "all"], optional
Expand All @@ -325,13 +283,16 @@ def __init__(
Whether the run was forked or not, by default False
run_id : str | None, optional
Run id of current run, by default None
fork_run_id : str | None, optional
Fork Run id from parent run, by default None
offline : bool | None, optional
Whether to run offline or not, by default False
authentication : bool | None, optional
Whether to authenticate with server or not, by default None
log_hyperparams : bool | None, optional
Whether to log hyperparameters, by default True

on_resume_create_child: bool | None, optional
Whether to create a child run when resuming a run, by default False
"""
if offline:
# OFFLINE - When we run offline we can pass a save_dir pointing to a local path
Expand All @@ -346,6 +307,9 @@ def __init__(
self._forked = forked
self._flag_log_hparams = log_hyperparams

self._fork_run_server2server = None
self._parent_run_server2server = None

if rank_zero_only.rank == 0:
enabled = authentication and not offline
self.auth = TokenAuth(tracking_uri, enabled=enabled)
Expand All @@ -357,6 +321,15 @@ def __init__(
self.auth.authenticate()
health_check(tracking_uri)

run_id, run_name, tags = self._get_mlflow_run_params(
project_name=project_name,
run_name=run_name,
config_run_id=run_id,
fork_run_id=fork_run_id,
tracking_uri=tracking_uri,
on_resume_create_child=on_resume_create_child,
)

super().__init__(
experiment_name=experiment_name,
run_name=run_name,
Expand All @@ -368,6 +341,83 @@ def __init__(
run_id=run_id,
)

def _check_server2server_lineage(self, run: mlflow.entities.Run) -> bool:
"""Address lineage and metadata for server2server runs.

Those are runs that have been sync from one remote server to another
"""
server2server = run.data.tags.get("server2server", "False") == "True"
LOGGER.info("Server2Server: %s", server2server)
if server2server:
parent_run_across_servers = run.data.params.get(
"metadata.offline_run_id",
run.data.params.get("metadata.server2server_run_id"),
)
if self._forked:
# if we want to fork a resume run we need to set the parent_run_across_servers
# but just to restore the checkpoint
self._fork_run_server2server = parent_run_across_servers
else:
self._parent_run_server2server = parent_run_across_servers

def _get_mlflow_run_params(
self,
project_name: str,
run_name: str,
config_run_id: str,
fork_run_id: str,
tracking_uri: str,
on_resume_create_child: bool,
) -> tuple[str | None, str, dict[str, Any]]:

run_id = None
tags = {"projectName": project_name}

# create a tag with the command used to run the script
command = os.environ.get("ANEMOI_TRAINING_CMD", sys.argv[0])
tags["command"] = command.split("/")[-1] # get the python script name
tags["mlflow.source.name"] = command
if len(sys.argv) > 1:
# add the arguments to the command tag
tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:])

if config_run_id or fork_run_id:
"Either run_id or fork_run_id must be provided to resume a run."
import mlflow

mlflow_client = mlflow.MlflowClient(tracking_uri)

if config_run_id and on_resume_create_child:
parent_run_id = config_run_id # parent_run_id
parent_run = mlflow_client.get_run(parent_run_id)
run_name = parent_run.info.run_name
self._check_server2server_lineage(parent_run)
tags["mlflow.parentRunId"] = parent_run_id
tags["resumedRun"] = "True" # tags can't take boolean values
elif config_run_id and not on_resume_create_child:
run_id = config_run_id
run = mlflow_client.get_run(run_id)
run_name = run.info.run_name
self._check_server2server_lineage(run)
mlflow_client.update_run(run_id=run_id, status="RUNNING")
tags["resumedRun"] = "True"
else:
parent_run_id = fork_run_id
tags["forkedRun"] = "True"
tags["forkedRunId"] = parent_run_id
run = mlflow_client.get_run(parent_run_id)
self._check_server2server_lineage(run)

if not run_name:
import uuid

run_name = f"{uuid.uuid4()!s}"

if os.getenv("SLURM_JOB_ID"):
tags["SLURM_JOB_ID"] = os.getenv("SLURM_JOB_ID")

return run_id, run_name, tags

@property
def experiment(self) -> MLFlowLogger.experiment:
if rank_zero_only.rank == 0:
Expand Down
42 changes: 32 additions & 10 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(self, config: DictConfig) -> None:

self.config.training.run_id = self.run_id
LOGGER.info("Run id: %s", self.config.training.run_id)

# Get the server2server lineage
self._get_server2server_lineage()

# Update paths to contain the run ID
self._update_paths()

Expand Down Expand Up @@ -147,7 +151,7 @@ def model(self) -> GraphForecaster:
@rank_zero_only
def _get_mlflow_run_id(self) -> str:
run_id = self.mlflow_logger.run_id
# for resumed runs or offline runs logging this can be uesful
# for resumed runs or offline runs logging this can be useful
LOGGER.info("Mlflow Run id: %s", run_id)
return run_id

Expand Down Expand Up @@ -188,19 +192,18 @@ def last_checkpoint(self) -> str | None:
if not self.start_from_checkpoint:
return None

fork_id = self.fork_run_server2server or self.config.training.fork_run_id
checkpoint = Path(
self.config.hardware.paths.checkpoints.parent,
self.config.training.fork_run_id or self.run_id,
fork_id or self.lineage_run,
self.config.hardware.files.warm_start or "last.ckpt",
)

# Check if the last checkpoint exists
if Path(checkpoint).exists():
LOGGER.info("Resuming training from last checkpoint: %s", checkpoint)
return checkpoint

LOGGER.warning("Could not find last checkpoint: %s", checkpoint)
return None
msg = "Could not find last checkpoint: %s", checkpoint
raise RuntimeError(msg)

@cached_property
def callbacks(self) -> list[pl.callbacks.Callback]:
Expand Down Expand Up @@ -252,10 +255,13 @@ def profiler(self) -> PyTorchProfiler | None:
def loggers(self) -> list:
loggers = []
if self.config.diagnostics.log.wandb.enabled:
LOGGER.info("W&B logger enabled")
loggers.append(self.wandb_logger)
if self.config.diagnostics.log.tensorboard.enabled:
LOGGER.info("TensorBoard logger enabled")
loggers.append(self.tensorboard_logger)
if self.config.diagnostics.log.mlflow.enabled:
LOGGER.info("MLFlow logger enabled")
loggers.append(self.mlflow_logger)
return loggers

Expand Down Expand Up @@ -291,17 +297,33 @@ def _log_information(self) -> None:
LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate)
LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start)

def _get_server2server_lineage(self) -> None:
"""Get the server2server lineage."""
self.parent_run_server2server = None
self.fork_run_server2server = None
if self.config.diagnostics.log.mlflow.enabled:
self.parent_run_server2server = self.mlflow_logger._parent_run_server2server
LOGGER.info("Parent run server2server: %s", self.parent_run_server2server)
self.fork_run_server2server = self.mlflow_logger._fork_run_server2server
LOGGER.info("Fork run server2server: %s", self.fork_run_server2server)

def _update_paths(self) -> None:
"""Update the paths in the configuration."""
self.lineage_run = None
if self.run_id: # when using mlflow only rank0 will have a run_id except when resuming runs
# Multi-gpu new runs or forked runs - only rank 0
# Multi-gpu resumed runs - all ranks
self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.run_id)
self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.run_id)
self.lineage_run = self.parent_run_server2server or self.run_id
self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run)
self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.lineage_run)
elif self.config.training.fork_run_id:
# WHEN USING MANY NODES/GPUS
self.lineage_run = self.parent_run_server2server or self.config.training.fork_run_id
# Only rank non zero in the forked run will go here
parent_run = self.config.training.fork_run_id
self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, parent_run)
self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run)

LOGGER.info("Checkpoints path: %s", self.config.hardware.paths.checkpoints)
LOGGER.info("Plots path: %s", self.config.hardware.paths.plots)

@cached_property
def strategy(self) -> DDPGroupStrategy:
Expand Down
Loading
Loading