Skip to content

Commit

Permalink
Adding version information to kwf files (#212)
Browse files Browse the repository at this point in the history
Currently `.kwf` files can be quite fickle because the format relies on the structure of the `Workflow` object itself, so whenever we change the `Workflow` class we lose the ability to read old `kwf` files.

Now, we have added explicitly the version number to the `Workflow` class. This will warn users if they are loading a `.kwf` file that was generated with a different version of `koopmans` than that which generated it, and, if the reading fails, they can always revert to the version of `koopmans` that generated the file in the first place.
  • Loading branch information
elinscott authored Oct 17, 2023
1 parent ba79a21 commit bbefdfc
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: ammaraskar/sphinx-action@master
with:
docs-folder: "docs/"
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
os: [ubuntu-latest, macOS-latest]
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
submodules: true
- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
os: [ubuntu-latest, macOS-latest]
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
submodules: true
- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/typechecking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
max-parallel: 3
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
submodules: false
- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/update_citation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Get branch name
id: branch-name
uses: tj-actions/branch-names@v6
- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion src/koopmans/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
else:
import importlib_metadata as metadata

__version__ = metadata.version('koopmans')
__version__: str = metadata.version('koopmans')
26 changes: 11 additions & 15 deletions src/koopmans/workflows/_ml.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
Written by Yannick Schubert Jul 2022
"""
import copy
import os
from pathlib import Path
from typing import Any, Dict, Tuple, Optional, List
from typing import Any, Dict, List, Tuple

import numpy as np
from deepdiff import DeepDiff
from sklearn.metrics import mean_absolute_error as mae

from koopmans import calculators, ml, utils
Expand Down Expand Up @@ -75,7 +73,7 @@ def _run(self):

# Specify for which bands we want to compute the decomposition
self.bands_to_extract = self.bands.to_solve
self.num_bands_occ = [len([band for band in self.bands if (band.filled == True and band.spin == spin)])
self.num_bands_occ = [len([band for band in self.bands if (band.filled and band.spin == spin)])
for spin in [0, 1]]
self.num_bands_to_extract = [len([band for band in self.bands_to_extract if band.filled == filled])
for filled in [True, False]]
Expand All @@ -96,7 +94,8 @@ def _run(self):
self.extract_input_vector_from_self_hartrees()
else:
raise ValueError(
f"{self.ml.input_data_for_ml_model} is currently not implemented as a valid input for the ml model.")
f"{self.ml.input_data_for_ml_model} is currently not implemented as a valid input for the ml "
"model.")

def extract_input_vector_from_self_hartrees(self):
"""
Expand All @@ -109,7 +108,7 @@ def extract_input_vector_from_self_hartrees(self):
else:
filled_str = 'emp'
np.savetxt(self.dirs['SH'] / f"SH.orbital.{filled_str}.{band.index}.txt",
np.array([SH[band.spin][band.index-1]]))
np.array([SH[band.spin][band.index - 1]]))
self.input_vectors_for_ml[f"SH.orbital.{filled_str}.{band.index}"] = SH

return
Expand Down Expand Up @@ -158,11 +157,11 @@ def convert_bin2xml(self):

def check_if_bin2xml_is_complete(self) -> bool:
"""
Checks if the bin2xml conversion was already performed.
For this we first check if the next step is complete, i.e. the decomposition into spherical harmonics and
radial basis functions. If we haven't yet computed the decomposition, we check if this step, i.e. the
conversion to xml, was already performed. This way we can delete the (possibly large) binary and xml files
Checks if the bin2xml conversion was already performed.
For this we first check if the next step is complete, i.e. the decomposition into spherical harmonics and
radial basis functions. If we haven't yet computed the decomposition, we check if this step, i.e. the
conversion to xml, was already performed. This way we can delete the (possibly large) binary and xml files
and only keep the expansion coefficients and still be able to restart the calculation.
"""

Expand Down Expand Up @@ -373,13 +372,10 @@ def get_alpha_from_file_for_debugging(self, band: Band) -> Tuple[float, float]:
if self.parameters.spin_polarized:
raise NotImplementedError('Need to check implementation')
assert isinstance(band.index, int)
return alphas[0][band.index-1], 0.0
return alphas[0][band.index - 1], 0.0

@classmethod
def fromdict(cls, dct: Dict[str, Any], **kwargs) -> Workflow:
calc_that_produced_orbital_densities = dct.pop('calc_that_produced_orbital_densities')
return super(MLFittingWorkflow, cls).fromdict(
dct, calc_that_produced_orbital_densities=calc_that_produced_orbital_densities, **kwargs)

def __eq__(self, other):
return DeepDiff(self, other) == {}
40 changes: 31 additions & 9 deletions src/koopmans/workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Workflow(ABC):
pseudo_dir: Path
projections: ProjectionBlocks
parent: Optional[Workflow]
version: str

