From c42b8999a770c95722305b518c6a9a66df252d85 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Thu, 16 May 2024 20:54:51 +0800 Subject: [PATCH] fix mypy errors --- pymatgen/core/trajectory.py | 59 ++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/pymatgen/core/trajectory.py b/pymatgen/core/trajectory.py index 38f1f104bdb..1e365142285 100644 --- a/pymatgen/core/trajectory.py +++ b/pymatgen/core/trajectory.py @@ -8,7 +8,7 @@ import warnings from fnmatch import fnmatch from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from monty.io import zopen @@ -20,10 +20,11 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing import Any from typing_extensions import Self - from pymatgen.util.typing import Matrix3D, SitePropsType, Vector3D + from pymatgen.util.typing import Matrix3D, PathLike, SitePropsType, Vector3D __author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith" @@ -167,7 +168,10 @@ def __len__(self) -> int: """Number of frames in the trajectory.""" return len(self.coords) - def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | Self: + def __getitem__( + self, + frames: int | slice | list[int], + ) -> Molecule | Structure | Self: """Get a subset of the trajectory. The output depends on the type of the input `frames`. If an int is given, return @@ -189,29 +193,24 @@ def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | raise IndexError(f"Frame index {frames} out of range.") if self.lattice is None: - charge = 0 - if self.charge is not None: - charge = int(self.charge) - - spin = None - if self.spin_multiplicity is not None: - spin = int(self.spin_multiplicity) + charge = int(self.charge) if self.charge is not None else 0 + spin = None if self.spin_multiplicity is None else int(self.spin_multiplicity) return Molecule( self.species, self.coords[frames], charge=charge, spin_multiplicity=spin, - site_properties=self._get_site_props(frames), # type: ignore + site_properties=self._get_site_props(frames), # type: ignore[arg-type] ) - lattice = self.lattice if self.constant_lattice else self.lattice[frames] # type: ignore + lattice = self.lattice if self.constant_lattice else self.lattice[frames] return Structure( Lattice(lattice), self.species, self.coords[frames], - site_properties=self._get_site_props(frames), # type: ignore + site_properties=self._get_site_props(frames), # type: ignore[arg-type] to_unit_cell=True, ) @@ -247,7 +246,7 @@ def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | base_positions=self.base_positions, ) - lattice = self.lattice if self.constant_lattice else self.lattice[selected] # type: ignore + lattice = self.lattice if self.constant_lattice else self.lattice[selected] return type(self)( species=self.species, @@ -398,7 +397,7 @@ def extend(self, trajectory: Trajectory) -> None: def write_Xdatcar( self, - filename: str | Path = "XDATCAR", + filename: PathLike = "XDATCAR", system: str | None = None, significant_figures: int = 6, ) -> None: @@ -408,7 +407,7 @@ def write_Xdatcar( Xdatcar_from_structs.get_str method and are passed through directly. Args: - filename: Name of file to write. It's prudent to end the filename with + filename: File to write. It's prudent to end the filename with 'XDATCAR', as most visualization and analysis software require this for autodetection. system: Description of system (e.g. 2D MoS2). @@ -435,7 +434,7 @@ def write_Xdatcar( if idx == 0 or not self.constant_lattice: lines.extend([system, "1.0"]) - _lattice = self.lattice if self.constant_lattice else self.lattice[idx] # type: ignore + _lattice = self.lattice if self.constant_lattice else self.lattice[idx] for latt_vec in _lattice: lines.append(f'{" ".join(map(str, latt_vec))}') @@ -498,15 +497,15 @@ def from_structures( else: lattice = np.array([structure.lattice.matrix for structure in structures]) - species = structures[0].species + species: list[Element | Species] = structures[0].species coords = [structure.frac_coords for structure in structures] site_properties = [structure.site_properties for structure in structures] return cls( - species=species, # type: ignore + species=species, # type: ignore[arg-type] coords=coords, lattice=lattice, - site_properties=site_properties, # type: ignore + site_properties=site_properties, constant_lattice=constant_lattice, **kwargs, ) @@ -529,11 +528,11 @@ def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Self: site_properties = [mol.site_properties for mol in molecules] return cls( - species=species, # type: ignore + species=species, # type: ignore[arg-type] coords=coords, charge=int(molecules[0].charge), spin_multiplicity=int(molecules[0].spin_multiplicity), - site_properties=site_properties, # type: ignore + site_properties=site_properties, **kwargs, ) @@ -571,7 +570,7 @@ def from_file( from ase.io.trajectory import Trajectory as AseTrajectory ase_traj = AseTrajectory(filename) - # periodic boundary conditions should be the same for all frames so just check the first + # Periodic boundary conditions should be the same for all frames so just check the first pbc = ase_traj[0].pbc if any(pbc): structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_traj] @@ -635,18 +634,18 @@ def _combine_site_props( assert prop1 is None or isinstance(prop1, (list, dict)) assert prop2 is None or isinstance(prop2, (list, dict)) - p1_candidates = { + p1_candidates: dict[str, Any] = { "NoneType": [None] * len1, "dict": [prop1] * len1, "list": prop1, } - p2_candidates = { + p2_candidates: dict[str, Any] = { "NoneType": [None] * len2, "dict": [prop2] * len2, "list": prop2, } - p1_selected: list = p1_candidates[type(prop1).__name__] # type: ignore - p2_selected: list = p2_candidates[type(prop2).__name__] # type: ignore + p1_selected: list = p1_candidates[type(prop1).__name__] + p2_selected: list = p2_candidates[type(prop2).__name__] return p1_selected + p2_selected @@ -661,10 +660,10 @@ def _combine_frame_props( if prop1 is prop2 is None: return None if prop1 is None: - return [None] * len1 + list(prop2) # type: ignore + return [None] * len1 + list(cast(list[dict], prop2)) if prop2 is None: - return list(prop1) + [None] * len2 # type: ignore - return list(prop1) + list(prop2) # type:ignore + return list(prop1) + [None] * len2 + return list(prop1) + list(prop2) def _check_site_props(self, site_props: SitePropsType | None) -> None: """Check data shape of site properties.