Skip to content

Commit

Permalink
prepare for ZnTrack v0.8.0 release (#356)
Browse files Browse the repository at this point in the history
* prepare for ZnTrack v0.8.0 release

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use parameter property

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add type hint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* replace run with repro

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* git and dvc init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* uodate zntrack, refactor `atoms` to `frames`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* do not depend on ipsuite

* adapt test

* remove `meta` ref

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update apax nodes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test to use less data

* fix MD

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* format

* address comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Nov 28, 2024
1 parent 11b0fc7 commit c45b446
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 454 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,6 @@ md_config.yaml
# data
*.extxyz
*.traj
*.npz
*.npz

/models
3 changes: 2 additions & 1 deletion apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxCalibrate, ApaxEnsemble, ApaxImport
from .utils import AddData

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport", "ApaxCalibrate"]
__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport", "ApaxCalibrate", "AddData"]

try:
from .analysis import ApaxBatchPrediction # noqa: F401
Expand Down
25 changes: 17 additions & 8 deletions apax/nodes/analysis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import pathlib

import zntrack.utils
from ipsuite import base
import ase
import h5py
import znh5md
import zntrack

from .model import Apax
from apax.nodes.model import Apax

log = logging.getLogger(__name__)


class ApaxBatchPrediction(base.ProcessAtoms):
class ApaxBatchPrediction(zntrack.Node):
"""Create and Save the predictions from model on atoms.
Attributes
Expand All @@ -24,13 +27,19 @@ class ApaxBatchPrediction(base.ProcessAtoms):
predictions: list[Atoms] the atoms that have the predicted properties from model
"""

_module_ = "apax.nodes"
data: list[ase.Atoms] = zntrack.deps()

model: Apax = zntrack.deps()
batch_size: int = zntrack.params(1)
frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5")

def run(self):
self.atoms = []
calc = self.model.get_calculator()
data = self.get_data()
self.atoms = calc.batch_eval(data, self.batch_size)
frames = calc.batch_eval(self.data, self.batch_size)
znh5md.write(self.frames_path, frames)

@property
def frames(self) -> list[ase.Atoms]:
with self.state.fs.open(self.frames_path, "rb") as f:
with h5py.File(f, "r") as h5:
return znh5md.IO(file_handle=h5)[:]
46 changes: 22 additions & 24 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import logging
import pathlib
import typing
Expand All @@ -10,9 +9,8 @@
import zntrack.utils

from apax.md.simulate import run_md

from .model import ApaxBase
from .utils import check_duplicate_keys
from apax.nodes.model import ApaxBase
from apax.nodes.utils import check_duplicate_keys

log = logging.getLogger(__name__)

Expand All @@ -28,7 +26,7 @@ class ApaxJaxMD(zntrack.Node):
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
repeat: None|int|tuple[int, int, int]
number of repeats
config: str
path to the MD simulation parameter file
Expand All @@ -38,7 +36,7 @@ class ApaxJaxMD(zntrack.Node):
data_id: int = zntrack.params(-1)

model: ApaxBase = zntrack.deps()
repeat = zntrack.params(None)
repeat: None | int | tuple[int, int, int] = zntrack.params(None)

config: str = zntrack.params_path(None)

Expand All @@ -47,35 +45,35 @@ class ApaxJaxMD(zntrack.Node):
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
@property
def parameter(self) -> dict:
with self.state.fs.open(self.config, "r") as f:
parameter = yaml.safe_load(f)

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)
check_duplicate_keys(custom_parameters, parameter, log)
parameter.update(custom_parameters)

return parameter

def _write_initial_structure(self):
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

def run(self):
"""Primary method to run which executes all steps of the model training"""

if not self.state.restarted:
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)
self._write_initial_structure()

run_md(self.model._parameter, self._parameter, log_level="info")
run_md(self.model.parameter, self.parameter, log_level="info")

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
@property
def frames(self) -> typing.List[ase.Atoms]:
with self.state.fs.open(self.sim_dir / "md.h5", "rb") as f:
with h5py.File(f) as file:
return znh5md.IO(file_handle=file)[:]
83 changes: 37 additions & 46 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


class ApaxBase(zntrack.Node):
pass
def get_calculator(self, **kwargs):
raise NotImplementedError


class Apax(ApaxBase):
Expand All @@ -40,13 +41,13 @@ class Apax(ApaxBase):
verbosity of logging during training
"""

data: list = zntrack.deps()
data: list[ase.Atoms] = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)
validation_data: list[ase.Atoms] = zntrack.deps()
model: t.Optional[ApaxBase] = zntrack.deps(None)
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)
log_level: str = zntrack.meta.Text("info")
log_level: str = zntrack.params("info")

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

Expand All @@ -55,38 +56,34 @@ class Apax(ApaxBase):
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

_parameter: dict = None
@property
def parameter(self) -> dict:
parameter = yaml.safe_load(self.state.fs.read_text(self.config))

def _post_load_(self) -> None:
self._handle_parameter_file()
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

def _handle_parameter_file(self):
self._parameter = yaml.safe_load(self.state.fs.read_text(self.config))
if self.model is not None:
param_files = self.model.parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
parameter["checkpoints"].update(base_path)
except KeyError:
parameter["checkpoints"] = base_path

with self.state.use_tmp_path():
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)
check_duplicate_keys(custom_parameters, parameter["data"], log)
parameter["data"].update(custom_parameters)
return parameter

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter, log_level=self.log_level)
apax_run(self.parameter, log_level=self.log_level)

def get_metrics(self):
"""In addition to the plots write a model metric"""
Expand All @@ -104,7 +101,7 @@ def run(self):
if self.state.restarted and csv_path.is_file():
metrics_df = pd.read_csv(self.model_directory / "log.csv")

if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1:
if metrics_df["epoch"].iloc[-1] >= self.parameter["n_epochs"] - 1:
return

self.train_model()
Expand Down Expand Up @@ -156,8 +153,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]
param_files = [m.parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
Expand Down Expand Up @@ -192,14 +188,9 @@ class ApaxImport(zntrack.Node):
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
@property
def parameter(self) -> dict:
return yaml.safe_load(self.state.fs.read_text(self.config))

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Expand All @@ -210,8 +201,8 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
ase calculator object
"""

directory = self._parameter["data"]["directory"]
exp = self._parameter["data"]["experiment"]
directory = self.parameter["data"]["directory"]
exp = self.parameter["data"]["experiment"]
model_dir = directory + "/" + exp

transformations = []
Expand Down Expand Up @@ -251,7 +242,7 @@ class ApaxCalibrate(ApaxBase):
See the apax documentation for available methods.
"""

model: t.Any = zntrack.deps()
model: ApaxBase = zntrack.deps()
validation_data: list[Atoms] = zntrack.deps()
batch_size: int = zntrack.params(32)
criterion: str = zntrack.params("ma_cal")
Expand All @@ -262,7 +253,7 @@ class ApaxCalibrate(ApaxBase):

nl_skin: float = zntrack.params(0.5)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

def run(self):
"""Primary method to run which executes all steps of the model training"""
Expand Down Expand Up @@ -294,7 +285,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
e_factor = self.metrics["e_factor"]
f_factor = self.metrics["f_factor"]

config_file = self.model._parameter["data"]["directory"]
config_file = self.model.parameter["data"]["directory"]

calibration = GlobalCalibration(
energy_factor=e_factor,
Expand Down
Loading

0 comments on commit c45b446

Please sign in to comment.