Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Jun 19, 2024
1 parent acc9a1f commit d08f0c2
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions pymatgen/io/lobster/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from monty.io import zopen
Expand All @@ -30,6 +30,7 @@
from pymatgen.io.vasp.inputs import Kpoints
from pymatgen.io.vasp.outputs import Vasprun, VolumetricData
from pymatgen.util.due import Doi, due
from pymatgen.util.typing import PathLike

if TYPE_CHECKING:
from typing import Any, ClassVar, Literal
Expand All @@ -38,7 +39,7 @@

from pymatgen.core.structure import IStructure
from pymatgen.electronic_structure.cohp import IcohpCollection
from pymatgen.util.typing import PathLike, Tuple3Ints, Vector3D
from pymatgen.util.typing import Tuple3Ints, Vector3D

__author__ = "Janine George, Marco Esters"
__copyright__ = "Copyright 2017, The Materials Project"
Expand Down Expand Up @@ -531,7 +532,7 @@ def icohplist(self) -> dict[Any, dict[str, Any]]:
return icohp_dict

@property
def icohpcollection(self) -> IcohpCollection:
def icohpcollection(self) -> IcohpCollection | None:
"""The IcohpCollection object."""
return self._icohpcollection

Expand Down Expand Up @@ -1079,7 +1080,7 @@ def _has_fatband(data: list[str]) -> bool:
return False

@staticmethod
def _get_dft_program(data: list[str]) -> str:
def _get_dft_program(data: list[str]) -> str | None:
for row in data:
splitrow = row.split()
if len(splitrow) > 4 and splitrow[3] == "program...":
Expand Down Expand Up @@ -1230,15 +1231,15 @@ class Fatband:

def __init__(
self,
filenames: str | list = ".",
filenames: PathLike | list[PathLike] = ".",
kpoints_file: PathLike = "KPOINTS",
vasprun_file: PathLike | None = "vasprun.xml",
structure: Structure | IStructure | None = None,
efermi: float | None = None,
) -> None:
"""
Args:
filenames (list or string): can be a list of file names or a path to a
filenames ( PathLike | list[PathLike]): File names or path to a
folder from which all "FATBAND_*" files will be read.
kpoints_file (PathLike): KPOINTS file for bandstructure calculation, typically "KPOINTS".
vasprun_file (PathLike): Corresponding vasprun.xml file. Instead, the
Expand Down Expand Up @@ -1272,7 +1273,7 @@ def __init__(
self.efermi = efermi
kpoints_object = Kpoints.from_file(kpoints_file)

atom_type = [] # TODO: DanielYang: this is not used?
# atom_type = [] # TODO: DanielYang: this is not used?
atom_names = []
orbital_names = []
parameters = []
Expand All @@ -1285,15 +1286,17 @@ def __init__(
if fnmatch.fnmatch(name, "FATBAND_*.lobster"):
filenames_new += [os.path.join(filenames, name)]
filenames = filenames_new
if len(filenames) == 0:

if len(cast(list[PathLike], filenames)) == 0:
raise ValueError("No FATBAND files in folder or given")

for name in filenames:
with zopen(name, mode="rt") as file:
contents = file.read().split("\n")

atom_names += [os.path.split(name)[1].split("_")[1].capitalize()]
parameters = contents[0].split()
atom_type += [re.split(r"[0-9]+", parameters[3])[0].capitalize()]
# atom_type += [re.split(r"[0-9]+", parameters[3])[0].capitalize()]
orbital_names += [parameters[4]]

# Get atomtype orbital dict
Expand Down Expand Up @@ -1451,8 +1454,7 @@ class Bandoverlaps(MSONable):
def __init__(
self,
filename: PathLike = "bandOverlaps.lobster",
band_overlaps_dict: dict[Any, dict]
| None = None, # Any is spin number 1 or -1 # TODO: DanielYang: use Literal type
band_overlaps_dict: dict[Spin, dict] | None = None,
max_deviation: list[float] | None = None,
) -> None:
"""
Expand Down Expand Up @@ -1481,12 +1483,13 @@ def __init__(
self._filename = filename
self._read(contents, spin_numbers)

def _read(self, contents: list[str], spin_numbers: list) -> None:
def _read(self, contents: list[str], spin_numbers: list[int]) -> None:
"""Read all contents of the file.
Args:
contents (list[str]): Contents of the file.
spin_numbers: list of spin numbers depending on LOBSTER version. # TODO: DanielYang: type for spin_numbers
# TODO: DanielYang: double-check spin_numbers type
spin_numbers (list[int]): Spin numbers depending on LOBSTER version.
"""
spin: Spin = Spin.up
kpoint_array: list = []
Expand Down Expand Up @@ -1698,6 +1701,7 @@ def _parse_file(
reals = []
imaginaries = []
splitline = lines[0].split()
grid: Tuple3Ints = (int(splitline[7]), int(splitline[8]), int(splitline[9]))

for line in lines[1:]:
splitline = line.split()
Expand All @@ -1708,7 +1712,6 @@ def _parse_file(
reals.append(float(splitline[4]))
imaginaries.append(float(splitline[5]))

grid: Tuple3Ints = (int(splitline[7]), int(splitline[8]), int(splitline[9]))
if len(reals) != grid[0] * grid[1] * grid[2] or len(imaginaries) != grid[0] * grid[1] * grid[2]:
raise ValueError("Something went wrong while reading the file")

Expand Down

0 comments on commit d08f0c2

Please sign in to comment.