diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a2c619f..16bb14d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,3 +4,6 @@ /.github/ @theissenhelen @jesperdramsch @gmertes /.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes /pyproject.toml @theissenhelen @jesperdramsch @gmertes + +# Protect package exemptions +/src/anemoi/inference/checkpoint/package_exemptions.py @gmertes @hcookie @theissenhelen @jesperdramsch diff --git a/CHANGELOG.md b/CHANGELOG.md index f6e8cd3..32a33e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ Keep it human-readable, your future self will thank you! ### Added - ci: changelog release updater - earthkit-data replaces climetlab +- `validate_environment` on Checkpoint [#13](https://github.com/ecmwf/anemoi-inference/pull/13) +- Validate the environment against a checkpoint with `anemoi-inference inspect --validate path.ckpt` - ci-hpc-config - Add Condition to store data [#15](https://github.com/ecmwf/anemoi-inference/pull/15) diff --git a/pyproject.toml b/pyproject.toml index daabf40..78374ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "anytree", "earthkit-data>=0.10", "numpy", + "packaging", "pyyaml", "semantic-version", "torch", diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index b4efbf1..f1e54cf 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -5,16 +5,22 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from __future__ import annotations import json import logging import os from functools import cached_property +from typing import Literal from anemoi.utils.checkpoints import has_metadata from anemoi.utils.checkpoints import load_metadata +from anemoi.utils.provenance import gather_provenance_info +from packaging.version import Version -from .metadata import Metadata +from anemoi.inference.checkpoint.metadata import Metadata +from anemoi.inference.checkpoint.package_exemptions import EXEMPT_NAMESPACES +from anemoi.inference.checkpoint.package_exemptions import EXEMPT_PACKAGES LOG = logging.getLogger(__name__) @@ -63,3 +69,142 @@ def operational_config(self): LOG.warning("No operational configuration found. Using default configuration.") return {} + + def validate_environment( + self, + all_packages: bool = False, + on_difference: Literal["warn", "error", "ignore"] = "warn", + *, + exempt_packages: list[str] | None = None, + ) -> bool: + """ + Validate environment of the checkpoint against the current environment. + + Parameters + ---------- + all_packages : bool, optional + Check all packages in environment or just `anemoi`'s, by default False + on_difference : Literal['warn', 'error', 'ignore'], optional + What to do on difference, by default "warn" + exempt_packages : list[str], optional + List of packages to exempt from the check, by default EXEMPT_PACKAGES + + Returns + ------- + bool + True if environment is valid, False otherwise + + Raises + ------ + RuntimeError + If found difference and `on_difference` is 'error' + ValueError + If `on_difference` is not 'warn' or 'error' + """ + train_environment = self.provenance_training + inference_environment = gather_provenance_info(full=False) + + # Override module information with more complete inference environment capture + import importlib.metadata as imp_metadata + + module_versions = { + distribution.metadata["Name"].replace("-", "_"): distribution.metadata["Version"] + for distribution in imp_metadata.distributions() + } + + inference_environment["module_versions"] = module_versions + + exempt_packages = exempt_packages or [] + exempt_packages.extend(EXEMPT_PACKAGES) + + invalid_messages = { + "python": [], + "missing": [], + "mismatch": [], + "critical mismatch": [], + "uncommitted": [], + } + + if train_environment["python"] != inference_environment["python"]: + invalid_messages["python"].append( + f"Python version mismatch: {train_environment['python']} != {inference_environment['python']}" + ) + + for module in train_environment["module_versions"].keys(): + inference_module_name = module # Due to package name differences between retrieval methods this may change + + if not all_packages and "anemoi" not in module: + continue + elif module in exempt_packages or module.split(".")[0] in EXEMPT_NAMESPACES: + continue + elif module.startswith("_"): + continue + elif module not in inference_environment["module_versions"]: + if "." in module and module.replace(".", "_") in inference_environment["module_versions"]: + inference_module_name = module.replace(".", "_") + else: + try: + import importlib + + importlib.import_module(module) + continue + except (ModuleNotFoundError, ImportError): + pass + invalid_messages["missing"].append(f"Missing module in inference environment: {module}") + continue + + train_environment_version = Version(train_environment["module_versions"][module]) + inference_environment_version = Version(inference_environment["module_versions"][inference_module_name]) + + if train_environment_version < inference_environment_version: + invalid_messages["mismatch"].append( + f"Version of module {module} was lower in training then in inference: {train_environment_version!s} <= {inference_environment_version!s}" + ) + elif train_environment_version > inference_environment_version: + invalid_messages["critical mismatch"].append( + f"CRITICAL: Version of module {module} was greater in training then in inference: {train_environment_version!s} > {inference_environment_version!s}" + ) + + for git_record in train_environment["git_versions"].keys(): + file_record = train_environment["git_versions"][git_record]["git"] + if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0: + continue + + if git_record not in inference_environment["git_versions"]: + invalid_messages["uncommitted"].append( + f"Training environment contained uncommitted change missing in inference environment: {git_record}" + ) + elif ( + train_environment["git_versions"][git_record]["sha1"] + != inference_environment["git_versions"][git_record]["sha1"] + ): + invalid_messages["uncommitted"].append( + f"sha1 mismatch for git record between training and inference. {git_record} (training != inference): {train_environment['git_versions'][git_record]} != {inference_environment['git_versions'][git_record]}" + ) + + for git_record in inference_environment["git_versions"].keys(): + file_record = inference_environment["git_versions"][git_record]["git"] + if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0: + continue + + if git_record not in train_environment["git_versions"]: + invalid_messages["uncommitted"].append( + f"Inference environment contains uncommited changes missing in training: {git_record}" + ) + + if len(invalid_messages) > 0: + text = "Environment validation failed. The following issues were found:\n" + "\n".join( + [f" {key}:\n " + "\n ".join(value) for key, value in invalid_messages.items() if len(value) > 0] + ) + if on_difference == "warn": + LOG.warning(text) + elif on_difference == "error": + raise RuntimeError(text) + elif on_difference == "ignore": + pass + else: + raise ValueError(f"Invalid value for `on_difference`: {on_difference}") + return False + + LOG.info(f"Environment validation passed") + return True diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 433ab35..57f324c 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -177,6 +177,11 @@ def _config_training(self): """Part of the metadata refers to the model configuration""" return self._metadata["config"]["training"] + @cached_property + def provenance_training(self): + """Environmental Configuration when trained""" + return dict(self._metadata.get("provenance_training", {})) + ########################################################################### def _forcings(self, constants): forcing = self._indices["data"]["input"]["forcing"] diff --git a/src/anemoi/inference/checkpoint/package_exemptions.py b/src/anemoi/inference/checkpoint/package_exemptions.py new file mode 100644 index 0000000..260d822 --- /dev/null +++ b/src/anemoi/inference/checkpoint/package_exemptions.py @@ -0,0 +1,15 @@ +# Complete package name to be exempt +EXEMPT_PACKAGES = [ + "anemoi.training", + "hydra", + "hydra_plugins", + "lightning", + "pytorch_lightning", + "lightning_fabric", + "lightning_utilities", +] + +# Entire namespaces to be exempt +EXEMPT_NAMESPACES = [ + "hydra_plugins", +] diff --git a/src/anemoi/inference/commands/inspect.py b/src/anemoi/inference/commands/inspect.py index 36b1691..9d2e970 100644 --- a/src/anemoi/inference/commands/inspect.py +++ b/src/anemoi/inference/commands/inspect.py @@ -14,17 +14,25 @@ class InspectCmd(Command): + """Inspect the contents of a checkpoint file.""" need_logging = False def add_arguments(self, command_parser): command_parser.add_argument("path", help="Path to the checkpoint.") command_parser.add_argument("--dump", action="store_true", help="Print internal information") + command_parser.add_argument( + "--validate", action="store_true", help="Validate the current virtual environment against the checkpoint" + ) def run(self, args): c = Checkpoint(args.path) + if args.validate: + c.validate_environment() + return + if args.dump: c.dump() return