Skip to content

Commit

Permalink
Lightning trainer updates: Added KIM-API model export capabilities + …
Browse files Browse the repository at this point in the history
…ckpt resume capabilities + docstrings
  • Loading branch information
ipcamit committed Jun 10, 2024
2 parents 1f4e6eb + c494f5e commit b7a157f
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 40 deletions.
67 changes: 57 additions & 10 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@
import multiprocessing
import os
import random
import sys
from copy import deepcopy
from datetime import datetime, timedelta
from glob import glob
from pathlib import Path
from typing import Callable, List, Union
from typing import TYPE_CHECKING, Callable, List, Union

Check warning on line 12 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L1-L12

Added lines #L1 - L12 were not covered by tests

import dill # TODO: include dill in requirements.txt
import numpy as np
import yaml
from loguru import logger

Check warning on line 17 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L14-L17

Added lines #L14 - L17 were not covered by tests

from kliff.dataset import Dataset

Check warning on line 19 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L19

Added line #L19 was not covered by tests

if TYPE_CHECKING:
from kliff.transforms.configuration_transforms import ConfigurationTransform

Check warning on line 22 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L21-L22

Added lines #L21 - L22 were not covered by tests

from kliff.utils import get_n_configs_in_xyz

Check warning on line 24 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L24

Added line #L24 was not covered by tests


Expand Down Expand Up @@ -87,7 +92,7 @@ def __init__(self, training_manifest: dict, model=None):
},
}
self.property_transforms = []
self.configuration_transform = None
self.configuration_transform: "ConfigurationTransform" = None

Check warning on line 95 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L94-L95

Added lines #L94 - L95 were not covered by tests

# training variables
# this is too complicated to put it in singe dict, therefore the training
Expand All @@ -112,7 +117,7 @@ def __init__(self, training_manifest: dict, model=None):
"kwargs": None,
"epochs": 10000,
"stop_condition": None,
"num_workers": 1,
"num_workers": None,
"batch_size": 1,
}
self.optimizer = None

Check warning on line 123 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L123

Added line #L123 was not covered by tests
Expand All @@ -130,7 +135,6 @@ def __init__(self, training_manifest: dict, model=None):

# export trained model
self.export_manifest: dict = {

Check warning on line 137 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L137

Added line #L137 was not covered by tests
"model_type": None,
"model_name": None,
"model_path": None,
}
Expand Down Expand Up @@ -250,7 +254,7 @@ def parse_manifest(self, manifest: dict):
"stop_condition", None
)
self.optimizer_manifest["num_workers"] = self.training_manifest.get(

Check warning on line 256 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L256

Added line #L256 was not covered by tests
"num_workers", 1
"num_workers", None
)
self.optimizer_manifest["batch_size"] = self.training_manifest.get(

Check warning on line 259 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L259

Added line #L259 was not covered by tests
"batch_size", 1
Expand Down Expand Up @@ -348,8 +352,15 @@ def setup_workspace(self):
"""
Check all the existing runs in the root directory and see if it finished the run
"""
dir_list = sorted(glob(f"{self.workspace['name']}*"))
dir_list = sorted(

Check warning on line 355 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L355

Added line #L355 was not covered by tests
glob(f"{self.workspace['name']}/{self.model_manifest['name']}*")
)
dir_list = [p for p in dir_list if os.path.isdir(p)]

Check warning on line 358 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L358

Added line #L358 was not covered by tests

if len(dir_list) == 0 or not self.workspace["resume"]:
logger.info(

Check warning on line 361 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L360-L361

Added lines #L360 - L361 were not covered by tests
"Either a fresh run or resume is not requested. Starting a new run."
)
self.current["appending_to_previous_run"] = False
self.current["run_dir"] = (

Check warning on line 365 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L364-L365

Added lines #L364 - L365 were not covered by tests
f"{self.workspace['name']}/{self.current['run_title']}"
Expand All @@ -359,12 +370,17 @@ def setup_workspace(self):
last_dir = dir_list[-1]
was_it_finished = os.path.exists(f"{last_dir}/.finished")
if was_it_finished: # start new run
current_run_dir = (
f"{self.workspace['name']}/{self.current['run_title']}"
logger.warning(

Check warning on line 373 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L370-L373

Added lines #L370 - L373 were not covered by tests
"Resuming from last training was requested, but it was completed. Exiting."
)
os.makedirs(current_run_dir, exist_ok=True)
self.current["appending_to_previous_run"] = False
# current_run_dir = (
# f"{self.workspace['name']}/{self.current['run_title']}"
# )
# os.makedirs(current_run_dir, exist_ok=True)
# self.current["appending_to_previous_run"] = False
sys.exit()

Check warning on line 381 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L381

Added line #L381 was not covered by tests
else:
logger.info("Last trainer was not finished. Resuming the training.")
self.current["appending_to_previous_run"] = True
self.current["run_dir"] = dir_list[-1]

Check warning on line 385 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L383-L385

Added lines #L383 - L385 were not covered by tests

Expand Down Expand Up @@ -626,6 +642,37 @@ def train(self, *args, **kwargs):
def save_kim_model(self, *args, **kwargs):
raise TrainerError("save_kim_model not implemented.")

Check warning on line 643 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L642-L643

Added lines #L642 - L643 were not covered by tests

@staticmethod
def _generate_kim_cmake(model_name: str, driver_name: str, file_list: List) -> str:

Check warning on line 646 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L645-L646

Added lines #L645 - L646 were not covered by tests
"""
Generate the CMakeLists.txt file for KIM API. This will be used to compile the
driver with the KIM API. The driver name is the name of the driver, and the file
list is the list of files to be included in the CMakeLists.txt file.
Private method.
Args:
driver_name: Name of the driver
file_list: List of files to be included in the CMakeLists.txt file
Returns:
CMakeLists.txt file as a string
"""
model_name = model_name.replace("-", "_")
cmake = f"""cmake_minimum_required(VERSION 3.10)

Check warning on line 659 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L658-L659

Added lines #L658 - L659 were not covered by tests
list(APPEND CMAKE_PREFIX_PATH $ENV{{KIM_API_CMAKE_PREFIX_DIR}})
find_package(KIM-API-ITEMS 2.2 REQUIRED CONFIG)
kim_api_items_setup_before_project(ITEM_TYPE "portableModel")
project({model_name})
kim_api_items_setup_after_project(ITEM_TYPE "portableModel")
add_kim_api_model_library(
NAME ${{PROJECT_NAME}}
DRIVER_NAME "{driver_name}"
PARAMETER_FILES {" ".join(file_list)}
)
"""
return cmake

Check warning on line 674 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L674

Added line #L674 was not covered by tests


# Parallel processing for dataset loading #############################################
def _parallel_read(

Check warning on line 678 in kliff/trainer/base_trainer.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/base_trainer.py#L678

Added line #L678 was not covered by tests
Expand Down
Loading

0 comments on commit b7a157f

Please sign in to comment.