Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kliff master v1 lightning #182

Merged
merged 26 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9108450
Trainer module
ipcamit Feb 15, 2024
2dc9be0
Merge branch 'kliff-master-v1' into kliff-trainer-v1
ipcamit Feb 20, 2024
731c5df
Trainer base class implemented
ipcamit Feb 23, 2024
86ab5d2
Trainer base class working
ipcamit Feb 28, 2024
2f57f56
First draft trainer framework
ipcamit Mar 4, 2024
01322ec
from config functionality in KIMModel
ipcamit Apr 6, 2024
2fa8bb8
DS and Model manifest initialization
ipcamit Apr 8, 2024
e1ef24e
Moved back from omegaconf to dict
ipcamit Apr 16, 2024
614c1b9
Working KIM trainer module
ipcamit Apr 17, 2024
d3050a6
Merged Eric's PR
ipcamit Apr 17, 2024
ccd5e29
Torch ml trainer added, to test
ipcamit Apr 19, 2024
a121895
working descriptor module
ipcamit Apr 23, 2024
18f9bc4
stress in loss function
ipcamit Apr 23, 2024
eee1459
Functional non NaN torch trainer
ipcamit May 13, 2024
0cbe35e
Functioning Lightning trainer, tested on nequip
ipcamit May 20, 2024
5dc5eee
Data resume capabilities added
ipcamit May 21, 2024
e8cb5fc
Dynamic loading in Lightning train
ipcamit May 26, 2024
26edfce
Added Lightning checkpoints for model save and loss traj + prelims fo…
ipcamit May 27, 2024
289a2a9
Modified dataset weights + tests to reflect
ipcamit Jun 3, 2024
1f4e6eb
Cleanup + indices
ipcamit Jun 5, 2024
533bf2f
implemented save kim odel
ipcamit Jun 9, 2024
9f65157
Checked model export
ipcamit Jun 10, 2024
c494f5e
Added restart capabilities
ipcamit Jun 10, 2024
b7a157f
Lightning trainer updates: Added KIM-API model export capabilities + …
ipcamit Jun 10, 2024
a14718b
Lightning trainer and tests
ipcamit Jun 18, 2024
1a5812b
Lightning trainer Comments #1
ipcamit Jun 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ tests/uq/*.pkl
tests/uq/*.json
tests/uq/kliff_saved_model

tests/trainer/test_run
tests/trainer/kliff.log
ipcamit marked this conversation as resolved.
Show resolved Hide resolved

# dataset
Si_training_set_4_configs
Si_training_set
Expand Down
405 changes: 373 additions & 32 deletions kliff/dataset/dataset.py

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions kliff/dataset/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,37 @@ def compute_weight(self, config):
def config_weight(self):
return self._config_weight

@config_weight.setter
def config_weight(self, value):
self._config_weight = value

@property
def energy_weight(self):
return self._energy_weight

@energy_weight.setter
def energy_weight(self, value):
self._energy_weight = value

@property
def forces_weight(self):
return self._forces_weight

@forces_weight.setter
def forces_weight(self, value):
self._forces_weight = value

@property
def stress_weight(self):
return self._stress_weight

@stress_weight.setter
def stress_weight(self, value):
self._stress_weight = value

def __repr__(self):
return f"Weights: config={self.config_weight}, energy={self.energy_weight}, forces={self.forces_weight}, stress={self.stress_weight}"

def _check_compute_flag(self, config):
"""
Check whether compute flag correctly set when the corresponding weight in
Expand Down
4 changes: 2 additions & 2 deletions kliff/descriptors/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def get_size(self):

def get_mean(self):
"""Return a list of the mean of the fingerprints."""
return self.mean.copy()
return self.mean

def get_stdev(self):
"""Return a list of the standard deviation of the fingerprints."""
return self.stdev.copy()
return self.stdev

def get_dtype(self):
"""Return the data type of the fingerprints."""
Expand Down
252 changes: 252 additions & 0 deletions kliff/models/kim.py
ipcamit marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import importlib
import os
import subprocess
import tarfile
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import kimpy
import numpy as np
from loguru import logger

Expand All @@ -12,6 +16,7 @@
from kliff.models.model import ComputeArguments, Model
from kliff.models.parameter import Parameter
from kliff.neighbor import assemble_forces, assemble_stress
from kliff.utils import install_kim_model, is_kim_model_installed

try:
import kimpy
Expand All @@ -21,6 +26,13 @@
except ImportError:
kimpy_avail = False

# list of model drivers that are not supported by this trainer.
# example quip, torchml, etc.
# TODO: Get the complete list of unsupported model drivers.
UNSUPPORTED_MODEL_DRIVERS = [
"TorchML",
]


class KIMComputeArguments(ComputeArguments):
"""
Expand Down Expand Up @@ -88,6 +100,8 @@ def __init__(
self._update_neigh(influence_distance)
self._register_data(compute_energy, compute_forces)

self.model_trainable_via_kim_api = False

def _get_implemented_property(self):
"""
Get implemented property of model.
Expand Down Expand Up @@ -681,6 +695,244 @@ def __call__(

return kim_ca_instance.results

@staticmethod
def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
"""
Get the model from a configuration. If it is a valid KIM model, it will return
the KIMModel object. If it is a TorchML model, it will return the torch
ReverseScriptedModule object *in future*. Else raise error. If the model is a tarball, it
will extract and install the model.

```{todo}
Get torchscript model from TorchML driver.
```

Example `model_manifest`:
```yaml
model:
model_type: kim # kim or torch
model_path: ./model.tar.gz # path to the model tarball
model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing
model_collection: "user"
```

Example `param_manifest`:
```yaml
parameter:
- A # dict means the parameter is transformed
- B # these are the parameters that are not transformed
- sigma:
transform_name: LogParameterTransform
value: 2.0
bounds: [[1.0, 10.0]]
```

```{note}
`parameter` block is usually defined as the children of the `transform` block
in trainer configuration file.
```

Args:
model_manifest: configuration object
param_manifest: parameter transformation configuration

Returns:
Model object
"""
model_name: Union[None, str] = model_manifest.get("name", None)
model_type: Union[None, str] = model_manifest.get("type", None)
model_path: Union[None, str, Path] = model_manifest.get("path", None)
model_driver = KIMModel.get_model_driver_name(model_name)
model_collection = model_manifest.get("collection")

if model_driver in UNSUPPORTED_MODEL_DRIVERS:
logger.error(
"Model driver not supported for KIM-API based training. "
"Please use appropriate trainer for this model."
)
raise KIMModelError(
f"Model driver {model_driver} not supported for KIMModel training."
)

# is model a tarball?
if model_path is not None:
model_path = Path(model_path)
if model_path.suffix == ".tar":
model_type = "tar"

# ensure model is installed
if model_type.lower() == "kim":
# is it a tar file?
is_model_installed = is_kim_model_installed(model_name)
if is_model_installed:
logger.info(f"Model {model_name} is already installed, continuing ...")
else:
logger.info(
f"Model {model_name} not installed on system, attempting to installing ..."
)
was_install_success = install_kim_model(model_name, model_collection)
if not was_install_success:
logger.error(
f"Model {model_name} not found in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
else:
logger.info(
f"Model {model_name} installed in {model_collection} collection."
)

elif model_type.lower() == "tar":
archive_content = tarfile.open(model_path + "/" + model_name)
model = archive_content.getnames()[0]
archive_content.extractall(model_path)
subprocess.run(
[
"kim-api-collections-management",
"install",
"--force",
model_collection,
model_path + "/" + model,
],
check=True,
)
logger.info(
f"Tarball Model {model} installed in {model_collection} collection."
)
else:
raise KIMModelError(f"Model type {model_type} not supported.")

model = KIMModel(model_name)

if param_manifest:
mutable_param_list = []
for param_to_transform in param_manifest.get("parameter", []):
if isinstance(param_to_transform, dict):
parameter_name = list(param_to_transform.keys())[0]
elif isinstance(param_to_transform, str):
parameter_name = param_to_transform
else:
raise KIMModelError(f"Parameter can be a str or dict")
mutable_param_list.append(parameter_name)

model.set_params_mutable(mutable_param_list)
model_param_list = model.parameters()

# apply transforms if needed
for model_params, input_params in zip(
model_param_list, param_manifest.get("parameter", [])
):
if isinstance(input_params, dict):
param_name = list(input_params.keys())[0]
if param_name != model_params.name:
raise KIMModelError(
f"Parameter name mismatch. Expected {model_params.name}, got {param_name}."
)

param_value_dict = input_params[param_name]
transform_name = param_value_dict.get("transform_name", None)
params_value = param_value_dict.get("value", None)
bounds = param_value_dict.get("bounds", None)

if transform_name is not None:
transform_module = getattr(
importlib.import_module(
f"kliff.transforms.parameter_transforms"
),
transform_name,
)
transform_module = transform_module()
model_params.add_transform(transform_module)

if params_value is not None:
model_params.copy_from_model_space(params_value)

if bounds is not None:
model_params.add_bounds_model_space(np.array(bounds))

elif isinstance(input_params, str):
if input_params != model_params.name:
raise KIMModelError(
f"Parameter name mismatch. Expected {model_params.name}, got {input_params}."
)
else:
raise KIMModelError(
f"Optimizable parameters must be string or value dict. Got {input_params} instead."
)

return model

@staticmethod
def get_model_driver_name(model_name: str) -> Union[str, None]:
"""
Get the model driver from the model name. It will return the model driver
string from the installed KIM API model. If the model is not installed, and the
model name is a tarball, it will extract the model driver name from the CMakeLists.txt.
This is needed to ensure that it excludes the model drivers that it cannot handle.
Example: TorchML driver based models. These models are to be trained using the
TorchTrainer.

TODO: This is not a clean solution. I think KIMPY must have a better way to handle this.
Ask Mingjian/Yaser for comment.

Args:
model_name: name of the model.

Returns:
Model driver name.
"""
# check if model is tarball
if "tar" in model_name:
return KIMModel._get_model_driver_name_for_tarball(model_name)

collections = kimpy.collections.create()
try:
shared_obj_path, collection = (
collections.get_item_library_file_name_and_collection(
kimpy.collection_item_type.portableModel, model_name
)
)
except RuntimeError: # not a portable model
return None
shared_obj_content = open(shared_obj_path, "rb").read()
md_start_idx = shared_obj_content.find(b"model-driver")

if md_start_idx == -1:
return None
else:
md_start_idx += 15 # length of 'model-driver" "'
md_end_idx = shared_obj_content.find(b'"', md_start_idx)
return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8")

@staticmethod
def _get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]:
"""
Get the model driver name from the tarball. It will extract the model driver
name from the CMakeLists.txt file in the tarball. This is needed to ensure that
it excludes the model drivers that it cannot handle. Example: TorchML driver based
models. These models are to be trained using the TorchTrainer.

Args:
tarball: path to the tarball.

Returns:
Model driver name.
"""
archive_content = tarfile.open(tarball)
cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt"
cmake_file = archive_content.extractfile(cmake_file_path)
cmake_file_content = cmake_file.read().decode("utf-8")

md_start_idx = cmake_file_content.find("DRIVER_NAME")
if md_start_idx == -1:
return None
else:
# name strats at "
md_start_idx = cmake_file_content.find('"', md_start_idx) + 1
if md_start_idx == -1:
return None
md_end_idx = cmake_file_content.find('"', md_start_idx)
return cmake_file_content[md_start_idx:md_end_idx]


class KIMModelError(Exception):
def __init__(self, msg):
Expand Down
5 changes: 5 additions & 0 deletions kliff/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base_trainer import Trainer
from .lightning_trainer import GNNLightningTrainer

# from .kim_trainer import KIMTrainer
# from .torch_trainer import DNNTrainer
Loading
Loading