From e21c5ed367e343e7ca5c111d71e82c8dbf3aaa57 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Wed, 12 Jun 2024 22:01:05 +0800 Subject: [PATCH] more type clarify --- pymatgen/electronic_structure/cohp.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pymatgen/electronic_structure/cohp.py b/pymatgen/electronic_structure/cohp.py index 4904ccfc4b9..559a022224b 100644 --- a/pymatgen/electronic_structure/cohp.py +++ b/pymatgen/electronic_structure/cohp.py @@ -331,7 +331,6 @@ def as_dict(self) -> dict[str, Any]: dct["COHP"] |= {label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].cohp.items()}} icohp = self.all_cohps[label].icohp if icohp is not None: - # TODO: DanielYang59: merge two condition branches with "|=" operator? if "ICOHP" not in dct: dct["ICOHP"] = {label: {str(spin): pops.tolist() for spin, pops in icohp.items()}} else: @@ -467,7 +466,7 @@ def get_summed_cohp_by_label_list( def get_summed_cohp_by_label_and_orbital_list( self, label_list: list[str], - orbital_list: list, # TODO (DanielYang): what is its type? Add custom type for it? + orbital_list: list[str], divisor: float = 1, summed_spin_channels: bool = False, ) -> Cohp: @@ -475,7 +474,7 @@ def get_summed_cohp_by_label_and_orbital_list( Args: label_list (list[str]): Labels for the COHP that should be included. - orbital_list (list): Orbitals for the COHPs that should be included + orbital_list (list[str]): Orbitals for the COHPs that should be included (same order as label_list). divisor (float): The summed COHP will be divided by this divisor. summed_spin_channels (bool): Sum the spin channels and return the sum in Spin.up. @@ -1048,10 +1047,14 @@ def icohpvalue(self, spin: Spin = Spin.up) -> float: return self._icohp[spin] - def icohpvalue_orbital(self, orbitals: list[Orbital] | str, spin: Spin = Spin.up) -> float: + def icohpvalue_orbital( + self, + orbitals: tuple[Orbital, Orbital] | str, + spin: Spin = Spin.up, + ) -> float: """ Args: - orbitals (list[Orbitals]): List of Orbitals or "str(Orbital1)-str(Orbital2)". + orbitals: tuple[Orbital, Orbital] or "str(Orbital0)-str(Orbital1)". spin (Spin): Spin.up or Spin.down. Returns: @@ -1060,7 +1063,7 @@ def icohpvalue_orbital(self, orbitals: list[Orbital] | str, spin: Spin = Spin.up if not self.is_spin_polarized and spin == Spin.down: raise ValueError("The calculation was not performed with spin polarization") - if isinstance(orbitals, list): # TODO: DanielYang: use tuple of 2 + if isinstance(orbitals, (tuple, list)): orbitals = f"{orbitals[0]}-{orbitals[1]}" assert self._orbitals is not None @@ -1177,7 +1180,7 @@ def get_icohp_by_label( label: str, summed_spin_channels: bool = True, spin: Spin = Spin.up, - orbitals: str | list[Orbital] | None = None, # TODO: DanielYang: use tuple of 2 + orbitals: str | tuple[Orbital, Orbital] | None = None, ) -> float: """Get an ICOHP value for a certain bond indicated by the label. @@ -1198,7 +1201,7 @@ def get_icohp_by_label( if orbitals is None: return icohp.summed_icohp if summed_spin_channels else icohp.icohpvalue(spin) - if isinstance(orbitals, list): + if isinstance(orbitals, (tuple, list)): orbitals = f"{orbitals[0]}-{orbitals[1]}" if summed_spin_channels: