Skip to content

Commit

Permalink
fix gemnet scaling factors fit.py and add a test (#819)
Browse files Browse the repository at this point in the history
* fix gemnet fit and add test

* only save factors, not the whole model!
  • Loading branch information
misko authored Aug 22, 2024
1 parent 1bee0d7 commit c2b0c30
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 46 deletions.
64 changes: 28 additions & 36 deletions src/fairchem/core/modules/scaling/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,9 @@ def _train_batch(trainer: BaseTrainer, batch) -> None:
del out, loss


def main(*, num_batches: int = 16) -> None:
# region args/config setup
setup_logging()

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
_config = build_config(args, override_args)
_config["logger"] = "wandb"
# endregion
def compute_scaling_factors(config, num_batches: int = 16) -> None:

assert not args.distributed, "This doesn't work with DDP"
with new_trainer_context(args=args, config=_config) as ctx:
with new_trainer_context(config=config) as ctx:
config = ctx.config
trainer = ctx.trainer

Expand All @@ -61,8 +52,8 @@ def main(*, num_batches: int = 16) -> None:
logging.info(f"Input checkpoint path: {ckpt_file}, {ckpt_file.exists()=}")

model: nn.Module = trainer.model
val_loader = trainer.val_loader
assert val_loader is not None, "Val dataset is required for making predictions"
data_loader = trainer.train_loader
assert data_loader is not None, "Train set required to load batches"

if ckpt_file.exists():
trainer.load_checkpoint(checkpoint_path=str(ckpt_file))
Expand Down Expand Up @@ -122,15 +113,8 @@ def main(*, num_batches: int = 16) -> None:
sys.exit(-1)
# endregion

# region get the output path
out_path = Path(
_prefilled_input(
"Enter output path for fitted scale factors: ",
prefill=str(ckpt_file),
)
)
if out_path.exists():
logging.warning(f"Already found existing file: {out_path}")
if ckpt_file.exists():
logging.warning(f"Already found existing file: {ckpt_file}")
flag = input(
"Do you want to continue and overwrite existing file (1), "
"or exit (2)? "
Expand All @@ -142,7 +126,7 @@ def main(*, num_batches: int = 16) -> None:
sys.exit()

logging.info(
f"Output path for fitted scale factors: {out_path}, {out_path.exists()=}"
f"Output path for fitted scale factors: {ckpt_file}, {ckpt_file.exists()=}"
)
# endregion

Expand Down Expand Up @@ -175,7 +159,7 @@ def index_fn(name: str = name) -> None:
module.initialize_(index_fn=index_fn)

# single pass through network
_train_batch(trainer, next(iter(val_loader)))
_train_batch(trainer, next(iter(data_loader)))

# sort the scale factors by their computation order
sorted_factors = sorted(
Expand All @@ -200,7 +184,7 @@ def index_fn(name: str = name) -> None:

logging.info(f"Fitting {name}...")
with module.fit_context_():
for batch in islice(val_loader, num_batches):
for batch in islice(data_loader, num_batches):
_train_batch(trainer, batch)
stats, ratio, value = module.fit_()

Expand All @@ -216,19 +200,27 @@ def index_fn(name: str = name) -> None:
assert module.fitted, f"{name} is not fitted"

# region save the scale factors to the checkpoint file
trainer.config["cmd"]["checkpoint_dir"] = out_path.parent
trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent
trainer.is_debug = False
out_file = trainer.save(
metrics=None,
checkpoint_file=out_path.name,
training_state=False,

torch.save(
{
x[0].replace(".scale_factor", ""): x[1]
for x in trainer.model.to("cpu").named_parameters()
if ".scale_" in x[0]
},
str(ckpt_file),
)
assert out_file is not None, "Failed to save checkpoint"
out_file = Path(out_file)
assert out_file.exists(), f"Failed to save checkpoint to {out_file}"
# endregion
logging.info(f"Saved results to: {out_file}")
logging.info(f"Saved results to: {ckpt_file}")


if __name__ == "__main__":
main()
# region args/config setup
setup_logging()

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
assert not args.distributed, "This doesn't work with DDP"
config = build_config(args, override_args)

compute_scaling_factors(config)
19 changes: 11 additions & 8 deletions tests/core/e2e/test_e2e_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def merge_dictionary(d, u):
return d


def update_yaml_with_dict(input_yaml, output_yaml, update_dict_with):
with open(input_yaml) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
if update_dict_with is not None:
yaml_config = merge_dictionary(yaml_config, update_dict_with)
yaml_config["backend"] = "gloo"
with open(str(output_yaml), "w") as yaml_file:
yaml.dump(yaml_config, yaml_file)


def _run_main(
rundir,
input_yaml,
Expand All @@ -103,14 +113,7 @@ def _run_main(
world_size=0,
):
config_yaml = Path(rundir) / "train_and_val_on_val.yml"

with open(input_yaml) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
if update_dict_with is not None:
yaml_config = merge_dictionary(yaml_config, update_dict_with)
yaml_config["backend"] = "gloo"
with open(str(config_yaml), "w") as yaml_file:
yaml.dump(yaml_config, yaml_file)
update_yaml_with_dict(input_yaml, config_yaml, update_dict_with)
run_args = {
"run_dir": rundir,
"logdir": f"{rundir}/logs",
Expand Down
67 changes: 65 additions & 2 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
import numpy as np
import numpy.testing as npt
import pytest
from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths
from fairchem.core._cli import Runner
from fairchem.core.modules.scaling.fit import compute_scaling_factors
from test_e2e_commons import (
_run_main,
oc20_lmdb_train_and_val_from_paths,
update_yaml_with_dict,
)

from fairchem.core.common.utils import setup_logging
from fairchem.core.common.utils import build_config, setup_logging
from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes

from fairchem.core.common.flags import flags

setup_logging()


Expand Down Expand Up @@ -98,6 +106,61 @@ def smoke_test_train(
energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6
)

def test_gemnet_fit_scaling(self, configs, tutorial_val_src):

with tempfile.TemporaryDirectory() as tempdirname:
# (1) generate scaling factors for gemnet config
config_yaml = f"{tempdirname}/train_and_val_on_val.yml"
scaling_pt = f"{tempdirname}/scaling.pt"
# run
parser = flags.get_parser()
args, override_args = parser.parse_known_args(
[
"--mode",
"train",
"--seed",
"100",
"--config-yml",
config_yaml,
"--cpu",
"--checkpoint",
scaling_pt,
]
)
update_yaml_with_dict(
configs["gemnet_oc"],
config_yaml,
update_dict_with={
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
)
config = build_config(args, override_args)

# (2) if existing scaling factors are present remove them
if "scale_file" in config["model"]:
config["model"].pop("scale_file")

compute_scaling_factors(config)

# (3) try to run the config with the newly generated scaling factors
_ = _run_main(
rundir=tempdirname,
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True, "scale_file": scaling_pt},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
input_yaml=configs["gemnet_oc"],
)

# not all models are tested with otf normalization estimation
# only gemnet_oc, escn, equiformer, and their hydra versions
@pytest.mark.parametrize(
Expand Down

0 comments on commit c2b0c30

Please sign in to comment.