Skip to content

Commit

Permalink
Fix: finetune error with cpu device (#76)
Browse files Browse the repository at this point in the history
* fix:finetune error with cpu device

* refactor: argparse to vasprun_to_xyz

* fix wandb

---------

Co-authored-by: Xixian Liu <v-xixianliu@microsoft.com>
  • Loading branch information
ZeroKnighting and Xixian Liu authored Jan 12, 2025
1 parent ba7c4e6 commit bebbbd0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
19 changes: 12 additions & 7 deletions src/mattersim/training/finetune_mattersim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
from mattersim.utils.logger_utils import get_logger

logger = get_logger()
torch.distributed.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])


def main(args):
if args.device == "cuda":
torch.distributed.init_process_group(backend="nccl")
else:
torch.distributed.init_process_group(backend="gloo")
args_dict = vars(args)
if args.wandb and local_rank == 0:
wandb_api_key = (
Expand All @@ -48,7 +51,8 @@ def main(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)

torch.cuda.set_device(local_rank)
if args.device == "cuda":
torch.cuda.set_device(local_rank)

if args.train_data_path.endswith(".pkl"):
with open(args.train_data_path, "rb") as f:
Expand All @@ -72,12 +76,12 @@ def main(args):
forces,
stresses,
shuffle=True,
pin_memory=True,
pin_memory=(args.device == "cuda"),
is_distributed=True,
**args_dict,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = args.device
# build energy normalization module
if args.re_normalize:
scale = AtomScaling(
Expand Down Expand Up @@ -109,7 +113,7 @@ def main(args):
energies,
forces,
stresses,
pin_memory=True,
pin_memory=(args.device == "cuda"),
is_distributed=True,
**args_dict,
)
Expand All @@ -125,7 +129,8 @@ def main(args):
if args.re_normalize:
potential.model.set_normalizer(scale)

potential.model = torch.nn.parallel.DistributedDataParallel(potential.model)
if args.device == "cuda":
potential.model = torch.nn.parallel.DistributedDataParallel(potential.model)
torch.distributed.barrier()

potential.train_model(
Expand All @@ -136,7 +141,7 @@ def main(args):
**args_dict,
)

if local_rank == 0 and args.save_checkpoint:
if local_rank == 0 and args.save_checkpoint and args.wandb:
wandb.save(os.path.join(args.save_path, "best_model.pth"))


Expand Down
67 changes: 41 additions & 26 deletions src/mattersim/utils/vasprun_to_xyz.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,31 @@
# -*- coding: utf-8 -*-
import argparse
import os
import random

from ase.io import write

from mattersim.utils.atoms_utils import AtomsAdaptor

vasp_files = [
"work/data/H/vasp/vasprun.xml",
"work/data/H/vasp_2/vasprun.xml",
"work/data/H/vasp_3/vasprun.xml",
"work/data/H/vasp_4/vasprun.xml",
"work/data/H/vasp_5/vasprun.xml",
"work/data/H/vasp_6/vasprun.xml",
"work/data/H/vasp_7/vasprun.xml",
"work/data/H/vasp_8/vasprun.xml",
"work/data/H/vasp_9/vasprun.xml",
"work/data/H/vasp_10/vasprun.xml",
]
train_ratio = 0.8
validation_ratio = 0.1
test_ratio = 0.1

save_dir = "./xyz_files"
os.makedirs(save_dir, exist_ok=True)


def main():

def main(args):
vasp_files = []
for root, dirs, files in os.walk(args.data_path):
for file in files:
if file.endswith(".xml"):
vasp_files.append(os.path.join(root, file))

train_ratio = args.train_ratio
validation_ratio = args.validation_ratio

save_dir = args.save_path
os.makedirs(save_dir, exist_ok=True)

atoms_train = []
atoms_validation = []
atoms_test = []

random.seed(42)
random.seed(args.seed)

for vasp_file in vasp_files:
atoms_list = AtomsAdaptor.from_file(filename=vasp_file)
Expand All @@ -54,10 +48,31 @@ def main():

# Save the training, validation, and test datasets to xyz files

write(f"{save_dir}/train.xyz", atoms_train)
write(f"{save_dir}/valid.xyz", atoms_validation)
write(f"{save_dir}/test.xyz", atoms_test)
write(f"{save_dir}/train.xyz", atoms_train, format="extxyz")
write(f"{save_dir}/valid.xyz", atoms_validation, format="extxyz")
write(f"{save_dir}/test.xyz", atoms_test, format="extxyz")


if __name__ == "__main__":
main()
# Some important arguments
parser = argparse.ArgumentParser()

# path parameters
parser.add_argument("--data_path", type=str, default=None, help="vasprun data path")
parser.add_argument("--train_ratio", type=float, default=0.8, help="train ratio")
parser.add_argument(
"--validation_ratio", type=float, default=0.1, help="validation ratio"
)
parser.add_argument(
"--save_path",
type=str,
default="./xyz_files",
help="path to save the xyz files",
)
parser.add_argument(
"--seed",
type=int,
default=42,
)
args = parser.parse_args()
main(args)

0 comments on commit bebbbd0

Please sign in to comment.