Skip to content

Commit

Permalink
missing updated from master
Browse files Browse the repository at this point in the history
  • Loading branch information
benrich37 committed Jan 23, 2025
1 parent e0405b4 commit ea78047
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 13 deletions.
98 changes: 97 additions & 1 deletion src/pymatgen/io/jdftx/_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
]


def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
def _get_atom_orb_labels_ref_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
"""
Return a dictionary mapping each atom symbol to all atomic orbital projection string representations.
Expand Down Expand Up @@ -546,3 +546,99 @@ def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
else:
labels_dict[sym] += mls
return labels_dict


def _get_atom_count_list(bandfile_filepath: Path) -> list[tuple[str, int]]:
"""
Return a list of tuples of atom symbols and counts.
Return a list of tuples of atom symbols and counts. This is superior to a dictionary as it maintains the order of
the atoms in the bandfile.
Args:
bandfile_filepath (str | Path): The path to the bandfile.
Returns:
list[tuple[str, int]]: A list of tuples of atom symbols and counts.
"""
bandfile = read_file(bandfile_filepath)
atom_count_list = []

for i, line in enumerate(bandfile):
if i > 1:
if "#" in line:
break
lsplit = line.strip().split()
sym = lsplit[0].strip()
count = int(lsplit[1].strip())
atom_count_list.append((sym, count))
return atom_count_list


def _get_orb_label_list_expected_len(labels_dict: dict[str, list[str]], atom_count_list: list[tuple[str, int]]) -> int:
"""
Return the expected length of the atomic orbital projection string representation list.
Return the expected length of the atomic orbital projection string representation list.
Args:
labels_dict (dict[str, list[str]]): A dictionary mapping each atom symbol to all atomic orbital projection
string representations.
atom_count_list (list[tuple[str, int]]): A list of tuples of atom symbols and counts.
Returns:
int: The expected length of the atomic orbital projection string representation list.
"""
expected_len = 0
for ion_tuple in atom_count_list:
ion = ion_tuple[0]
count = ion_tuple[1]
orbs = labels_dict[ion]
expected_len += count * len(orbs)
return expected_len


def _get_orb_label(ion: str, idx: int, orb: str) -> str:
"""
Return the string representation for an orbital projection.
Return the string representation for an orbital projection.
Args:
ion (str): The symbol of the atom.
idx (int): The index of the atom.
orb (str): The atomic orbital projection string representation.
Returns:
str: The atomic orbital projection string representation for the atom.
"""
return f"{ion}#{idx + 1}({orb})"


def _get_orb_label_list(bandfile_filepath: Path) -> tuple[str, ...]:
"""
Return a tuple of all atomic orbital projection string representations.
Return a tuple of all atomic orbital projection string representations.
Args:
bandfile_filepath (str | Path): The path to the bandfile.
Returns:
tuple[str]: A list of all atomic orbital projection string representations.
"""
labels_dict = _get_atom_orb_labels_ref_dict(bandfile_filepath)
atom_count_list = _get_atom_count_list(bandfile_filepath)
read_file(bandfile_filepath)
labels_list: list[str] = []
for ion_tuple in atom_count_list:
ion = ion_tuple[0]
orbs = labels_dict[ion]
count = ion_tuple[1]
for i in range(count):
for orb in orbs:
labels_list.append(_get_orb_label(ion, i, orb))
# This is most likely unnecessary, but it is a good check to have.
if len(labels_list) != _get_orb_label_list_expected_len(labels_dict, atom_count_list):
raise RuntimeError("Number of atomic orbital projections does not match expected length.")
return tuple(labels_list)
1 change: 1 addition & 0 deletions src/pymatgen/io/jdftx/joutstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _from_text_slice(
instance = cls(
lattice=init_structure.lattice.matrix,
species=init_structure.species,
coords_are_cartesian=True,
coords=init_structure.cart_coords,
site_properties=init_structure.site_properties,
)
Expand Down
37 changes: 29 additions & 8 deletions src/pymatgen/io/jdftx/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class is written.
import numpy as np

from pymatgen.io.jdftx._output_utils import (
_get_atom_orb_labels_dict,
_get_nbands_from_bandfile_filepath,
_get_orb_label_list,
get_proj_tju_from_file,
read_outfile_slices,
)
Expand Down Expand Up @@ -90,19 +90,34 @@ class JDFTXOutputs:
(nspin, nkpt, nbands, nion, nionproj) to save on memory as nonionproj is different depending on the ion
type. This array may also be complex if specified in 'band-projections-params' in the JDFTx input, allowing
for pCOHP analysis.
eigenvals (np.ndarray): The eigenvalues. Stored in shape (nstates, nbands) where nstates is nspin*nkpts (nkpts
may not equal prod(kfolding) if symmetry reduction occurred) and nbands is the number of bands.
orb_label_list (tuple[str, ...]): A tuple of the orbital labels for the bandProjections file, where the i'th
element describes the i'th orbital. Orbital labels are formatted as "<ion>#<ion-number>(<orbital>)",
where <ion> is the element symbol of the ion, <ion-number> is the 1-based index of the ion-type in the
structure (ie C#2 would be the second carbon atom, but not necessarily the second ion in the structure),
and <orbital> is a string describing "l" and "ml" quantum numbers (ie "p_x" or "d_yz"). Note that while "z"
corresponds to the "z" axis, "x" and "y" are arbitrary and may not correspond to the actual x and y axes of
the structure. In the case where multiple shells of a given "l" are available within the projections, a
0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for
second shell). The actual principal quantum number is not stored in the JDFTx output files and must be
inferred by the user.
"""

