Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare for ZnTrack v0.8.0 release #356

Merged
merged 32 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5444eac
prepare for ZnTrack v0.8.0 release
PythonFZ Oct 21, 2024
9d194a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
9413a8f
use parameter property
PythonFZ Oct 22, 2024
95e9875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
ed10604
add type hint
PythonFZ Oct 23, 2024
fabb21e
Merge branch 'zntrack-v08' of https://github.com/apax-hub/apax into z…
PythonFZ Oct 23, 2024
522d03b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
3aed9a4
replace run with repro
PythonFZ Oct 23, 2024
917d549
Merge branch 'zntrack-v08' of https://github.com/apax-hub/apax into z…
PythonFZ Oct 23, 2024
9b7cce7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
7ccdf09
git and dvc init
PythonFZ Oct 23, 2024
e91c9cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
580ce80
Merge branch 'main' into zntrack-v08
PythonFZ Oct 23, 2024
c161660
Merge branch 'main' into zntrack-v08
PythonFZ Nov 20, 2024
9c49761
uodate zntrack, refactor `atoms` to `frames`
PythonFZ Nov 20, 2024
1179e78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
c774885
Merge branch 'main' into zntrack-v08
PythonFZ Nov 28, 2024
8cec570
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2024
b49205e
do not depend on ipsuite
PythonFZ Nov 28, 2024
3684597
Merge branch 'zntrack-v08' of https://github.com/apax-hub/apax into z…
PythonFZ Nov 28, 2024
593da72
adapt test
PythonFZ Nov 28, 2024
e5dedea
remove `meta` ref
PythonFZ Nov 28, 2024
31f8ae4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2024
2381b37
update apax nodes
PythonFZ Nov 28, 2024
5f58681
Merge branch 'zntrack-v08' of https://github.com/apax-hub/apax into z…
PythonFZ Nov 28, 2024
75af2b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2024
226b0c8
update test to use less data
PythonFZ Nov 28, 2024
fdc7b86
fix MD
PythonFZ Nov 28, 2024
ee967a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2024
46fb4da
format
PythonFZ Nov 28, 2024
882300e
address comments
PythonFZ Nov 28, 2024
265d484
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,6 @@ md_config.yaml
# data
*.extxyz
*.traj
*.npz
*.npz

/models
3 changes: 2 additions & 1 deletion apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxCalibrate, ApaxEnsemble, ApaxImport
from .utils import AddData

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport", "ApaxCalibrate"]
__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport", "ApaxCalibrate", "AddData"]

try:
from .analysis import ApaxBatchPrediction # noqa: F401
Expand Down
25 changes: 17 additions & 8 deletions apax/nodes/analysis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import pathlib

import zntrack.utils
from ipsuite import base
import ase
import h5py
import znh5md
import zntrack

from .model import Apax
from apax.nodes.model import Apax

log = logging.getLogger(__name__)


class ApaxBatchPrediction(base.ProcessAtoms):
class ApaxBatchPrediction(zntrack.Node):
"""Create and Save the predictions from model on atoms.

Attributes
Expand All @@ -24,13 +27,19 @@ class ApaxBatchPrediction(base.ProcessAtoms):
predictions: list[Atoms] the atoms that have the predicted properties from model
"""

_module_ = "apax.nodes"
data: list[ase.Atoms] = zntrack.deps()

model: Apax = zntrack.deps()
batch_size: int = zntrack.params(1)
frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5")

def run(self):
self.atoms = []
calc = self.model.get_calculator()
data = self.get_data()
self.atoms = calc.batch_eval(data, self.batch_size)
frames = calc.batch_eval(self.data, self.batch_size)
znh5md.write(self.frames_path, frames)

@property
def frames(self) -> list[ase.Atoms]:
with self.state.fs.open(self.frames_path, "rb") as f:
with h5py.File(f, "r") as h5:
return znh5md.IO(file_handle=h5)[:]
46 changes: 22 additions & 24 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import logging
import pathlib
import typing
Expand All @@ -10,9 +9,8 @@
import zntrack.utils

from apax.md.simulate import run_md

from .model import ApaxBase
from .utils import check_duplicate_keys
from apax.nodes.model import ApaxBase
from apax.nodes.utils import check_duplicate_keys

log = logging.getLogger(__name__)

Expand All @@ -28,7 +26,7 @@ class ApaxJaxMD(zntrack.Node):
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
repeat: None|int|tuple[int, int, int]
number of repeats
config: str
path to the MD simulation parameter file
Expand All @@ -38,7 +36,7 @@ class ApaxJaxMD(zntrack.Node):
data_id: int = zntrack.params(-1)

