Skip to content

Commit

Permalink
more type clarify
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Jun 12, 2024
1 parent c188030 commit e21c5ed
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions pymatgen/electronic_structure/cohp.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ def as_dict(self) -> dict[str, Any]:
dct["COHP"] |= {label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].cohp.items()}}
icohp = self.all_cohps[label].icohp
if icohp is not None:
# TODO: DanielYang59: merge two condition branches with "|=" operator?
if "ICOHP" not in dct:
dct["ICOHP"] = {label: {str(spin): pops.tolist() for spin, pops in icohp.items()}}
else:
Expand Down Expand Up @@ -467,15 +466,15 @@ def get_summed_cohp_by_label_list(
def get_summed_cohp_by_label_and_orbital_list(
self,
label_list: list[str],
orbital_list: list, # TODO (DanielYang): what is its type? Add custom type for it?
orbital_list: list[str],
divisor: float = 1,
summed_spin_channels: bool = False,
) -> Cohp:
"""Get a Cohp object that includes a summed COHP divided by divisor.
Args:
label_list (list[str]): Labels for the COHP that should be included.
orbital_list (list): Orbitals for the COHPs that should be included
orbital_list (list[str]): Orbitals for the COHPs that should be included
(same order as label_list).
divisor (float): The summed COHP will be divided by this divisor.
summed_spin_channels (bool): Sum the spin channels and return the sum in Spin.up.
Expand Down Expand Up @@ -1048,10 +1047,14 @@ def icohpvalue(self, spin: Spin = Spin.up) -> float:

return self._icohp[spin]

def icohpvalue_orbital(self, orbitals: list[Orbital] | str, spin: Spin = Spin.up) -> float:
def icohpvalue_orbital(
self,
orbitals: tuple[Orbital, Orbital] | str,
spin: Spin = Spin.up,
) -> float:
"""
Args:
orbitals (list[Orbitals]): List of Orbitals or "str(Orbital1)-str(Orbital2)".
orbitals: tuple[Orbital, Orbital] or "str(Orbital0)-str(Orbital1)".
spin (Spin): Spin.up or Spin.down.
Returns:
Expand All @@ -1060,7 +1063,7 @@ def icohpvalue_orbital(self, orbitals: list[Orbital] | str, spin: Spin = Spin.up
if not self.is_spin_polarized and spin == Spin.down:
raise ValueError("The calculation was not performed with spin polarization")

if isinstance(orbitals, list): # TODO: DanielYang: use tuple of 2
if isinstance(orbitals, (tuple, list)):
orbitals = f"{orbitals[0]}-{orbitals[1]}"

assert self._orbitals is not None
Expand Down Expand Up @@ -1177,7 +1180,7 @@ def get_icohp_by_label(
label: str,
summed_spin_channels: bool = True,
spin: Spin = Spin.up,
orbitals: str | list[Orbital] | None = None, # TODO: DanielYang: use tuple of 2
orbitals: str | tuple[Orbital, Orbital] | None = None,
) -> float:
"""Get an ICOHP value for a certain bond indicated by the label.
Expand All @@ -1198,7 +1201,7 @@ def get_icohp_by_label(
if orbitals is None:
return icohp.summed_icohp if summed_spin_channels else icohp.icohpvalue(spin)

if isinstance(orbitals, list):
if isinstance(orbitals, (tuple, list)):
orbitals = f"{orbitals[0]}-{orbitals[1]}"

if summed_spin_channels:
Expand Down

0 comments on commit e21c5ed

Please sign in to comment.