diff --git a/src/mattersim/training/finetune_mattersim.py b/src/mattersim/training/finetune_mattersim.py index 09e8a7d..307afac 100644 --- a/src/mattersim/training/finetune_mattersim.py +++ b/src/mattersim/training/finetune_mattersim.py @@ -17,11 +17,14 @@ from mattersim.utils.logger_utils import get_logger logger = get_logger() -torch.distributed.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) def main(args): + if args.device == "cuda": + torch.distributed.init_process_group(backend="nccl") + else: + torch.distributed.init_process_group(backend="gloo") args_dict = vars(args) if args.wandb and local_rank == 0: wandb_api_key = ( @@ -48,7 +51,8 @@ def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) - torch.cuda.set_device(local_rank) + if args.device == "cuda": + torch.cuda.set_device(local_rank) if args.train_data_path.endswith(".pkl"): with open(args.train_data_path, "rb") as f: @@ -72,12 +76,12 @@ def main(args): forces, stresses, shuffle=True, - pin_memory=True, + pin_memory=(args.device == "cuda"), is_distributed=True, **args_dict, ) - device = "cuda" if torch.cuda.is_available() else "cpu" + device = args.device # build energy normalization module if args.re_normalize: scale = AtomScaling( @@ -109,7 +113,7 @@ def main(args): energies, forces, stresses, - pin_memory=True, + pin_memory=(args.device == "cuda"), is_distributed=True, **args_dict, ) @@ -125,7 +129,8 @@ def main(args): if args.re_normalize: potential.model.set_normalizer(scale) - potential.model = torch.nn.parallel.DistributedDataParallel(potential.model) + if args.device == "cuda": + potential.model = torch.nn.parallel.DistributedDataParallel(potential.model) torch.distributed.barrier() potential.train_model( @@ -136,7 +141,7 @@ def main(args): **args_dict, ) - if local_rank == 0 and args.save_checkpoint: + if local_rank == 0 and args.save_checkpoint and args.wandb: wandb.save(os.path.join(args.save_path, "best_model.pth")) diff --git a/src/mattersim/utils/vasprun_to_xyz.py b/src/mattersim/utils/vasprun_to_xyz.py index 39a1e11..2be5a82 100644 --- a/src/mattersim/utils/vasprun_to_xyz.py +++ b/src/mattersim/utils/vasprun_to_xyz.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import argparse import os import random @@ -6,32 +7,25 @@ from mattersim.utils.atoms_utils import AtomsAdaptor -vasp_files = [ - "work/data/H/vasp/vasprun.xml", - "work/data/H/vasp_2/vasprun.xml", - "work/data/H/vasp_3/vasprun.xml", - "work/data/H/vasp_4/vasprun.xml", - "work/data/H/vasp_5/vasprun.xml", - "work/data/H/vasp_6/vasprun.xml", - "work/data/H/vasp_7/vasprun.xml", - "work/data/H/vasp_8/vasprun.xml", - "work/data/H/vasp_9/vasprun.xml", - "work/data/H/vasp_10/vasprun.xml", -] -train_ratio = 0.8 -validation_ratio = 0.1 -test_ratio = 0.1 - -save_dir = "./xyz_files" -os.makedirs(save_dir, exist_ok=True) - - -def main(): + +def main(args): + vasp_files = [] + for root, dirs, files in os.walk(args.data_path): + for file in files: + if file.endswith(".xml"): + vasp_files.append(os.path.join(root, file)) + + train_ratio = args.train_ratio + validation_ratio = args.validation_ratio + + save_dir = args.save_path + os.makedirs(save_dir, exist_ok=True) + atoms_train = [] atoms_validation = [] atoms_test = [] - random.seed(42) + random.seed(args.seed) for vasp_file in vasp_files: atoms_list = AtomsAdaptor.from_file(filename=vasp_file) @@ -54,10 +48,31 @@ def main(): # Save the training, validation, and test datasets to xyz files - write(f"{save_dir}/train.xyz", atoms_train) - write(f"{save_dir}/valid.xyz", atoms_validation) - write(f"{save_dir}/test.xyz", atoms_test) + write(f"{save_dir}/train.xyz", atoms_train, format="extxyz") + write(f"{save_dir}/valid.xyz", atoms_validation, format="extxyz") + write(f"{save_dir}/test.xyz", atoms_test, format="extxyz") if __name__ == "__main__": - main() + # Some important arguments + parser = argparse.ArgumentParser() + + # path parameters + parser.add_argument("--data_path", type=str, default=None, help="vasprun data path") + parser.add_argument("--train_ratio", type=float, default=0.8, help="train ratio") + parser.add_argument( + "--validation_ratio", type=float, default=0.1, help="validation ratio" + ) + parser.add_argument( + "--save_path", + type=str, + default="./xyz_files", + help="path to save the xyz files", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + ) + args = parser.parse_args() + main(args)