model: ApaxBase = zntrack.deps()
repeat = zntrack.params(None)
repeat: None | int | tuple[int, int, int] = zntrack.params(None)

config: str = zntrack.params_path(None)

Expand All @@ -47,35 +45,35 @@ class ApaxJaxMD(zntrack.Node):
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
@property
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved
def parameter(self) -> dict:
with self.state.fs.open(self.config, "r") as f:
parameter = yaml.safe_load(f)

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)
check_duplicate_keys(custom_parameters, parameter, log)
parameter.update(custom_parameters)

return parameter

def _write_initial_structure(self):
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

def run(self):
"""Primary method to run which executes all steps of the model training"""

if not self.state.restarted:
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)
self._write_initial_structure()

run_md(self.model._parameter, self._parameter, log_level="info")
run_md(self.model.parameter, self.parameter, log_level="info")

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
@property
def frames(self) -> typing.List[ase.Atoms]:
with self.state.fs.open(self.sim_dir / "md.h5", "rb") as f:
with h5py.File(f) as file:
return znh5md.IO(file_handle=file)[:]
83 changes: 37 additions & 46 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


class ApaxBase(zntrack.Node):
pass
def get_calculator(self, **kwargs):
raise NotImplementedError


class Apax(ApaxBase):
Expand All @@ -40,13 +41,13 @@ class Apax(ApaxBase):
verbosity of logging during training
"""

data: list = zntrack.deps()
data: list[ase.Atoms] = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)
validation_data: list[ase.Atoms] = zntrack.deps()
model: t.Optional[ApaxBase] = zntrack.deps(None)
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)
log_level: str = zntrack.meta.Text("info")
log_level: str = zntrack.params("info")

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

Expand All @@ -55,38 +56,34 @@ class Apax(ApaxBase):
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

_parameter: dict = None
@property
def parameter(self) -> dict:
parameter = yaml.safe_load(self.state.fs.read_text(self.config))

def _post_load_(self) -> None:
self._handle_parameter_file()
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

def _handle_parameter_file(self):
self._parameter = yaml.safe_load(self.state.fs.read_text(self.config))
if self.model is not None:
param_files = self.model.parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
parameter["checkpoints"].update(base_path)
except KeyError:
parameter["checkpoints"] = base_path

with self.state.use_tmp_path():
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)
check_duplicate_keys(custom_parameters, parameter["data"], log)
parameter["data"].update(custom_parameters)
return parameter

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter, log_level=self.log_level)
apax_run(self.parameter, log_level=self.log_level)

def get_metrics(self):
"""In addition to the plots write a model metric"""
Expand All @@ -104,7 +101,7 @@ def run(self):
if self.state.restarted and csv_path.is_file():
metrics_df = pd.read_csv(self.model_directory / "log.csv")

if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1:
if metrics_df["epoch"].iloc[-1] >= self.parameter["n_epochs"] - 1:
return

self.train_model()
Expand Down Expand Up @@ -156,8 +153,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]
param_files = [m.parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
Expand Down Expand Up @@ -192,14 +188,9 @@ class ApaxImport(zntrack.Node):
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
@property
def parameter(self) -> dict:
return yaml.safe_load(self.state.fs.read_text(self.config))

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Expand All @@ -210,8 +201,8 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
ase calculator object
"""

directory = self._parameter["data"]["directory"]
exp = self._parameter["data"]["experiment"]
directory = self.parameter["data"]["directory"]
exp = self.parameter["data"]["experiment"]
model_dir = directory + "/" + exp

transformations = []
Expand Down Expand Up @@ -251,7 +242,7 @@ class ApaxCalibrate(ApaxBase):
See the apax documentation for available methods.
"""

model: t.Any = zntrack.deps()
model: ApaxBase = zntrack.deps()
validation_data: list[Atoms] = zntrack.deps()
batch_size: int = zntrack.params(32)
criterion: str = zntrack.params("ma_cal")
Expand All @@ -262,7 +253,7 @@ class ApaxCalibrate(ApaxBase):

nl_skin: float = zntrack.params(0.5)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

def run(self):
"""Primary method to run which executes all steps of the model training"""
Expand Down Expand Up @@ -294,7 +285,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
e_factor = self.metrics["e_factor"]
f_factor = self.metrics["f_factor"]

config_file = self.model._parameter["data"]["directory"]
config_file = self.model.parameter["data"]["directory"]

calibration = GlobalCalibration(
energy_factor=e_factor,
Expand Down
Loading