Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for hydra gemnet OC scaling factor generation and loading; Raise error on fail to load scaling factors #831

Merged
merged 11 commits into from
Oct 1, 2024
Merged
1 change: 0 additions & 1 deletion docs/legacy_tutorials/OCP_Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,6 @@ model = {
"otf_graph": False,
"output_init": "HeOrthogonal",
"activation": "silu",
"scale_file": "./gemnet-dT.json",
"regress_forces": False,
"direct_forces": False,
}
Expand Down
9 changes: 8 additions & 1 deletion src/fairchem/core/modules/scaling/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,18 @@ def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) ->
logging.debug(
f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}"
)
missing_keys = set(scale_factors.keys()) - set(scale_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
"Failed to load scaling values. Missing entries for,",
missing_keys,
"\nHave",
scale_dict.keys(),
)
for name, scale in scale_dict.items():
if name not in scale_factors:
logging.warning(f"Scale factor {name} not found in model")
continue

scale_module, module_name = scale_factors[name]
logging.debug(f"Loading scale factor {scale} for ({name} => {module_name})")
scale_module.set_(scale)
10 changes: 9 additions & 1 deletion src/fairchem/core/modules/scaling/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import re
import readline
import sys
from itertools import islice
Expand Down Expand Up @@ -202,9 +203,16 @@ def index_fn(name: str = name) -> None:
trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent
trainer.is_debug = False

def rename_module(name):
name = name.replace(".scale_factor", "")
# remove DDP wrapper
name = re.sub("^module.", "", name)
# remove hydra backbone
return re.sub("^backbone.", "", name)

torch.save(
{
x[0].replace(".scale_factor", ""): x[1]
rename_module(x[0]): x[1]
for x in trainer.model.to("cpu").named_parameters()
if ".scale_" in x[0]
},
Expand Down
26 changes: 20 additions & 6 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def smoke_test_train(
energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6
)

def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
@pytest.mark.parametrize(
("model_name"),
[
("gemnet_oc"),
("gemnet_oc_hydra"),
("gemnet_oc_hydra_grad"),
],
)
def test_gemnet_fit_scaling(self, model_name, configs, tutorial_val_src):

with tempfile.TemporaryDirectory() as tempdirname:
# (1) generate scaling factors for gemnet config
Expand All @@ -130,7 +138,7 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
]
)
update_yaml_with_dict(
configs["gemnet_oc"],
configs[model_name],
config_yaml,
update_dict_with={
"dataset": oc20_lmdb_train_and_val_from_paths(
Expand All @@ -143,24 +151,30 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
config = build_config(args, override_args)

# (2) if existing scaling factors are present remove them
if "scale_file" in config["model"]:
config["model"].pop("scale_file")
config["model"].pop("scale_file", None)
if "backbone" in config["model"]:
config["model"]["backbone"].pop("scale_file", None)

compute_scaling_factors(config)

model_config_change = (
{"backbone": {"scale_file": scaling_pt}}
if "backbone" in config["model"]
else {"scale_file": scaling_pt}
)
# (3) try to run the config with the newly generated scaling factors
_ = _run_main(
rundir=tempdirname,
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True, "scale_file": scaling_pt},
"model": model_config_change,
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
input_yaml=configs["gemnet_oc"],
input_yaml=configs[model_name],
)

def test_convert_checkpoint_and_config_to_hydra(self, configs, tutorial_val_src):
Expand Down