From 882300e9c9f57e41e85fb0b6f73445251efceb71 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 28 Nov 2024 16:48:52 +0100 Subject: [PATCH] address comments --- apax/nodes/md.py | 4 ++-- apax/nodes/model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apax/nodes/md.py b/apax/nodes/md.py index 8719eef5..58a052bb 100644 --- a/apax/nodes/md.py +++ b/apax/nodes/md.py @@ -26,7 +26,7 @@ class ApaxJaxMD(zntrack.Node): index of the configuration from the data list to use model: ApaxModel model to use for the simulation - repeat: float + repeat: None|int|tuple[int, int, int] number of repeats config: str path to the MD simulation parameter file @@ -36,7 +36,7 @@ class ApaxJaxMD(zntrack.Node): data_id: int = zntrack.params(-1) model: ApaxBase = zntrack.deps() - repeat: typing.Optional[bool] = zntrack.params(None) + repeat: None|int|tuple[int, int, int] = zntrack.params(None) config: str = zntrack.params_path(None) diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 80b36baf..90d6f49f 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -41,7 +41,7 @@ class Apax(ApaxBase): verbosity of logging during training """ - data: list = zntrack.deps() + data: list[ase.Atoms] = zntrack.deps() config: str = zntrack.params_path() validation_data: list[ase.Atoms] = zntrack.deps() model: t.Optional[ApaxBase] = zntrack.deps(None)