From 98c4bb2ff738c861c0300cbf58f71eb7c46ff4b8 Mon Sep 17 00:00:00 2001 From: Misko Date: Thu, 5 Sep 2024 01:51:52 +0000 Subject: [PATCH 1/6] raise on gemnet oc fail to load scaling; add scaling tests for hydra oc aswell --- .../core/models/gemnet_oc/gemnet_oc.py | 3 ++- src/fairchem/core/modules/scaling/compat.py | 10 +++++-- src/fairchem/core/modules/scaling/fit.py | 4 ++- tests/core/e2e/test_s2ef.py | 26 ++++++++++++++----- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index c5e6efb005..3442bfcf21 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -234,6 +234,7 @@ def __init__( num_elements: int = 83, otf_graph: bool = False, scale_file: str | None = None, + raise_on_scale_load_error: bool = True, **kwargs, # backwards compatibility with deprecated arguments ) -> None: if qint_tags is None: @@ -381,7 +382,7 @@ def __init__( if direct_forces: self.out_forces.reset_parameters(out_initializer) - load_scales_compat(self, scale_file) + load_scales_compat(self, scale_file, on_error_raise=raise_on_scale_load_error) def set_cutoffs(self, cutoff, cutoff_qint, cutoff_aeaint, cutoff_aint): self.cutoff = cutoff diff --git a/src/fairchem/core/modules/scaling/compat.py b/src/fairchem/core/modules/scaling/compat.py index fd0a42f0b4..7060bf49ff 100644 --- a/src/fairchem/core/modules/scaling/compat.py +++ b/src/fairchem/core/modules/scaling/compat.py @@ -51,7 +51,9 @@ def _load_scale_dict(scale_file: str | ScaleDict | None): return scale_dict -def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) -> None: +def load_scales_compat( + module: nn.Module, scale_file: str | ScaleDict | None, on_error_raise: bool = True +) -> None: scale_dict = _load_scale_dict(scale_file) if not scale_dict: return @@ -64,11 +66,15 @@ def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) -> logging.debug( f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}" ) + missing_keys = set(scale_factors.keys()) - set(scale_dict.keys()) + if len(missing_keys) > 0 and on_error_raise: + raise ValueError( + "Failed to load scaling values. Missing entries for,", missing_keys + ) for name, scale in scale_dict.items(): if name not in scale_factors: logging.warning(f"Scale factor {name} not found in model") continue - scale_module, module_name = scale_factors[name] logging.debug(f"Loading scale factor {scale} for ({name} => {module_name})") scale_module.set_(scale) diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 4bfc4bb62a..5090804d09 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -2,6 +2,7 @@ import logging import math +import re import readline import sys from itertools import islice @@ -56,6 +57,7 @@ def compute_scaling_factors(config, num_batches: int = 16) -> None: assert data_loader is not None, "Train set required to load batches" if ckpt_file.exists(): + assert 1 == 0 trainer.load_checkpoint(checkpoint_path=str(ckpt_file)) # region reoad scale file contents if necessary @@ -205,7 +207,7 @@ def index_fn(name: str = name) -> None: torch.save( { - x[0].replace(".scale_factor", ""): x[1] + re.sub("^backbone.", "", x[0].replace(".scale_factor", "")): x[1] for x in trainer.model.to("cpu").named_parameters() if ".scale_" in x[0] }, diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 10e3203c91..a7477cd1df 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -106,7 +106,15 @@ def smoke_test_train( energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 ) - def test_gemnet_fit_scaling(self, configs, tutorial_val_src): + @pytest.mark.parametrize( + ("model_name"), + [ + ("gemnet_oc"), + ("gemnet_oc_hydra"), + ("gemnet_oc_hydra_grad"), + ], + ) + def test_gemnet_fit_scaling(self, model_name, configs, tutorial_val_src): with tempfile.TemporaryDirectory() as tempdirname: # (1) generate scaling factors for gemnet config @@ -128,7 +136,7 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src): ] ) update_yaml_with_dict( - configs["gemnet_oc"], + configs[model_name], config_yaml, update_dict_with={ "dataset": oc20_lmdb_train_and_val_from_paths( @@ -141,24 +149,30 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src): config = build_config(args, override_args) # (2) if existing scaling factors are present remove them - if "scale_file" in config["model"]: - config["model"].pop("scale_file") + config["model"].pop("scale_file", None) + if "backbone" in config["model"]: + config["model"]["backbone"].pop("scale_file", None) compute_scaling_factors(config) + model_config_change = ( + {"backbone": {"scale_file": scaling_pt}} + if "backbone" in config["model"] + else {"scale_file": scaling_pt} + ) # (3) try to run the config with the newly generated scaling factors _ = _run_main( rundir=tempdirname, update_dict_with={ "optim": {"max_epochs": 1}, - "model": {"use_pbc_single": True, "scale_file": scaling_pt}, + "model": model_config_change, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), ), }, - input_yaml=configs["gemnet_oc"], + input_yaml=configs[model_name], ) # not all models are tested with otf normalization estimation From 9089921bed41c20c6f51916d7681d170807f7977 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 10 Sep 2024 17:05:06 +0000 Subject: [PATCH 2/6] remove assert --- src/fairchem/core/modules/scaling/fit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 5090804d09..368c23479d 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -57,7 +57,6 @@ def compute_scaling_factors(config, num_batches: int = 16) -> None: assert data_loader is not None, "Train set required to load batches" if ckpt_file.exists(): - assert 1 == 0 trainer.load_checkpoint(checkpoint_path=str(ckpt_file)) # region reoad scale file contents if necessary From 937fb5cc5f9a63a945ba41e28d0b75c5bd1bbc2b Mon Sep 17 00:00:00 2001 From: Misko Date: Sat, 14 Sep 2024 01:05:14 +0000 Subject: [PATCH 3/6] always raise error if scalings factors cannot be loaded --- src/fairchem/core/models/gemnet_oc/gemnet_oc.py | 3 +-- src/fairchem/core/modules/scaling/compat.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index 475c842f52..c982b7d43a 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -234,7 +234,6 @@ def __init__( num_elements: int = 83, otf_graph: bool = False, scale_file: str | None = None, - raise_on_scale_load_error: bool = True, **kwargs, # backwards compatibility with deprecated arguments ) -> None: if qint_tags is None: @@ -382,7 +381,7 @@ def __init__( if direct_forces: self.out_forces.reset_parameters(out_initializer) - load_scales_compat(self, scale_file, on_error_raise=raise_on_scale_load_error) + load_scales_compat(self, scale_file) def set_cutoffs(self, cutoff, cutoff_qint, cutoff_aeaint, cutoff_aint): self.cutoff = cutoff diff --git a/src/fairchem/core/modules/scaling/compat.py b/src/fairchem/core/modules/scaling/compat.py index 7060bf49ff..1e42f6e89b 100644 --- a/src/fairchem/core/modules/scaling/compat.py +++ b/src/fairchem/core/modules/scaling/compat.py @@ -51,9 +51,7 @@ def _load_scale_dict(scale_file: str | ScaleDict | None): return scale_dict -def load_scales_compat( - module: nn.Module, scale_file: str | ScaleDict | None, on_error_raise: bool = True -) -> None: +def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) -> None: scale_dict = _load_scale_dict(scale_file) if not scale_dict: return @@ -67,7 +65,7 @@ def load_scales_compat( f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}" ) missing_keys = set(scale_factors.keys()) - set(scale_dict.keys()) - if len(missing_keys) > 0 and on_error_raise: + if len(missing_keys) > 0: raise ValueError( "Failed to load scaling values. Missing entries for,", missing_keys ) From 70df4ff524c919488fb75e0fa1f76a989661cfd4 Mon Sep 17 00:00:00 2001 From: Misko Date: Sun, 15 Sep 2024 20:24:09 +0000 Subject: [PATCH 4/6] add debug to figure out why github ci is failing --- src/fairchem/core/modules/scaling/compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/modules/scaling/compat.py b/src/fairchem/core/modules/scaling/compat.py index 1e42f6e89b..1f82783377 100644 --- a/src/fairchem/core/modules/scaling/compat.py +++ b/src/fairchem/core/modules/scaling/compat.py @@ -67,7 +67,10 @@ def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) -> missing_keys = set(scale_factors.keys()) - set(scale_dict.keys()) if len(missing_keys) > 0: raise ValueError( - "Failed to load scaling values. Missing entries for,", missing_keys + "Failed to load scaling values. Missing entries for,", + missing_keys, + "\nHave", + scale_dict.keys(), ) for name, scale in scale_dict.items(): if name not in scale_factors: From 53a7bb1f9737780c4586dc72b92669cce0d1dc5b Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 16 Sep 2024 20:06:12 +0000 Subject: [PATCH 5/6] rename for always ddp --- src/fairchem/core/modules/scaling/fit.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 368c23479d..4f0f806dca 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -204,9 +204,16 @@ def index_fn(name: str = name) -> None: trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent trainer.is_debug = False + def rename_module(name): + name = name.replace(".scale_factor", "") + # remove DDP wrapper + name = re.sub("^module.", "", name) + # remove hydra backbone + return re.sub("^backbone.", "", name) + torch.save( { - re.sub("^backbone.", "", x[0].replace(".scale_factor", "")): x[1] + rename_module(x[0]): x[1] for x in trainer.model.to("cpu").named_parameters() if ".scale_" in x[0] }, From 6f2cf93c78d217be0c23a74dedd230539ec50ada Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 1 Oct 2024 17:42:20 +0000 Subject: [PATCH 6/6] remove loading scaling factors for num blocks >3 --- docs/legacy_tutorials/OCP_Tutorial.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/legacy_tutorials/OCP_Tutorial.md b/docs/legacy_tutorials/OCP_Tutorial.md index 19fd93f6bc..164583ae5b 100644 --- a/docs/legacy_tutorials/OCP_Tutorial.md +++ b/docs/legacy_tutorials/OCP_Tutorial.md @@ -1162,7 +1162,6 @@ model = { "otf_graph": False, "output_init": "HeOrthogonal", "activation": "silu", - "scale_file": "./gemnet-dT.json", "regress_forces": False, "direct_forces": False, }