calc_dir: str | Path = field(init=True)
outfile_name: str | Path | None = field(init=True)
store_vars: list[str] = field(default_factory=list, init=True)
paths: dict[str, Path] = field(init=False)
outfile: JDFTXOutfile = field(init=False)
bandProjections: np.ndarray | None = field(init=False)
eigenvals: np.ndarray | None = field(init=False)
# Misc metadata for interacting with the data
atom_orb_labels_dict: dict[int, str] | None = field(init=False)
orb_label_list: tuple[str, ...] | None = field(init=False)

@classmethod
def from_calc_dir(cls, calc_dir: str | Path, store_vars: list[str] | None = None) -> JDFTXOutputs:
def from_calc_dir(
cls, calc_dir: str | Path, store_vars: list[str] | None = None, outfile_name: str | Path | None = None
) -> JDFTXOutputs:
"""
Create a JDFTXOutputs object from a directory containing JDFTx out files.
Expand All @@ -113,12 +128,15 @@ def from_calc_dir(cls, calc_dir: str | Path, store_vars: list[str] | None = None
none_slice_on_error (bool): If True, will return None if an error occurs while parsing a slice instead of
halting the parsing process. This can be useful for parsing files with multiple slices where some slices
may be incomplete or corrupted.
outfile_name (str | Path): The name of the outfile to use. If None, will search for the outfile in the
calc_dir. If provided, will concatenate with calc_dir as the outfile path. Use this if the calc_dir
contains multiple files that may be mistaken for the outfile (ie multiple files with the '.out' suffix).
Returns:
JDFTXOutputs: The JDFTXOutputs object.
"""
if store_vars is None:
store_vars = []
return cls(calc_dir=Path(calc_dir), store_vars=store_vars)
return cls(calc_dir=Path(calc_dir), store_vars=store_vars, outfile_name=outfile_name)

def __post_init__(self):
self._init_paths()
Expand All @@ -128,7 +146,12 @@ def _init_paths(self):
self.paths = {}
if self.calc_dir is None:
raise ValueError("calc_dir must be set as not None before initializing.")
outfile_path = _find_jdftx_out_file(self.calc_dir)
if self.outfile_name is None:
outfile_path = _find_jdftx_out_file(self.calc_dir)
else:
outfile_path = self.calc_dir / self.outfile_name
if not outfile_path.exists():
raise FileNotFoundError(f"Provided outfile path {outfile_path} does not exist.")
self.outfile = JDFTXOutfile.from_file(outfile_path)
prefix = self.outfile.prefix
for fname in dump_file_names:
Expand Down Expand Up @@ -167,15 +190,13 @@ def _check_bandProjections(self):
def _store_bandProjections(self):
if "bandProjections" in self.paths:
self.bandProjections = get_proj_tju_from_file(self.paths["bandProjections"])
self.atom_orb_labels_dict = _get_atom_orb_labels_dict(self.paths["bandProjections"])
self.orb_label_list = _get_orb_label_list(self.paths["bandProjections"])

def _check_eigenvals(self):
"""Check for misaligned data within eigenvals file."""
if "eigenvals" in self.paths:
if not self.paths["eigenvals"].exists():
raise RuntimeError("Allocated path for eigenvals does not exist.")
# TODO: We should not have to load the entire file to find its length - replace with something more
# efficient once Claude lets me create an account.
tj = len(np.fromfile(self.paths["eigenvals"]))
nstates_float = tj / self.outfile.nbands
if not np.isclose(nstates_float, int(nstates_float)):
Expand Down
8 changes: 4 additions & 4 deletions tests/io/jdftx/outputs_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict):
"eigenvals": n2_ex_calc_dir / Path("eigenvals"),
}
n2_ex_calc_dir_bandprojections_metadata = {
"atom_orb_labels_dict": {
"N": ["s", "px", "py", "pz"],
},
"orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "N#2(s)", "N#2(px)", "N#2(py)", "N#2(pz)"],
"shape": (54, 15, 8),
"first val": -0.1331527 + 0.5655596j,
}


Expand All @@ -150,8 +149,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict):
"eigenvals": nh3_ex_calc_dir / Path("eigenvals"),
}
nh3_ex_calc_dir_bandprojections_metadata = {
"atom_orb_labels_dict": {"N": ["s", "px", "py", "pz"], "H": ["s"]},
"orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "H#1(s)", "H#2(s)", "H#3(s)"],
"shape": (16, 14, 7),
"first val": -0.0688767 + 0.9503786j,
}

example_sp_outfile_path = ex_out_files_dir / Path("example_sp.out")
Expand Down

0 comments on commit ea78047

Please sign in to comment.