def __init__(self, atoms: Atoms,
pseudopotentials: Dict[str, str] = {},
Expand All @@ -114,6 +115,7 @@ def __init__(self, atoms: Atoms,
plotting: Union[Dict[str, Any], settings.PlotSettingsDict] = {},
ml: Union[Dict[str, Any], settings.MLSettingsDict] = {},
autogenerate_settings: bool = True,
version: Optional[str] = None,
**kwargs: Dict[str, Any]):

# Parsing parameters
Expand Down Expand Up @@ -308,7 +310,8 @@ def __init__(self, atoms: Atoms,
calc_params[key] = value
match = True
# if not a calculator, workflow, or plotting keyword, raise an error
if not match and not self.parameters.is_valid(key) and not self.plotting.is_valid(key) and not self.ml.is_valid(key):
if not match and not self.parameters.is_valid(key) and not self.plotting.is_valid(key) \
and not self.ml.is_valid(key):
raise ValueError(f'{key} is not a valid setting')

# Adding excluded_bands info to self.projections
Expand All @@ -319,6 +322,13 @@ def __init__(self, atoms: Atoms,
label += f'_{spin}'
self.projections.exclude_bands[spin] = self.calculator_parameters[label].get('exclude_bands', [])

# Version number (important when loading workflows from .kwf files)
from koopmans import __version__
self.version = version if version is not None else __version__
if self.version != __version__:
utils.warn(f'You are using version {__version__} of koopmans, but this workflow was generated with '
f'version {self.version}. Proceed with caution.')

def __eq__(self, other: Any):
if isinstance(other, Workflow):
return self.__dict__ == other.__dict__
Expand Down Expand Up @@ -517,7 +527,8 @@ def _run_sanity_checks(self):

# Make sanity checks for the ML model
if self.ml.use_ml:
utils.warn("Predicting screening parameters with machine-learning is an experimental feature; proceed with caution")
utils.warn("Predicting screening parameters with machine-learning is an experimental feature; proceed with "
"caution")
if self.parameters.task not in ['trajectory', 'convergence_ml']:
raise NotImplementedError(
f'Using the ML-prediction for the {self.parameter.task}-task has not yet been implemented.')
Expand All @@ -526,17 +537,20 @@ def _run_sanity_checks(self):
f"Using the ML-prediction for the {self.parameters.method}-method has not yet been implemented")
if self.parameters.functional != 'ki':
raise NotImplementedError(
f'Using the ML-prediction for the {self.parameters.functional}-functional has not yet been implemented.')
f'Using the ML-prediction for the {self.parameters.functional}-functional has not yet been '
'implemented.')
if self.parameters.init_orbitals != 'mlwfs':
raise NotImplementedError(
f'Using the ML-prediction for {self.parameters.init_orbitals}-init orbitals has not yet been implemented.')
f'Using the ML-prediction for {self.parameters.init_orbitals}-init orbitals has not yet been '
'implemented.')
if self.parameters.init_empty_orbitals != self.parameters.init_orbitals:
raise NotImplementedError(
f'Using the ML-prediction for using different init orbitals for empty states than for occupied states has not yet been implemented.')
'Using the ML-prediction for using different init orbitals for empty states than for occupied '
'states has not yet been implemented.')
if self.parameters.spin_polarized:
utils.warn(f'Using the ML-prediction for spin-polarised systems has not yet been extensively tested.')
utils.warn('Using the ML-prediction for spin-polarised systems has not yet been extensively tested.')
if not all(self.atoms.pbc):
utils.warn(f'Using the ML-prediction for non-periodic systems has not yet been extensively tested.')
utils.warn('Using the ML-prediction for non-periodic systems has not yet been extensively tested.')
if self.parameters.orbital_groups:
utils.warn('Using orbital_groups has not yet been extensively tested.')
if not np.all(self.atoms.cell.angles() == 90.0):
Expand Down Expand Up @@ -564,7 +578,8 @@ def convert_to_list(param, type):
r_mins = convert_to_list(self.ml.r_min, float)
r_maxs = convert_to_list(self.ml.r_max, float)

# check that each n_max, l_max, r_max and r_min are greater or equal to 0 and that r_min is smaller than r_max
# check that each n_max, l_max, r_max and r_min are greater or equal to 0 and that r_min is smaller than
# r_max
for n_max in n_maxs:
if not n_max > 0:
raise ValueError(f"n_max has to be larger than zero. The provided value is n_max={n_max}")
Expand All @@ -576,7 +591,8 @@ def convert_to_list(param, type):
raise ValueError(f"r_min has to be equal or larger than zero. The provided value is r_min={r_min}")
if r_min < 0.5:
utils.warn(
f"Small values of r_min (<0.5) can lead to problems in the construction of the radial basis. The provided value is r_min={r_min}.")
"Small values of r_min (<0.5) can lead to problems in the construction of the radial basis. "
f"The provided value is r_min={r_min}.")
for r_max in r_maxs:
if not any(r_min < r_max for r_min in r_mins):
raise ValueError(f"All provided values of r_min are larger or equal to r_max={r_max}.")
Expand Down Expand Up @@ -940,6 +956,12 @@ def todict(self):

@classmethod
def fromdict(cls, dct: Dict[str, Any], **kwargs) -> Workflow:

# Remove __koopmans_name/module__ if present (won't happen if the encoder was used, but will happen if
# todict and fromdict are used directly)
dct.pop('__koopmans_name__', None)
dct.pop('__koopmans_module__', None)

wf = cls(atoms=dct.pop('atoms'),
parameters=dct.pop('parameters'),
calculator_parameters=dct.pop('calculator_parameters'),
Expand Down

0 comments on commit bbefdfc

Please sign in to comment.