From f489105c32300ee51421acc8cefac18dbeba6663 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 17 Sep 2024 00:01:09 +0000 Subject: [PATCH] make it work with old checkpoints --- .../equiformer_v2/eqv2_to_eqv2_hydra.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py b/src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py index 70d3ee470..cc92f7e85 100644 --- a/src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py +++ b/src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py @@ -17,22 +17,30 @@ def convert_checkpoint_and_config_to_hydra( new_checkpoint_fn ), "Output checkpoint cannot already exist!" + def remove_module_prefix(x): + while x[: len("module.")] == "module.": + x = x[len("module.") :] + return x + def eqv2_state_dict_to_hydra_state_dict(eqv2_state_dict): hydra_state_dict = OrderedDict() for og_key in list(eqv2_state_dict.keys()): + key_without_module = remove_module_prefix(og_key) if "force_block" in og_key or "energy_block" in og_key: - key = og_key.replace( + key = "module." + key_without_module.replace( "force_block", "output_heads.forces.force_block" ).replace("energy_block", "output_heads.energy.energy_block") else: - offset = 0 - if og_key[: len("module.")] == "module.": - offset += len("module.") - key = og_key[:offset] + "backbone." + og_key[offset:] + key = "module.backbone." + key_without_module hydra_state_dict[key] = eqv2_state_dict[og_key] return hydra_state_dict def convert_configs_to_hydra(yaml_config, checkpoint_config): + if isinstance(checkpoint_config["model"], str): + name = checkpoint_config["model"] + checkpoint_config["model"] = checkpoint_config.pop("model_attributes") + checkpoint_config["model"]["name"] = name + new_model_config = { "name": "hydra", "backbone": checkpoint_config["model"].copy(), @@ -80,7 +88,7 @@ def convert_configs_to_hydra(yaml_config, checkpoint_config): checkpoint["state_dict"] ) for key in ["ema", "optimizer", "scheduler"]: - new_checkpoint.pop(key) + new_checkpoint.pop(key, None) # write output torch.save(new_checkpoint, new_checkpoint_fn)