Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed May 16, 2024
1 parent 21947c9 commit c42b899
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions pymatgen/core/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from fnmatch import fnmatch
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from monty.io import zopen
Expand All @@ -20,10 +20,11 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any

from typing_extensions import Self

from pymatgen.util.typing import Matrix3D, SitePropsType, Vector3D
from pymatgen.util.typing import Matrix3D, PathLike, SitePropsType, Vector3D


__author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith"
Expand Down Expand Up @@ -167,7 +168,10 @@ def __len__(self) -> int:
"""Number of frames in the trajectory."""
return len(self.coords)

def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | Self:
def __getitem__(
self,
frames: int | slice | list[int],
) -> Molecule | Structure | Self:
"""Get a subset of the trajectory.
The output depends on the type of the input `frames`. If an int is given, return
Expand All @@ -189,29 +193,24 @@ def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure |
raise IndexError(f"Frame index {frames} out of range.")

if self.lattice is None:
charge = 0
if self.charge is not None:
charge = int(self.charge)

spin = None
if self.spin_multiplicity is not None:
spin = int(self.spin_multiplicity)
charge = int(self.charge) if self.charge is not None else 0
spin = None if self.spin_multiplicity is None else int(self.spin_multiplicity)

return Molecule(
self.species,
self.coords[frames],
charge=charge,
spin_multiplicity=spin,
site_properties=self._get_site_props(frames), # type: ignore
site_properties=self._get_site_props(frames), # type: ignore[arg-type]
)

lattice = self.lattice if self.constant_lattice else self.lattice[frames] # type: ignore
lattice = self.lattice if self.constant_lattice else self.lattice[frames]

return Structure(
Lattice(lattice),
self.species,
self.coords[frames],
site_properties=self._get_site_props(frames), # type: ignore
site_properties=self._get_site_props(frames), # type: ignore[arg-type]
to_unit_cell=True,
)

Expand Down Expand Up @@ -247,7 +246,7 @@ def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure |
base_positions=self.base_positions,
)

lattice = self.lattice if self.constant_lattice else self.lattice[selected] # type: ignore
lattice = self.lattice if self.constant_lattice else self.lattice[selected]

return type(self)(
species=self.species,
Expand Down Expand Up @@ -398,7 +397,7 @@ def extend(self, trajectory: Trajectory) -> None:

def write_Xdatcar(
self,
filename: str | Path = "XDATCAR",
filename: PathLike = "XDATCAR",
system: str | None = None,
significant_figures: int = 6,
) -> None:
Expand All @@ -408,7 +407,7 @@ def write_Xdatcar(
Xdatcar_from_structs.get_str method and are passed through directly.
Args:
filename: Name of file to write. It's prudent to end the filename with
filename: File to write. It's prudent to end the filename with
'XDATCAR', as most visualization and analysis software require this
for autodetection.
system: Description of system (e.g. 2D MoS2).
Expand All @@ -435,7 +434,7 @@ def write_Xdatcar(
if idx == 0 or not self.constant_lattice:
lines.extend([system, "1.0"])

_lattice = self.lattice if self.constant_lattice else self.lattice[idx] # type: ignore
_lattice = self.lattice if self.constant_lattice else self.lattice[idx]

for latt_vec in _lattice:
lines.append(f'{" ".join(map(str, latt_vec))}')
Expand Down Expand Up @@ -498,15 +497,15 @@ def from_structures(
else:
lattice = np.array([structure.lattice.matrix for structure in structures])

species = structures[0].species
species: list[Element | Species] = structures[0].species
coords = [structure.frac_coords for structure in structures]
site_properties = [structure.site_properties for structure in structures]

return cls(
species=species, # type: ignore
species=species, # type: ignore[arg-type]
coords=coords,
lattice=lattice,
site_properties=site_properties, # type: ignore
site_properties=site_properties,
constant_lattice=constant_lattice,
**kwargs,
)
Expand All @@ -529,11 +528,11 @@ def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Self:
site_properties = [mol.site_properties for mol in molecules]

return cls(
species=species, # type: ignore
species=species, # type: ignore[arg-type]
coords=coords,
charge=int(molecules[0].charge),
spin_multiplicity=int(molecules[0].spin_multiplicity),
site_properties=site_properties, # type: ignore
site_properties=site_properties,
**kwargs,
)

Expand Down Expand Up @@ -571,7 +570,7 @@ def from_file(
from ase.io.trajectory import Trajectory as AseTrajectory

ase_traj = AseTrajectory(filename)
# periodic boundary conditions should be the same for all frames so just check the first
# Periodic boundary conditions should be the same for all frames so just check the first
pbc = ase_traj[0].pbc
if any(pbc):
structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_traj]
Expand Down Expand Up @@ -635,18 +634,18 @@ def _combine_site_props(
assert prop1 is None or isinstance(prop1, (list, dict))
assert prop2 is None or isinstance(prop2, (list, dict))

p1_candidates = {
p1_candidates: dict[str, Any] = {
"NoneType": [None] * len1,
"dict": [prop1] * len1,
"list": prop1,
}
p2_candidates = {
p2_candidates: dict[str, Any] = {
"NoneType": [None] * len2,
"dict": [prop2] * len2,
"list": prop2,
}
p1_selected: list = p1_candidates[type(prop1).__name__] # type: ignore
p2_selected: list = p2_candidates[type(prop2).__name__] # type: ignore
p1_selected: list = p1_candidates[type(prop1).__name__]
p2_selected: list = p2_candidates[type(prop2).__name__]

return p1_selected + p2_selected

Expand All @@ -661,10 +660,10 @@ def _combine_frame_props(
if prop1 is prop2 is None:
return None
if prop1 is None:
return [None] * len1 + list(prop2) # type: ignore
return [None] * len1 + list(cast(list[dict], prop2))
if prop2 is None:
return list(prop1) + [None] * len2 # type: ignore
return list(prop1) + list(prop2) # type:ignore
return list(prop1) + [None] * len2
return list(prop1) + list(prop2)

def _check_site_props(self, site_props: SitePropsType | None) -> None:
"""Check data shape of site properties.
Expand Down

0 comments on commit c42b899

Please sign in to comment.