From d0ad785f9b4e9a57717ef2c56eef19ceb8e90c25 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 25 May 2024 15:54:32 +0800 Subject: [PATCH] Improve type annotations for `core.structure` (#3837) * remove a lot of ignore tags * first go: quick look and comment/type tweaks * remove ALL type: ignore[reportPossiblyUnboundVariable] * a quick look * tweak module docstring * fix/supress PossiblyUnboundVariable * pre-commit fix * add debug tag * fix some mypy errors * fix some mypy errors * suppress errors in `__setitem__` * fix typos * fix union * revert change to md * tweaks --------- Co-authored-by: Janosh Riebesell --- .github/code_of_conduct.md | 2 +- .../analysis/chemenv/utils/scripts_utils.py | 3 +- pymatgen/analysis/local_env.py | 2 +- pymatgen/analysis/solar/__init__.py | 2 +- pymatgen/analysis/topological/__init__.py | 2 +- pymatgen/core/interface.py | 2 +- pymatgen/core/structure.py | 706 ++++++++++-------- pymatgen/electronic_structure/boltztrap.py | 16 +- pymatgen/io/cif.py | 2 +- pymatgen/io/gaussian.py | 2 +- pymatgen/io/vasp/optics.py | 2 +- 11 files changed, 422 insertions(+), 319 deletions(-) diff --git a/.github/code_of_conduct.md b/.github/code_of_conduct.md index 32b7b6728de..80a380b8ba9 100644 --- a/.github/code_of_conduct.md +++ b/.github/code_of_conduct.md @@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, -level of experience, education, socio-economic status, nationality, personal +level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards diff --git a/pymatgen/analysis/chemenv/utils/scripts_utils.py b/pymatgen/analysis/chemenv/utils/scripts_utils.py index 9ffbc644913..087a42b2402 100644 --- a/pymatgen/analysis/chemenv/utils/scripts_utils.py +++ b/pymatgen/analysis/chemenv/utils/scripts_utils.py @@ -213,6 +213,7 @@ def compute_environments(chemenv_configuration): default_strategy.setup_options(chemenv_configuration.package_options["default_strategy"]["strategy_options"]) max_dist_factor = chemenv_configuration.package_options["default_max_distance_factor"] first_time = True + test = None while True: if len(questions) > 1: found = False @@ -240,7 +241,7 @@ def compute_environments(chemenv_configuration): input_source = "" if found and len(questions) > 1: - input_source = test # type: ignore[reportPossiblyUnboundVariable] + input_source = test structure = None if source_type == "cif": diff --git a/pymatgen/analysis/local_env.py b/pymatgen/analysis/local_env.py index 42004e76300..f39da43ed31 100644 --- a/pymatgen/analysis/local_env.py +++ b/pymatgen/analysis/local_env.py @@ -828,7 +828,7 @@ def get_all_voronoi_polyhedra(self, structure: Structure): del indices # Save memory (tessellations can be costly) # Run the tessellation - qvoronoi_input = [s.coords for s in sites] + qvoronoi_input = [s.coords for s in sites if s is not None] voro = Voronoi(qvoronoi_input) # Get the information for each neighbor diff --git a/pymatgen/analysis/solar/__init__.py b/pymatgen/analysis/solar/__init__.py index 72fd29bfa40..cb951a7030f 100644 --- a/pymatgen/analysis/solar/__init__.py +++ b/pymatgen/analysis/solar/__init__.py @@ -1 +1 @@ -"""Modules for prediciting theoretical solar-cell efficiency.""" +"""Module for predicting theoretical solar-cell efficiency.""" diff --git a/pymatgen/analysis/topological/__init__.py b/pymatgen/analysis/topological/__init__.py index 579107d2c1b..f424015e1f6 100644 --- a/pymatgen/analysis/topological/__init__.py +++ b/pymatgen/analysis/topological/__init__.py @@ -1 +1 @@ -"""Modules for prediciting topological properties.""" +"""Module for predicting topological properties.""" diff --git a/pymatgen/core/interface.py b/pymatgen/core/interface.py index 84decc4f108..35e3cfa94be 100644 --- a/pymatgen/core/interface.py +++ b/pymatgen/core/interface.py @@ -761,7 +761,7 @@ def gb_from_parameters( sites_away_gb.append(site) if len(sites_near_gb) >= 1: s_near_gb = Structure.from_sites(sites_near_gb) - s_near_gb.merge_sites(tol=bond_length * rm_ratio, mode="d") + s_near_gb.merge_sites(tol=bond_length * rm_ratio, mode="delete") all_sites = sites_away_gb + s_near_gb.sites # type: ignore gb_with_vac = Structure.from_sites(all_sites) diff --git a/pymatgen/core/structure.py b/pymatgen/core/structure.py index bbd40a8b2f8..7d49026d8cd 100644 --- a/pymatgen/core/structure.py +++ b/pymatgen/core/structure.py @@ -1,5 +1,6 @@ -"""This module provides classes used to define a non-periodic molecule and a -periodic structure. +"""This module provides classes to define non-periodic Molecule +and periodic Structure, along with their immutable counterparts +IMolecule and IStructure. """ from __future__ import annotations @@ -19,10 +20,10 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import MutableSequence from fnmatch import fnmatch -from inspect import isclass from io import StringIO -from typing import TYPE_CHECKING, Literal, cast, get_args +from typing import TYPE_CHECKING, Literal, Union, cast, get_args import numpy as np from monty.dev import deprecated @@ -35,6 +36,7 @@ from scipy.linalg import expm, polar from scipy.spatial.distance import squareform from tabulate import tabulate +from typing_extensions import Self from pymatgen.core.bonds import CovalentBond, get_bond_length from pymatgen.core.composition import Composition @@ -49,18 +51,17 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence - from pathlib import Path from typing import Any, Callable, SupportsIndex + import pandas as pd from ase import Atoms from ase.calculators.calculator import Calculator from ase.io.trajectory import Trajectory from ase.optimize.optimize import Optimizer from matgl.ext.ase import TrajectoryObserver - from numpy.typing import ArrayLike - from typing_extensions import Self + from numpy.typing import ArrayLike, NDArray - from pymatgen.util.typing import CompositionLike, SpeciesLike + from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike FileFormats = Literal["cif", "poscar", "cssr", "json", "yaml", "yml", "xsf", "mcsqs", "res", "pwmat", ""] @@ -77,7 +78,7 @@ class Neighbor(Site): def __init__( self, species: Composition, - coords: np.ndarray, + coords: NDArray, properties: dict | None = None, nn_distance: float = 0.0, index: int = 0, @@ -92,11 +93,11 @@ def __init__( index: Index within structure. label: Label for the site. Defaults to None. """ - self.coords = coords - self._species = species - self.properties = properties or {} - self.nn_distance = nn_distance - self.index = index + self._species: Composition = species + self.coords: NDArray = coords + self.properties: dict = properties or {} + self.nn_distance: float = nn_distance + self.index: int = index self._label = label def __len__(self) -> Literal[3]: @@ -112,7 +113,7 @@ def as_dict(self) -> dict: return super(Site, self).as_dict() @classmethod - def from_dict(cls, dct: dict) -> Self: + def from_dict(cls, dct: dict) -> Site: """Get a Neighbor from a dict. Args: @@ -127,18 +128,17 @@ def from_dict(cls, dct: dict) -> Self: class PeriodicNeighbor(PeriodicSite): """Simple PeriodicSite subclass to contain a neighboring atom that skips all the unnecessary checks for speed. Can be used as a fixed-length tuple of - size 4 to retain backwards compatibility with past use cases. - + size 4 to retain backwards compatibility with past use cases: (site, distance, index, image). - In future, usage should be to call attributes, e.g. PeriodicNeighbor.index, + Should access attributes in the future, e.g. PeriodicNeighbor.index, PeriodicNeighbor.distance, etc. """ def __init__( self, species: Composition, - coords: np.ndarray, + coords: NDArray, lattice: Lattice, properties: dict | None = None, nn_distance: float = 0.0, @@ -166,7 +166,7 @@ def __init__( self.image = image self._label = label - def __len__(self) -> int: + def __len__(self) -> Literal[4]: """Make neighbor Tuple-like to retain backwards compatibility.""" return 4 @@ -174,8 +174,8 @@ def __getitem__(self, idx: int | slice): # type: ignore[override] """Make neighbor Tuple-like to retain backwards compatibility.""" return (self, self.nn_distance, self.index, self.image)[idx] - @property # type: ignore - def coords(self) -> np.ndarray: # type: ignore + @property # type: ignore[misc] + def coords(self) -> NDArray: """Cartesian coords.""" return self._lattice.get_cartesian_coords(self._frac_coords) @@ -203,38 +203,38 @@ class SiteCollection(collections.abc.Sequence, ABC): periodicity). Not meant to be instantiated directly. """ - # Tolerance in Angstrom for determining if sites are too close. + # Tolerance in Angstrom for determining if sites are too close DISTANCE_TOLERANCE = 0.5 _properties: dict def __contains__(self, site: object) -> bool: return site in self.sites - def __iter__(self) -> Iterator[Site]: + def __iter__(self) -> Iterator[PeriodicSite]: return iter(self.sites) - # TODO return type needs fixing (can be list[Site] but raises lots of mypy errors) - def __getitem__(self, ind: int | slice) -> Site: + # TODO return type needs fixing (can be Sequence[PeriodicSite] but raises lots of mypy errors) + def __getitem__(self, ind: int | slice) -> PeriodicSite: return self.sites[ind] # type: ignore[return-value] def __len__(self) -> int: return len(self.sites) def __hash__(self) -> int: - # for now, just use the composition hash code. + """Use the composition hash for now.""" return hash(self.composition) @property - def sites(self) -> list[Site]: - """An iterator for the sites in the Structure.""" - return self._sites # type: ignore[has-type] + def sites(self) -> list[PeriodicSite] | tuple[PeriodicSite, ...]: + """The sites in the Structure.""" + return self._sites @sites.setter def sites(self, sites: Sequence[PeriodicSite]) -> None: """Set the sites in the Structure.""" - # if self is mutable Structure or Molecule, set _sites as list - is_mutable = isinstance(self._sites, list) # type: ignore[has-type] - self._sites = list(sites) if is_mutable else tuple(sites) + # If self is mutable Structure or Molecule, set _sites as list + is_mutable = isinstance(self._sites, MutableSequence) + self._sites: list[PeriodicSite] | tuple[PeriodicSite, ...] = list(sites) if is_mutable else tuple(sites) @abstractmethod def copy(self) -> Self: @@ -284,7 +284,7 @@ def species_and_occu(self) -> list[Composition]: return [site.species for site in self] @property - @deprecated(message="Use n_type_sp instead.") + @deprecated(message="Use n_type_sp instead") def ntypesp(self) -> int: """Number of types of atoms.""" return len(self.types_of_species) @@ -295,19 +295,20 @@ def n_elems(self) -> int: return len(self.types_of_species) @property - def types_of_species(self) -> tuple[Element | Species | DummySpecies]: - """List of types of specie.""" - # Cannot use set since we want a deterministic algorithm. + def types_of_species(self) -> tuple[Element | Species | DummySpecies, ...]: + """Tuple of types of species.""" types: list[Element | Species | DummySpecies] = [] for site in self: for sp, amt in site.species.items(): if amt != 0: types.append(sp) - return tuple(sorted(set(types))) # type: ignore + + # Cannot use set since we want a deterministic algorithm + return cast(tuple[Union[Element, Species, DummySpecies], ...], tuple(sorted(set(types)))) @property - def types_of_specie(self) -> tuple[Element | Species | DummySpecies]: - """Specie->Species rename. Maintained for backwards compatibility.""" + def types_of_specie(self) -> tuple[Element | Species | DummySpecies, ...]: + """Specie -> Species rename, to maintain backwards compatibility.""" return self.types_of_species def group_by_types(self) -> Iterator[Site | PeriodicSite]: @@ -332,7 +333,7 @@ def symbol_set(self) -> tuple[str, ...]: @property def atomic_numbers(self) -> tuple[int, ...]: - """List of atomic numbers.""" + """Tuple of atomic numbers.""" try: return tuple(site.specie.Z for site in self) except AttributeError: @@ -350,7 +351,7 @@ def site_properties(self) -> dict[str, Sequence]: return {key: [site.properties.get(key) for site in self] for key in prop_keys} @property - def labels(self) -> list[str]: + def labels(self) -> list[str | None]: """Site labels as a list.""" return [site.label for site in self] @@ -503,16 +504,16 @@ def to_file(self, filename: str = "", fmt: FileFormats = "") -> str | None: @classmethod @abstractmethod def from_str(cls, input_string: str, fmt: Any) -> None: - """Reads in SiteCollection from a string.""" + """Read in SiteCollection from a string.""" raise NotImplementedError @classmethod @abstractmethod def from_file(cls, filename: str) -> None: - """Reads in SiteCollection from a filename.""" + """Read in SiteCollection from a filename.""" raise NotImplementedError - def add_site_property(self, property_name: str, values: Sequence | np.ndarray) -> SiteCollection: + def add_site_property(self, property_name: str, values: Sequence | np.ndarray) -> Self: """Add a property to a site. Note: This is the preferred method for adding magnetic moments, selective dynamics, and related site-specific properties to a structure/molecule object. @@ -539,7 +540,7 @@ def add_site_property(self, property_name: str, values: Sequence | np.ndarray) - return self - def remove_site_property(self, property_name: str) -> SiteCollection: + def remove_site_property(self, property_name: str) -> Self: """Removes a property to a site. Args: @@ -554,11 +555,13 @@ def remove_site_property(self, property_name: str) -> SiteCollection: return self def replace_species( - self, species_mapping: dict[SpeciesLike, SpeciesLike | dict[SpeciesLike, float]], in_place: bool = True - ) -> SiteCollection: - """Swap species. + self, + species_mapping: dict[SpeciesLike, SpeciesLike | dict[SpeciesLike, float]], + in_place: bool = True, + ) -> Self: + """Replace species. - Note that this clears the label of any affected site. + Note that this resets the label of any affected site to species_string. Args: species_mapping (dict): Species to swap. Species can be elements too. e.g. @@ -595,7 +598,7 @@ def replace_species( return site_coll - def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> SiteCollection: + def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> Self: """Add oxidation states. Args: @@ -618,7 +621,7 @@ def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> return self - def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> SiteCollection: + def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> Self: """Add oxidation states to a structure by site. Args: @@ -645,7 +648,7 @@ def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> SiteColl return self - def remove_oxidation_states(self) -> SiteCollection: + def remove_oxidation_states(self) -> Self: """Removes oxidation states from a structure.""" for site in self: new_sp: dict[Element, float] = defaultdict(float) @@ -656,7 +659,7 @@ def remove_oxidation_states(self) -> SiteCollection: return self - def add_oxidation_state_by_guess(self, **kwargs) -> SiteCollection: + def add_oxidation_state_by_guess(self, **kwargs) -> Self: """Decorates the structure with oxidation state, guessing using Composition.oxi_state_guesses(). If multiple guesses are found we take the first one. @@ -670,7 +673,7 @@ def add_oxidation_state_by_guess(self, **kwargs) -> SiteCollection: return self - def add_spin_by_element(self, spins: dict[str, float]) -> SiteCollection: + def add_spin_by_element(self, spins: dict[str, float]) -> Self: """Add spin states to structure. Args: @@ -688,7 +691,7 @@ def add_spin_by_element(self, spins: dict[str, float]) -> SiteCollection: return self - def add_spin_by_site(self, spins: Sequence[float]) -> SiteCollection: + def add_spin_by_site(self, spins: Sequence[float]) -> Self: """Add spin states to structure by site. Args: @@ -707,7 +710,7 @@ def add_spin_by_site(self, spins: Sequence[float]) -> SiteCollection: return self - def remove_spin(self) -> SiteCollection: + def remove_spin(self) -> Self: """Remove spin states from structure.""" for site in self: new_sp: dict[Element, float] = defaultdict(float) @@ -719,7 +722,7 @@ def remove_spin(self) -> SiteCollection: return self def extract_cluster(self, target_sites: list[Site], **kwargs) -> list[Site]: - """Extracts a cluster of atoms based on bond lengths. + """Extract a cluster of atoms based on bond lengths. Args: target_sites (list[Site]): Initial sites from which to nucleate cluster. @@ -760,7 +763,7 @@ def _calculate(self, calculator: str | Calculator, verbose: bool = False) -> Cal from pymatgen.io.ase import AseAtomsAdaptor if isinstance(self, Molecule) and isinstance(calculator, str) and calculator.lower() in ("chgnet", "m3gnet"): - raise ValueError(f"Can't use {calculator=} for a Molecule.") + raise ValueError(f"Can't use {calculator=} for a Molecule") calculator = self._prep_calculator(calculator) # Get Atoms object @@ -792,8 +795,8 @@ def _relax( """Perform a structure relaxation using an ASE calculator. Args: - calculator (str | ase.Calculator): An ASE Calculator or a string from the following options: "M3GNet", - "gfn2-xtb". + calculator (str | ase.Calculator): An ASE Calculator or a string + from the following options: "M3GNet", "gfn2-xtb". relax_cell (bool): whether to relax the lattice cell. Defaults to True. optimizer (str): name of the ASE optimizer class to use steps (int): max number of steps for relaxation. Defaults to 500. @@ -824,9 +827,9 @@ def _relax( calc_params = {} if is_molecule else dict(stress_weight=stress_weight) calculator = self._prep_calculator(calculator, **calc_params) - # check str is valid optimizer key + # Check str is valid optimizer key def is_ase_optimizer(key): - return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer) + return inspect.isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer) valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)] if isinstance(optimizer, str): @@ -840,7 +843,8 @@ def is_ase_optimizer(key): adaptor = AseAtomsAdaptor() atoms = adaptor.get_atoms(self) - # Use a TrajectoryObserver if running M3GNet or CHGNet; otherwise, write a .traj file + # Use a TrajectoryObserver if running M3GNet or CHGNet. + # Otherwise, write a .traj file if return_trajectory: if run_uip: from matgl.ext.ase import TrajectoryObserver @@ -857,7 +861,7 @@ def is_ase_optimizer(key): with contextlib.redirect_stdout(stream): if relax_cell: if is_molecule: - raise ValueError("Can't relax cell for a Molecule.") + raise ValueError("Can't relax cell for a Molecule") ecf = ExpCellFilter(atoms) dyn = opt_class(ecf, **opt_kwargs) else: @@ -886,8 +890,7 @@ def _prep_calculator(self, calculator: Literal["m3gnet", "gfn2-xtb"] | Calculato """Convert string name of special ASE calculators into ASE calculator objects. Args: - calculator: An ASE Calculator or a string from the following options: "m3gnet", - "gfn2-xtb". + calculator: An ASE Calculator or a string from the following options: "m3gnet", "gfn2-xtb". **params: Parameters for the calculator. Returns: @@ -1037,7 +1040,7 @@ def __init__( self._lattice, to_unit_cell, coords_are_cartesian=coords_are_cartesian, - properties=prop, # type: ignore + properties=prop, label=label, ) sites.append(site) @@ -1050,10 +1053,11 @@ def __init__( def __eq__(self, other: object) -> bool: needed_attrs = ("lattice", "sites", "properties") + # Return NotImplemented as in https://docs.python.org/3/library/functools.html#functools.total_ordering if not all(hasattr(other, attr) for attr in needed_attrs): - # return NotImplemented as in https://docs.python.org/3/library/functools.html#functools.total_ordering return NotImplemented + # TODO (DanielYang59): fix below type other = cast(Structure, other) # make mypy happy if other is self: @@ -1067,11 +1071,11 @@ def __eq__(self, other: object) -> bool: return all(site in other for site in self) def __hash__(self) -> int: - # For now, just use the composition hash code. + """Use the composition hash for now.""" return hash(self.composition) def __mul__(self, scaling_matrix: int | Sequence[int] | Sequence[Sequence[int]]) -> Structure: - """Make a supercell. Allowing to have sites outside the unit cell. + """Make a supercell. Allow sites outside the unit cell. Args: scaling_matrix: A scaling matrix for transforming the lattice @@ -1129,8 +1133,7 @@ def __repr__(self) -> str: outs = ["Structure Summary", repr(self.lattice)] if self._charge: outs.append(f"Overall Charge: {self._charge:+}") - for site in self: - outs.append(repr(site)) + outs.extend(map(repr, self)) return "\n".join(outs) def __str__(self) -> str: @@ -1173,8 +1176,8 @@ def from_sites( validate_proximity: bool = False, to_unit_cell: bool = False, properties: dict | None = None, - ) -> IStructure: - """Convenience constructor to make a Structure from a list of sites. + ) -> Self: + """Convenience constructor to make a IStructure from a list of sites. Args: sites: Sequence of PeriodicSites. Sites must have the same @@ -1234,7 +1237,7 @@ def from_spacegroup( coords_are_cartesian: bool = False, tol: float = 1e-5, labels: Sequence[str | None] | None = None, - ) -> IStructure | Structure: + ) -> Self: """Generate a structure using a spacegroup. Note that only symmetrically distinct species and coords should be provided. All equivalent sites are generated from the spacegroup operations. @@ -1283,7 +1286,7 @@ def from_spacegroup( num = int(sg) spg = SpaceGroup.from_int_number(num) except ValueError: - spg = SpaceGroup(sg) # type: ignore + spg = SpaceGroup(str(sg)) lattice = lattice if isinstance(lattice, Lattice) else Lattice(lattice) @@ -1309,13 +1312,19 @@ def from_spacegroup( for idx, (sp, c) in enumerate(zip(species, frac_coords)): cc = spg.get_orbit(c, tol=tol) all_sp.extend([sp] * len(cc)) - all_coords.extend(cc) # type: ignore + all_coords.extend(cc) label = labels[idx] if labels else None all_labels.extend([label] * len(cc)) for k, v in props.items(): all_site_properties[k].extend([v[idx]] * len(cc)) - return cls(lattice, all_sp, all_coords, site_properties=all_site_properties, labels=all_labels) + return cls( + lattice, + all_sp, + all_coords, + site_properties=all_site_properties, + labels=all_labels, + ) @classmethod def from_magnetic_spacegroup( @@ -1328,7 +1337,7 @@ def from_magnetic_spacegroup( coords_are_cartesian: bool = False, tol: float = 1e-5, labels: Sequence[str | None] | None = None, - ) -> IStructure | Structure: + ) -> Self: """Generate a structure using a magnetic spacegroup. Note that only symmetrically distinct species, coords and magmoms should be provided.] All equivalent sites are generated from the spacegroup operations. @@ -1377,7 +1386,7 @@ def from_magnetic_spacegroup( None for no labels. Returns: - Structure | IStructure + IStructure """ if "magmom" not in site_properties: raise ValueError("Magnetic moments have to be defined.") @@ -1406,9 +1415,9 @@ def from_magnetic_spacegroup( all_magmoms: list[float] = [] all_site_properties: dict[str, list] = defaultdict(list) all_labels: list[str | None] = [] - for idx, (sp, c, m) in enumerate(zip(species, frac_coords, magmoms)): # type: ignore - cc, mm = msg.get_orbit(c, m, tol=tol) - all_sp.extend([sp] * len(cc)) + for idx, (spec, f_coord, magmom) in enumerate(zip(species, frac_coords, magmoms)): + cc, mm = msg.get_orbit(f_coord, magmom, tol=tol) + all_sp.extend([spec] * len(cc)) all_coords.extend(cc) all_magmoms.extend(mm) label = labels[idx] if labels else None @@ -1477,17 +1486,21 @@ def density(self) -> float: return mass.to("g") / (self.volume * Length(1, "ang").to("cm") ** 3) @property - def pbc(self) -> tuple[bool, bool, bool]: + def pbc(self) -> PbcLike: """The periodicity of the structure.""" return self._lattice.pbc @property def is_3d_periodic(self) -> bool: - """True if the Lattice is periodic in all directions.""" + """Whether the Lattice is periodic in all directions.""" return self._lattice.is_3d_periodic - def get_space_group_info(self, symprec: float = 1e-2, angle_tolerance: float = 5.0) -> tuple[str, int]: - """Convenience method to quickly get the spacegroup of a structure. + def get_space_group_info( + self, + symprec: float = 1e-2, + angle_tolerance: float = 5.0, + ) -> tuple[str, int]: + """Get the spacegroup of a structure. Args: symprec (float): Same definition as in SpacegroupAnalyzer. @@ -1498,13 +1511,18 @@ def get_space_group_info(self, symprec: float = 1e-2, angle_tolerance: float = 5 Returns: spacegroup_symbol, international_number """ - # Import within method needed to avoid cyclic dependency. + # Avoid circular import from pymatgen.symmetry.analyzer import SpacegroupAnalyzer spg_analyzer = SpacegroupAnalyzer(self, symprec=symprec, angle_tolerance=angle_tolerance) return spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number() - def matches(self, other: IStructure | Structure, anonymous: bool = False, **kwargs) -> bool: + def matches( + self, + other: Self | Structure, + anonymous: bool = False, + **kwargs, + ) -> bool: """Check whether this structure is similar to another structure. Basically a convenience method to call structure matching. @@ -1596,7 +1614,7 @@ def get_sites_in_sphere( self._lattice, properties=self[idx].properties, nn_distance=dist, - image=img, # type: ignore + image=img, index=idx, label=self[idx].label, ) @@ -1651,9 +1669,8 @@ def _get_neighbor_list_py( sites: list[PeriodicSite] | None = None, numerical_tol: float = 1e-8, exclude_self: bool = True, - ) -> tuple[np.ndarray, ...]: - """A python version of getting neighbor_list. The returned values are a tuple of - numpy arrays (center_indices, points_indices, offset_vectors, distances). + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """A python version of getting neighbor_list. Atom `center_indices[i]` has neighbor atom `points_indices[i]` that is translated by `offset_vectors[i]` lattice vectors, and the distance is `distances[i]`. @@ -1676,7 +1693,11 @@ def _get_neighbor_list_py( tuple: (center_indices, points_indices, offset_vectors, distances) """ neighbors = self.get_all_neighbors_py( - r=r, include_index=True, include_image=True, sites=sites, numerical_tol=1e-8 + r=r, + include_index=True, + include_image=True, + sites=sites, + numerical_tol=1e-8, ) center_indices = [] points_indices = [] @@ -1699,7 +1720,7 @@ def get_neighbor_list( sites: Sequence[PeriodicSite] | None = None, numerical_tol: float = 1e-8, exclude_self: bool = True, - ) -> tuple[np.ndarray, ...]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Get neighbor lists using numpy array representations without constructing Neighbor objects. If the cython extension is installed, this method will be orders of magnitude faster than `get_all_neighbors_old` and 2-3x faster @@ -1730,7 +1751,10 @@ def get_neighbor_list( try: from pymatgen.optimization.neighbors import find_points_in_spheres except ImportError: - return self._get_neighbor_list_py(r, sites, exclude_self=exclude_self) # type: ignore + if sites is None: + return self._get_neighbor_list_py(r, None, exclude_self=exclude_self) + return self._get_neighbor_list_py(r, list(sites), exclude_self=exclude_self) + else: if sites is None: sites = self.sites @@ -1741,7 +1765,7 @@ def get_neighbor_list( center_indices, points_indices, images, distances = find_points_in_spheres( cart_coords, site_coords, - r=float(r), + r=r, pbc=pbc, lattice=lattice_matrix, tol=numerical_tol, @@ -1750,7 +1774,7 @@ def get_neighbor_list( if exclude_self: self_pair = (center_indices == points_indices) & (distances <= numerical_tol) cond = ~self_pair - return (center_indices[cond], points_indices[cond], images[cond], distances[cond]) + return center_indices[cond], points_indices[cond], images[cond], distances[cond] def get_symmetric_neighbor_list( self, @@ -1815,12 +1839,12 @@ def get_symmetric_neighbor_list( f"supplied spacegroup {sgp.symbol}!" ) - # get a list of neighbors up to distance r + # Get a list of neighbors up to distance r bonds = self.get_neighbor_list(r) if unique: redundant = [] - # compare all neighbors pairwise to find the pairs that connect the same + # Compare all neighbors pairwise to find the pairs that connect the same # two sites, but with an inverted vector (R=-R) that connects the two and add # one of each pair to the redundant list. for idx, (i, j, R, d) in enumerate(zip(*bonds)): @@ -1834,12 +1858,12 @@ def get_symmetric_neighbor_list( if bool1 and bool2 and bool3 and bool4: redundant.append(jdx) - # delete the redundant neighbors + # Delete the redundant neighbors m = ~np.in1d(np.arange(len(bonds[0])), redundant) idcs_dist = np.argsort(bonds[3][m]) bonds = (bonds[0][m][idcs_dist], bonds[1][m][idcs_dist], bonds[2][m][idcs_dist], bonds[3][m][idcs_dist]) - # expand the output tuple by symmetry_indices and symmetry_ops. + # Expand the output tuple by symmetry_indices and symmetry_ops. n_bonds = len(bonds[0]) symmetry_indices = np.empty(n_bonds) symmetry_indices[:] = np.nan @@ -1873,7 +1897,7 @@ def get_symmetric_neighbor_list( if are_related and not is_reversed: symmetry_indices[jdx] = symmetry_index symmetry_ops[jdx] = op - elif are_related and is_reversed: + elif are_related: symmetry_indices[jdx] = symmetry_index symmetry_ops[jdx] = op bonds[0][jdx], bonds[1][jdx] = bonds[1][jdx], bonds[0][jdx] @@ -1881,13 +1905,13 @@ def get_symmetric_neighbor_list( symmetry_index += 1 - # the bonds are ordered by their symmetry index + # The bonds are ordered by their symmetry index idcs_symid = np.argsort(symmetry_indices) bonds = (bonds[0][idcs_symid], bonds[1][idcs_symid], bonds[2][idcs_symid], bonds[3][idcs_symid]) symmetry_indices = symmetry_indices[idcs_symid] symmetry_ops = symmetry_ops[idcs_symid] - # the groups of neighbors with the same symmetry index are ordered such that neighbors + # The groups of neighbors with the same symmetry index are ordered such that neighbors # that are the first occurrence of a new symmetry index in the ordered output are the ones # that are assigned the Identity as a symmetry operation. idcs_symop = np.arange(n_bonds) @@ -1915,8 +1939,7 @@ def get_all_neighbors( sites: Sequence[PeriodicSite] | None = None, numerical_tol: float = 1e-8, ) -> list[list[PeriodicNeighbor]]: - """Get neighbors for each atom in the unit cell, out to a distance r - Returns a list of list of neighbors for each site in structure. + """Get neighbors for each atom in the unit cell, out to a distance r. Use this method if you are planning on looping over all sites in the crystal. If you only want neighbors for a particular site, use the method get_neighbors as it may not have to build such a large supercell @@ -1949,7 +1972,8 @@ def get_all_neighbors( ok in most instances. Returns: - [[pymatgen.core.structure.PeriodicNeighbor], ..] + [[pymatgen.core.structure.PeriodicNeighbor], ...]: a list of + list of neighbors for each site in structure. """ if sites is None: sites = self.sites @@ -2003,8 +2027,7 @@ def get_all_neighbors_py( sites: Sequence[PeriodicSite] | None = None, numerical_tol: float = 1e-8, ) -> list[list[PeriodicNeighbor]]: - """Get neighbors for each atom in the unit cell, out to a distance r - Returns a list of list of neighbors for each site in structure. + """Get neighbors for each atom in the unit cell, out to a distance r. Use this method if you are planning on looping over all sites in the crystal. If you only want neighbors for a particular site, use the method get_neighbors as it may not have to build such a large supercell @@ -2037,7 +2060,7 @@ def get_all_neighbors_py( ok in most instances. Returns: - list[list[PeriodicNeighbor]] + list[list[PeriodicNeighbor]]: Neighbors for each site in structure. """ if sites is None: sites = self.sites @@ -2073,9 +2096,14 @@ def get_all_neighbors_py( return neighbors @deprecated(get_all_neighbors, "This is retained purely for checking purposes.") - def get_all_neighbors_old(self, r, include_index=False, include_image=False, include_site=True): - """Get neighbors for each atom in the unit cell, out to a distance r - Returns a list of list of neighbors for each site in structure. + def get_all_neighbors_old( + self, + r: float, + include_index: bool = False, + include_image: bool = False, + include_site: bool = True, + ): + """Get neighbors for each atom in the unit cell, out to a distance r. Use this method if you are planning on looping over all sites in the crystal. If you only want neighbors for a particular site, use the method get_neighbors as it may not have to build such a large supercell @@ -2101,7 +2129,7 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc data. Defaults to True. Returns: - PeriodicNeighbor + list[list[PeriodicNeighbor]]: Neighbors for each site in structure. """ # Use same algorithm as get_sites_in_sphere to determine supercell but # loop over all atoms in crystal @@ -2113,7 +2141,7 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc all_ranges = list(itertools.starmap(np.arange, zip(nmin, nmax))) lattice = self._lattice matrix = lattice.matrix - neighbors = [[] for _ in range(len(self))] + neighbors: list[list] = [[] for _ in range(len(self))] all_fcoords = np.mod(self.frac_coords, 1) coords_in_cell = np.dot(all_fcoords, matrix) site_coords = self.cart_coords @@ -2137,7 +2165,7 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc ) for i in indices[within_r]: - item = [] + item: list[Any] = [] if include_site: item.append(nnsite) # type: ignore[reportPossiblyUnboundVariable] item.append(d[i]) @@ -2146,11 +2174,17 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc # Add the image, if requested if include_image: item.append(image) + neighbors[i].append(item) return neighbors def get_neighbors_in_shell( - self, origin: ArrayLike, r: float, dr: float, include_index: bool = False, include_image: bool = False + self, + origin: ArrayLike, + r: float, + dr: float, + include_index: bool = False, + include_image: bool = False, ) -> list[PeriodicNeighbor]: """Get all sites in a shell centered on origin (coords) between radii r-dr and r+dr. @@ -2208,7 +2242,7 @@ def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "nigg return type(self)( reduced_latt, self.species_and_occu, - self.cart_coords, # type: ignore + self.cart_coords, coords_are_cartesian=True, to_unit_cell=True, site_properties=self.site_properties, @@ -2287,7 +2321,7 @@ def interpolate( pbc: bool = True, autosort_tol: float = 0, end_amplitude: float = 1, - ) -> list[IStructure | Structure]: + ) -> list[Self]: """Interpolate between this structure and end_structure. Useful for construction of NEB inputs. To obtain useful results, the cell setting and order of sites must consistent across the start and end structures. @@ -2365,7 +2399,7 @@ def interpolate( if len(unmapped_start_ind) == 1: idx = unmapped_start_ind[0] - j = next(iter(set(range(len(start_coords))) - set(matched))) # type: ignore + j = next(iter(set(range(len(start_coords))) - set(matched))) # type: ignore[arg-type] sorted_end_coords[idx] = end_coords[j] end_coords = sorted_end_coords @@ -2377,7 +2411,7 @@ def interpolate( structs = [] if interpolate_lattices: - # interpolate lattice matrices using polar decomposition + # Interpolate lattice matrices using polar decomposition # u is a unitary rotation, p is stretch _u, p = polar(np.dot(end_structure.lattice.matrix.T, np.linalg.inv(self.lattice.matrix.T))) lvec = end_amplitude * (p - np.identity(3)) @@ -2395,7 +2429,12 @@ def interpolate( ) return structs - def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True): + def get_miller_index_from_site_indexes( + self, + site_ids: list[int], + round_dp: int = 4, + verbose: bool = True, + ) -> MillerIndex: """Get the Miller index of a plane from a set of sites indexes. A minimum of 3 sites are required. If more than 3 sites are given @@ -2403,7 +2442,7 @@ def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True) calculated. Args: - site_ids (list of int): A list of site indexes to consider. A + site_ids (list[int]): A list of site indexes to consider. A minimum of three site indexes are required. If more than three sites are provided, the best plane that minimises the distance to all sites will be calculated. @@ -2412,7 +2451,7 @@ def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True) verbose (bool, optional): Whether to print warnings. Returns: - tuple: The Miller index. + MillerIndex: The Miller index. """ return self.lattice.get_miller_index_from_coords( self.frac_coords[site_ids], @@ -2422,13 +2461,16 @@ def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True) ) def get_primitive_structure( - self, tolerance: float = 0.25, use_site_props: bool = False, constrain_latt: list | dict | None = None - ): - """This finds a smaller unit cell than the input. Sometimes it doesn"t + self, + tolerance: float = 0.25, + use_site_props: bool = False, + constrain_latt: list | dict | None = None, + ) -> Self: + """Find a smaller unit cell than the input. Sometimes it doesn't find the smallest possible one, so this method is recursively called until it is unable to find a smaller cell. - NOTE: if the tolerance is greater than 1/2 the minimum inter-site + NOTE: If the tolerance is greater than 1/2 of the minimum inter-site distance in the primitive cell, the algorithm will reject this lattice. Args: @@ -2457,7 +2499,7 @@ def site_label(site): parts.append(f"{key}={site.properties[key]}") return ", ".join(parts) - # group sites by species string + # Group sites by species string sites = sorted(self._sites, key=site_label) grouped_sites = [list(a[1]) for a in itertools.groupby(sites, key=site_label)] @@ -2469,7 +2511,7 @@ def site_label(site): min_fcoords = min(grouped_frac_coords, key=len) min_vecs = min_fcoords - min_fcoords[0] - # fractional tolerance in the supercell + # Fractional tolerance in the supercell super_ftol = np.divide(tolerance, self.lattice.abc) super_ftol_2 = super_ftol * 2 @@ -2481,7 +2523,7 @@ def pbc_coord_intersection(fc1, fc2, tol): dist -= np.round(dist) return fc1[np.any(np.all(dist < tol, axis=-1), axis=-1)] - # here we reduce the number of min_vecs by enforcing that every + # Here we reduce the number of min_vecs by enforcing that every # vector in min_vecs approximately maps each site onto a similar site. # The subsequent processing is O(fu^3 * min_vecs) = O(n^4) if we do no # reduction. @@ -2491,7 +2533,7 @@ def pbc_coord_intersection(fc1, fc2, tol): for frac_coords in group: min_vecs = pbc_coord_intersection(min_vecs, group - frac_coords, super_ftol_2) - def get_hnf(fu): + def get_hnf(form_units): """Get all possible distinct supercell matrices given a number of formula units in the supercell. Batches the matrices by the values in the diagonal (for less numpy overhead). @@ -2505,30 +2547,29 @@ def factors(n: int): if n % idx == 0: yield idx - for det in factors(fu): + for det in factors(form_units): if det == 1: continue for a in factors(det): for e in factors(det // a): g = det // a // e - yield ( - det, - np.array( - [ - [[a, b, c], [0, e, f], [0, 0, g]] - for b, c, f in itertools.product(range(a), range(a), range(e)) - ] - ), + supercell_matrices = np.array( + [ + [[a, b, c], [0, e, f], [0, 0, g]] + for b, c, f in itertools.product(range(a), range(a), range(e)) + ] ) - # we can't let sites match to their neighbors in the supercell + yield det, supercell_matrices + + # We can't let sites match to their neighbors in the supercell grouped_non_nbrs = [] for gf_coords in grouped_frac_coords: fdist = gf_coords[None, :, :] - gf_coords[:, None, :] fdist -= np.round(fdist) np.abs(fdist, fdist) non_nbrs = np.any(fdist > 2 * super_ftol[None, None, :], axis=-1) - # since we want sites to match to themselves + # Since we want sites to match to themselves np.fill_diagonal(non_nbrs, val=True) grouped_non_nbrs.append(non_nbrs) @@ -2536,7 +2577,7 @@ def factors(n: int): for size, ms in get_hnf(num_fu): inv_ms = np.linalg.inv(ms) - # find sets of lattice vectors that are present in min_vecs + # Find sets of lattice vectors that are present in min_vecs dist = inv_ms[:, :, None, :] - min_vecs[None, None, :, :] dist -= np.round(dist) np.abs(dist, dist) @@ -2556,19 +2597,19 @@ def factors(n: int): for gsites, gf_coords, non_nbrs in zip(grouped_sites, grouped_frac_coords, grouped_non_nbrs): all_frac = np.dot(gf_coords, latt_mat) - # calculate grouping of equivalent sites, represented by + # Calculate grouping of equivalent sites, represented by # adjacency matrix fdist = all_frac[None, :, :] - all_frac[:, None, :] fdist = np.abs(fdist - np.round(fdist)) close_in_prim = np.all(fdist < ftol[None, None, :], axis=-1) groups = np.logical_and(close_in_prim, non_nbrs) - # check that groups are correct + # Check that groups are correct if not np.all(np.sum(groups, axis=0) == size): valid = False break - # check that groups are all cliques + # Check that groups are all cliques for group in groups: if not np.all(groups[group][:, group]): valid = False @@ -2576,7 +2617,7 @@ def factors(n: int): if not valid: break - # add the new sites, averaging positions + # Add the new sites, averaging positions added = np.zeros(len(gsites)) new_fcoords = all_frac % 1 for grp_idx, group in enumerate(groups): @@ -2625,7 +2666,11 @@ def factors(n: int): return self.copy() - def get_orderings(self, mode: Literal["enum", "sqs"] = "enum", **kwargs) -> list[Structure]: + def get_orderings( + self, + mode: Literal["enum", "sqs"] = "enum", + **kwargs, + ) -> list[Structure]: """Get list of orderings for a disordered structure. If structure does not contain disorder, the default structure is returned. @@ -2669,7 +2714,12 @@ def get_orderings(self, mode: Literal["enum", "sqs"] = "enum", **kwargs) -> list return [run_mcsqs(self, **kwargs).bestsqs] raise ValueError("Invalid mode!") - def as_dict(self, verbosity=1, fmt=None, **kwargs) -> dict[str, Any]: + def as_dict( + self, + verbosity: int = 1, + fmt: Literal["abivars"] | None = None, + **kwargs, + ) -> dict[str, Any]: """Dict representation of Structure. Args: @@ -2679,17 +2729,17 @@ def as_dict(self, verbosity=1, fmt=None, **kwargs) -> dict[str, Any]: database. Set to 0 for an extremely lightweight version that only includes sufficient information to reconstruct the object. - fmt (str): Specifies a format for the dict. Defaults to None, - which is the default format used in pymatgen. Other options - include "abivars". + fmt ("abivars" | None): Specifies a format for the dict. + Defaults to None, which is the default format used + in pymatgen. Or "abivars". **kwargs: Allow passing of other kwargs needed for certain formats, e.g. "abivars". Returns: JSON-serializable dict representation. """ + # Return a dictionary with the ABINIT variables if fmt == "abivars": - # Returns a dictionary with the ABINIT variables from pymatgen.io.abinit.abiobjects import structure_to_abivars return structure_to_abivars(self, **kwargs) @@ -2706,7 +2756,7 @@ def as_dict(self, verbosity=1, fmt=None, **kwargs) -> dict[str, Any]: "properties": self.properties, } for site in self: - site_dict = site.as_dict(verbosity=verbosity) # type: ignore[call-arg] + site_dict = site.as_dict(verbosity=verbosity) del site_dict["lattice"] del site_dict["@module"] del site_dict["@class"] @@ -2714,15 +2764,16 @@ def as_dict(self, verbosity=1, fmt=None, **kwargs) -> dict[str, Any]: dct["sites"] = sites return dct - def as_dataframe(self): - """Create a Pandas dataframe of the sites. Structure-level attributes are stored in DataFrame.attrs. + def as_dataframe(self) -> pd.DataFrame: + """Create a Pandas DataFrame of the sites. + Structure-level attributes are stored in DataFrame.attrs. Example: - Species a b c x y z magmom - 0 (Si) 0.0 0.0 0.000000e+00 0.0 0.000000e+00 0.000000e+00 5 - 1 (Si) 0.0 0.0 1.000000e-7 0.0 -2.217138e-7 3.135509e-7 -5 + Species a b c x y z magmom + 0 (Si) 0.0 0.0 0.0 0.0 0.0 0.0 5 + 1 (Si) 0.0 0.0 0.0 0.0 0.0 0.0 -5 """ - # pandas lazy imported for speed (https://github.com/materialsproject/pymatgen/issues/3563) + # pandas lazy imported for speed (#3563) import pandas as pd data: list[list[str | float]] = [] @@ -2740,8 +2791,12 @@ def as_dataframe(self): return df @classmethod - def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) -> Self: - """Reconstitute a Structure object from a dict representation of Structure + def from_dict( + cls, + dct: dict[str, Any], + fmt: Literal["abivars"] | None = None, + ) -> Self: + """Reconstitute a Structure from a dict representation of Structure created using as_dict(). Args: @@ -2761,11 +2816,11 @@ def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) - charge = dct.get("charge") return cls.from_sites(sites, charge=charge, properties=dct.get("properties")) - def to(self, filename: str | Path = "", fmt: FileFormats = "", **kwargs) -> str: - """Outputs the structure to a file or string. + def to(self, filename: PathLike = "", fmt: FileFormats = "", **kwargs) -> str: + """Output the structure to a file or string. Args: - filename (str): If provided, output will be written to a file. If + filename (PathLike): If provided, output will be written to a file. If fmt is not specified, the format is determined from the filename. Defaults is None, i.e. string output. fmt (str): Format to output to. Defaults to JSON unless filename @@ -2786,7 +2841,7 @@ def to(self, filename: str | Path = "", fmt: FileFormats = "", **kwargs) -> str: if fmt == "cif" or fnmatch(filename.lower(), "*.cif*"): from pymatgen.io.cif import CifWriter - writer = CifWriter(self, **kwargs) + writer: Any = CifWriter(self, **kwargs) elif fmt == "mcif" or fnmatch(filename.lower(), "*.mcif*"): from pymatgen.io.cif import CifWriter @@ -2798,7 +2853,7 @@ def to(self, filename: str | Path = "", fmt: FileFormats = "", **kwargs) -> str: elif fmt == "cssr" or fnmatch(filename.lower(), "*.cssr*"): from pymatgen.io.cssr import Cssr - writer = Cssr(self) # type: ignore + writer = Cssr(self) elif fmt == "json" or fnmatch(filename.lower(), "*.json*"): json_str = json.dumps(self.as_dict()) if filename: @@ -2875,7 +2930,7 @@ def from_str( # type: ignore[override] merge_tol: float = 0.0, **kwargs, ) -> Structure | IStructure: - """Reads a structure from a string. + """Read a structure from a string. Args: input_string (str): String to parse. @@ -2951,15 +3006,20 @@ def from_str( # type: ignore[override] @classmethod def from_file( # type: ignore[override] - cls, filename: str | Path, primitive: bool = False, sort: bool = False, merge_tol: float = 0.0, **kwargs + cls, + filename: PathLike, + primitive: bool = False, + sort: bool = False, + merge_tol: float = 0.0, + **kwargs, ) -> Structure | IStructure: - """Reads a structure from a file. For example, anything ending in + """Read a structure from a file. For example, anything ending in a "cif" is assumed to be a Crystallographic Information Format file. Supported formats include CIF, POSCAR/CONTCAR, CHGCAR, LOCPOT, vasprun.xml, CSSR, Netcdf and pymatgen's JSON-serialized structures. Args: - filename (str): The filename to read from. + filename (PathLike): The file to read. primitive (bool): Whether to convert to a primitive cell. Defaults to False. sort (bool): Whether to sort sites. Default to False. merge_tol (float): If this is some positive number, sites that are within merge_tol from each other will be @@ -2980,10 +3040,6 @@ def from_file( # type: ignore[override] struct = struct.get_sorted_structure() return struct - from pymatgen.io.exciting import ExcitingInput - from pymatgen.io.lmto import LMTOCtrl - from pymatgen.io.vasp import Chgcar, Vasprun - fname = os.path.basename(filename) with zopen(filename, mode="rt", errors="replace") as file: contents = file.read() @@ -2993,8 +3049,12 @@ def from_file( # type: ignore[override] struct = cls.from_str(contents, fmt="poscar", primitive=primitive, sort=sort, merge_tol=merge_tol, **kwargs) elif fnmatch(fname, "CHGCAR*") or fnmatch(fname, "LOCPOT*"): + from pymatgen.io.vasp import Chgcar + struct = Chgcar.from_file(filename, **kwargs).structure elif fnmatch(fname, "vasprun*.xml*"): + from pymatgen.io.vasp import Vasprun + struct = Vasprun(filename, **kwargs).final_structure elif fnmatch(fname.lower(), "*.cssr*"): return cls.from_str(contents, fmt="cssr", primitive=primitive, sort=sort, merge_tol=merge_tol, **kwargs) @@ -3005,10 +3065,14 @@ def from_file( # type: ignore[override] elif fnmatch(fname, "*.xsf"): return cls.from_str(contents, fmt="xsf", primitive=primitive, sort=sort, merge_tol=merge_tol, **kwargs) elif fnmatch(fname, "input*.xml"): + from pymatgen.io.exciting import ExcitingInput + return ExcitingInput.from_file(fname, **kwargs).structure elif fnmatch(fname, "*rndstr.in*") or fnmatch(fname, "*lat.in*") or fnmatch(fname, "*bestsqs*"): return cls.from_str(contents, fmt="mcsqs", primitive=primitive, sort=sort, merge_tol=merge_tol, **kwargs) elif fnmatch(fname, "CTRL*"): + from pymatgen.io.lmto import LMTOCtrl + return LMTOCtrl.from_file(filename=filename, **kwargs).structure elif fnmatch(fname, "inp*.xml") or fnmatch(fname, "*.in*") or fnmatch(fname, "inp_*"): from pymatgen.io.fleur import FleurInput @@ -3104,7 +3168,7 @@ def __init__( charge_spin_check: bool = True, properties: dict | None = None, ) -> None: - """Create a Molecule. + """Create a IMolecule. Args: species: list of atomic species. Possible kinds of input include a @@ -3172,7 +3236,7 @@ def __eq__(self, other: object) -> bool: if not all(hasattr(other, attr) for attr in needed_attrs): return NotImplemented - other = cast(IMolecule, other) + other = cast(IMolecule, other) # TODO @DanielYang59: fix type if len(self) != len(other): return False @@ -3185,7 +3249,7 @@ def __eq__(self, other: object) -> bool: return all(site in other for site in self) def __hash__(self) -> int: - # For now, just use the composition hash code. + """Use the composition hash for now.""" return hash(self.composition) def __repr__(self) -> str: @@ -3194,7 +3258,7 @@ def __repr__(self) -> str: def __str__(self) -> str: outs = [ f"Full Formula ({self.composition.formula})", - "Reduced Formula: " + self.composition.reduced_formula, + f"Reduced Formula: {self.composition.reduced_formula}", f"Charge = {self._charge}, Spin Mult = {self._spin_multiplicity}", f"Sites ({len(self)})", ] @@ -3224,7 +3288,7 @@ def nelectrons(self) -> float: return n_electrons @property - def center_of_mass(self) -> np.ndarray: + def center_of_mass(self) -> NDArray: """Center of mass of molecule.""" center = np.zeros(3) total_weight: float = 0 @@ -3238,7 +3302,7 @@ def copy(self) -> Self: """Convenience method to get a copy of the molecule. Returns: - IMolecule | Molecule + IMolecule """ return type(self).from_sites(self, properties=self.properties) @@ -3251,7 +3315,7 @@ def from_sites( validate_proximity: bool = False, charge_spin_check: bool = True, properties: dict | None = None, - ) -> IMolecule | Molecule: + ) -> Self: """Convenience constructor to make a Molecule from a list of sites. Args: @@ -3271,7 +3335,7 @@ def from_sites( ValueError: If sites is empty Returns: - Molecule + IMolecule """ if len(sites) < 1: raise ValueError(f"You need at least 1 site to make a {cls.__name__}") @@ -3292,9 +3356,9 @@ def from_sites( properties=properties, ) - def break_bond(self, ind1: int, ind2: int, tol: float = 0.2) -> tuple[IMolecule | Molecule, ...]: - """Get two molecules based on breaking the bond between atoms at index - ind1 and ind2. + def break_bond(self, ind1: int, ind2: int, tol: float = 0.2) -> tuple[Self, Self]: + """Get two molecules based on breaking the bond between atoms + at index ind1 and ind2. Args: ind1 (int): 1st site index @@ -3305,10 +3369,10 @@ def break_bond(self, ind1: int, ind2: int, tol: float = 0.2) -> tuple[IMolecule 20% longer. Returns: - Two Molecule objects representing the two clusters formed from + Two IMolecule representing the clusters formed from breaking the bond. """ - clusters = [[self[ind1]], [self[ind2]]] + clusters = ([self[ind1]], [self[ind2]]) sites = [site for idx, site in enumerate(self) if idx not in (ind1, ind2)] @@ -3316,7 +3380,7 @@ def belongs_to_cluster(site, cluster): return any(CovalentBond.is_bonded(site, test_site, tol=tol) for test_site in cluster) while len(sites) > 0: - unmatched = [] + unmatched: list[PeriodicSite] = [] for site in sites: for cluster in clusters: if belongs_to_cluster(site, cluster): @@ -3329,7 +3393,7 @@ def belongs_to_cluster(site, cluster): raise ValueError("Not all sites are matched!") sites = unmatched - return tuple(type(self).from_sites(cluster) for cluster in clusters) + return cast(tuple[Self, Self], tuple(map(type(self).from_sites, clusters))) def get_covalent_bonds(self, tol: float = 0.2) -> list[CovalentBond]: """Determine the covalent bonds in a molecule. @@ -3347,9 +3411,15 @@ def get_covalent_bonds(self, tol: float = 0.2) -> list[CovalentBond]: bonds.append(CovalentBond(site1, site2)) return bonds - def get_zmatrix(self): + def get_zmatrix(self) -> str: """Get a z-matrix representation of the molecule.""" + # TODO: allow more z-matrix conventions for element/site description + def find_nn_pos_before_site(site_idx: int): + """Get index of nearest neighbor atoms.""" + all_dist = [(self.get_distance(site_idx, idx), idx) for idx in range(site_idx)] + all_dist = sorted(all_dist, key=lambda x: x[0]) + return [d[1] for d in all_dist] output = [] output_var = [] @@ -3357,18 +3427,18 @@ def get_zmatrix(self): if idx == 0: output.append(f"{site.specie}") elif idx == 1: - nn = self._find_nn_pos_before_site(idx) + nn = find_nn_pos_before_site(idx) bond_length = self.get_distance(idx, nn[0]) output.append(f"{self[idx].specie} {nn[0] + 1} B{idx}") output_var.append(f"B{idx}={bond_length:.6f}") elif idx == 2: - nn = self._find_nn_pos_before_site(idx) + nn = find_nn_pos_before_site(idx) bond_length = self.get_distance(idx, nn[0]) angle = self.get_angle(idx, nn[0], nn[1]) output.append(f"{self[idx].specie} {nn[0] + 1} B{idx} {nn[1] + 1} A{idx}") output_var.extend((f"B{idx}={bond_length:.6f}", f"A{idx}={angle:.6f}")) else: - nn = self._find_nn_pos_before_site(idx) + nn = find_nn_pos_before_site(idx) bond_length = self.get_distance(idx, nn[0]) angle = self.get_angle(idx, nn[0], nn[1]) dih = self.get_dihedral(idx, nn[0], nn[1], nn[2]) @@ -3376,13 +3446,7 @@ def get_zmatrix(self): output_var.extend((f"B{idx}={bond_length:.6f}", f"A{idx}={angle:.6f}", f"D{idx}={dih:.6f}")) return "\n".join(output) + "\n\n" + "\n".join(output_var) - def _find_nn_pos_before_site(self, site_idx): - """Get index of nearest neighbor atoms.""" - all_dist = [(self.get_distance(site_idx, idx), idx) for idx in range(site_idx)] - all_dist = sorted(all_dist, key=lambda x: x[0]) - return [d[1] for d in all_dist] - - def as_dict(self): + def as_dict(self) -> dict: """JSON-serializable dict representation of Molecule.""" dct = { "@module": type(self).__module__, @@ -3396,18 +3460,18 @@ def as_dict(self): site_dict = site.as_dict() del site_dict["@module"] del site_dict["@class"] - dct["sites"].append(site_dict) + cast(list, dct["sites"]).append(site_dict) return dct @classmethod - def from_dict(cls, dct: dict) -> IMolecule | Molecule: + def from_dict(cls, dct: dict) -> Self: """Reconstitute a Molecule object from a dict representation created using as_dict(). Args: dct (dict): dict representation of Molecule. Returns: - Molecule + IMolecule """ sites = [Site.from_dict(sd) for sd in dct["sites"]] charge = dct.get("charge", 0) @@ -3526,8 +3590,8 @@ def get_boxed_structure( if a <= x_range or b <= y_range or c <= z_range: raise ValueError("Box is not big enough to contain Molecule.") - lattice = Lattice.from_parameters(a * images[0], b * images[1], c * images[2], 90, 90, 90) # type: ignore - nimages: int = images[0] * images[1] * images[2] # type: ignore + lattice = Lattice.from_parameters(a * images[0], b * images[1], c * images[2], 90, 90, 90) + nimages: int = images[0] * images[1] * images[2] all_coords: list[ArrayLike] = [] centered_coords = self.cart_coords - self.center_of_mass + offset @@ -3535,7 +3599,7 @@ def get_boxed_structure( for i, j, k in itertools.product( list(range(images[0])), list(range(images[1])), - list(range(images[2])), # type: ignore + list(range(images[2])), ): box_center = [(i + 0.5) * a, (j + 0.5) * b, (k + 0.5) * c] if random_rotation: @@ -3570,7 +3634,8 @@ def get_boxed_structure( if x_max > a or x_min < 0 or y_max > b or y_min < 0 or z_max > c or z_min < 0: raise ValueError("Molecule crosses boundary of box") all_coords.extend(new_coords) - sprops = {k: v * nimages for k, v in self.site_properties.items()} # type: ignore + + site_props = {key: sequence * nimages for key, sequence in self.site_properties.items()} # type: ignore[operator] if cls is None: cls = Structure @@ -3581,7 +3646,7 @@ def get_boxed_structure( self.species * nimages, all_coords, coords_are_cartesian=True, - site_properties=sprops, + site_properties=site_props, labels=self.labels * nimages, ).get_sorted_structure() @@ -3590,15 +3655,15 @@ def get_boxed_structure( self.species * nimages, coords, coords_are_cartesian=True, - site_properties=sprops, + site_properties=site_props, labels=self.labels * nimages, ) - def get_centered_molecule(self) -> IMolecule | Molecule: + def get_centered_molecule(self) -> Self: """Get a Molecule centered at the center of mass. Returns: - Molecule centered with center of mass at origin. + IMolecule centered with center of mass at origin. """ center = self.center_of_mass new_coords = np.array(self.cart_coords) - center @@ -3630,15 +3695,15 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: str: String representation of molecule in given format. If a filename is provided, the same string is written to the file. """ - from pymatgen.io.babel import BabelMolAdaptor - from pymatgen.io.gaussian import GaussianInput - from pymatgen.io.xyz import XYZ - fmt = fmt.lower() writer: Any if fmt == "xyz" or fnmatch(filename.lower(), "*.xyz*"): + from pymatgen.io.xyz import XYZ + writer = XYZ(self) elif any(fmt == ext or fnmatch(filename.lower(), f"*.{ext}*") for ext in ("gjf", "g03", "g09", "com", "inp")): + from pymatgen.io.gaussian import GaussianInput + writer = GaussianInput(self) elif fmt == "json" or fnmatch(filename, "*.json*") or fnmatch(filename, "*.mson*"): json_str = json.dumps(self.as_dict()) @@ -3656,9 +3721,11 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: file.write(yaml_str) return yaml_str else: + from pymatgen.io.babel import BabelMolAdaptor + match = re.search(r"\.(pdb|mol|mdl|sdf|sd|ml2|sy2|mol2|cml|mrv)", filename.lower()) if not fmt and match: - fmt = match.group(1) + fmt = match[1] writer = BabelMolAdaptor(self) return writer.write_file(filename, file_format=fmt) @@ -3668,8 +3735,10 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: @classmethod def from_str( # type: ignore[override] - cls, input_string: str, fmt: Literal["xyz", "gjf", "g03", "g09", "com", "inp", "json", "yaml"] - ) -> IMolecule | Molecule: + cls, + input_string: str, + fmt: Literal["xyz", "gjf", "g03", "g09", "com", "inp", "json", "yaml"], + ) -> Self | Molecule: """Reads the molecule from a string. Args: @@ -3683,52 +3752,60 @@ def from_str( # type: ignore[override] Returns: IMolecule or Molecule. """ - from pymatgen.io.gaussian import GaussianInput - from pymatgen.io.xyz import XYZ + fmt = cast(Literal["xyz", "gjf", "g03", "g09", "com", "inp", "json", "yaml"], fmt.lower()) - fmt = fmt.lower() # type: ignore[assignment] if fmt == "xyz": + from pymatgen.io.xyz import XYZ + mol = XYZ.from_str(input_string).molecule - elif fmt in ["gjf", "g03", "g09", "com", "inp"]: + + elif fmt in {"gjf", "g03", "g09", "com", "inp"}: + from pymatgen.io.gaussian import GaussianInput + mol = GaussianInput.from_str(input_string).molecule + elif fmt == "json": dct = json.loads(input_string) return cls.from_dict(dct) - elif fmt in ("yaml", "yml"): + + elif fmt in {"yaml", "yml"}: yaml = YAML() dct = yaml.load(input_string) return cls.from_dict(dct) + else: from pymatgen.io.babel import BabelMolAdaptor mol = BabelMolAdaptor.from_str(input_string, file_format=fmt).pymatgen_mol + return cls.from_sites(mol, properties=mol.properties) @classmethod - def from_file(cls, filename: str | Path) -> Self | None: # type: ignore[override] - """Reads a molecule from a file. Supported formats include xyz, + def from_file(cls, filename: PathLike) -> Self | None: # type: ignore[override] + """Read a molecule from a file. Supported formats include xyz, gaussian input (gjf|g03|g09|com|inp), Gaussian output (.out|and pymatgen's JSON-serialized molecules. Using openbabel, many more extensions are supported but requires openbabel to be installed. Args: - filename (str | Path): The filename to read from. + filename (PathLike): The file to read. Returns: Molecule """ filename = str(filename) - from pymatgen.io.gaussian import GaussianOutput with zopen(filename) as file: contents = file.read() fname = filename.lower() if fnmatch(fname, "*.xyz*"): return cls.from_str(contents, fmt="xyz") - if any(fnmatch(fname.lower(), f"*.{r}*") for r in ["gjf", "g03", "g09", "com", "inp"]): + if any(fnmatch(fname.lower(), f"*.{r}*") for r in ("gjf", "g03", "g09", "com", "inp")): return cls.from_str(contents, fmt="g09") - if any(fnmatch(fname.lower(), f"*.{r}*") for r in ["out", "lis", "log"]): + if any(fnmatch(fname.lower(), f"*.{r}*") for r in ("out", "lis", "log")): + from pymatgen.io.gaussian import GaussianOutput + return GaussianOutput(filename).final_structure if fnmatch(fname, "*.json*") or fnmatch(fname, "*.mson*"): return cls.from_str(contents, fmt="json") @@ -3737,7 +3814,7 @@ def from_file(cls, filename: str | Path) -> Self | None: # type: ignore[overrid from pymatgen.io.babel import BabelMolAdaptor if match := re.search(r"\.(pdb|mol|mdl|sdf|sd|ml2|sy2|mol2|cml|mrv)", filename.lower()): - new = BabelMolAdaptor.from_file(filename, match.group(1)).pymatgen_mol + new = BabelMolAdaptor.from_file(filename, match[1]).pymatgen_mol new.__class__ = cls return new raise ValueError("Cannot determine file type.") @@ -3814,9 +3891,9 @@ def __init__( properties=properties, ) - self._sites: list[PeriodicSite] = list(self._sites) # type: ignore + self._sites: list[PeriodicSite] = list(self._sites) # type: ignore[assignment] - def __setitem__( # type: ignore + def __setitem__( # type: ignore[override] self, idx: int | slice | Sequence[int] | SpeciesLike, site: SpeciesLike | PeriodicSite | Sequence | dict[SpeciesLike, float], @@ -3833,33 +3910,33 @@ def __setitem__( # type: ignore or a tuple of up to length 3. Examples: - s[0] = "Fe" - s[0] = Element("Fe") + structure[0] = "Fe" + structure[0] = Element("Fe") both replaces the species only. - s[0] = "Fe", [0.5, 0.5, 0.5] + structure[0] = "Fe", [0.5, 0.5, 0.5] Replaces site and *fractional* coordinates. Any properties are inherited from current site. - s[0] = "Fe", [0.5, 0.5, 0.5], spin=2 + structure[0] = "Fe", [0.5, 0.5, 0.5], spin=2 Replaces site and *fractional* coordinates and properties. - s[(0, 2, 3)] = "Fe" + structure[(0, 2, 3)] = "Fe" Replaces sites 0, 2 and 3 with Fe. - s[0::2] = "Fe" + structure[0::2] = "Fe" Replaces all even index sites with Fe. - s["Mn"] = "Fe" + structure["Mn"] = "Fe" Replaces all Mn in the structure with Fe. This is a short form for the more complex replace_species. - s["Mn"] = "Fe0.5Co0.5" + structure["Mn"] = "Fe0.5Co0.5" Replaces all Mn in the structure with Fe: 0.5, Co: 0.5, i.e., creates a disordered structure! """ if isinstance(idx, int): indices = [idx] elif isinstance(idx, (str, Element, Species)): - self.replace_species({idx: site}) # type: ignore + self.replace_species({idx: site}) # type: ignore[dict-item] return elif isinstance(idx, slice): to_mod = self[idx] @@ -3874,17 +3951,19 @@ def __setitem__( # type: ignore if len(indices) != 1: raise ValueError("Site assignments makes sense only for single int indices!") self._sites[ii] = site + elif isinstance(site, str) or (not isinstance(site, collections.abc.Sequence)): - self._sites[ii].species = site # type: ignore + self._sites[ii].species = site # type: ignore[assignment] + else: - self._sites[ii].species = site[0] # type: ignore + self._sites[ii].species = site[0] # type: ignore[assignment, index] if len(site) > 1: - self._sites[ii].frac_coords = site[1] # type: ignore + self._sites[ii].frac_coords = site[1] # type: ignore[index] if len(site) > 2: - self._sites[ii].properties = site[2] # type: ignore + self._sites[ii].properties = site[2] # type: ignore[assignment, index] def __delitem__(self, idx: SupportsIndex | slice) -> None: - """Deletes a site from the Structure.""" + """Delete a site from the Structure.""" self._sites.__delitem__(idx) @property @@ -3900,14 +3979,14 @@ def lattice(self, lattice: ArrayLike | Lattice) -> None: for site in self: site.lattice = lattice - def append( # type: ignore + def append( # type: ignore[override] self, species: CompositionLike, coords: ArrayLike, coords_are_cartesian: bool = False, validate_proximity: bool = False, properties: dict | None = None, - ): + ) -> Self: """Append a site to the structure. Args: @@ -3931,7 +4010,7 @@ def append( # type: ignore properties=properties, ) - def insert( # type: ignore + def insert( # type: ignore[override] self, idx: int, species: CompositionLike, @@ -3940,7 +4019,7 @@ def insert( # type: ignore validate_proximity: bool = False, properties: dict | None = None, label: str | None = None, - ): + ) -> Self: """Insert a site to the structure. Args: @@ -3965,7 +4044,7 @@ def insert( # type: ignore if site.distance(new_site) < self.DISTANCE_TOLERANCE: raise ValueError("New site is too close to an existing site!") - self.sites.insert(idx, new_site) + cast(list[PeriodicSite], self.sites).insert(idx, new_site) return self @@ -3999,14 +4078,19 @@ def replace( elif coords_are_cartesian: frac_coords = self._lattice.get_fractional_coords(coords) else: - frac_coords = coords # type: ignore + frac_coords = coords new_site = PeriodicSite(species, frac_coords, self._lattice, properties=properties, label=label) - self.sites[idx] = new_site + cast(list[PeriodicSite], self.sites)[idx] = new_site return self - def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> Self: + def substitute( + self, + index: int, + func_group: IMolecule | Molecule | str, + bond_order: int = 1, + ) -> Self: """Substitute atom at index with a functional group. Args: @@ -4041,8 +4125,8 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or all_non_terminal_nn.append((nn, dist)) break - if len(all_non_terminal_nn) == 0: - raise RuntimeError("Can't find a non-terminal neighbor to attach functional group to.") + if not all_non_terminal_nn: + raise RuntimeError("Can't find a non-terminal neighbor to attach functional group to") non_terminal_nn = min(all_non_terminal_nn, key=lambda d: d[1])[0] @@ -4067,7 +4151,7 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or # bond length is equal to the bond length. try: bl = get_bond_length(non_terminal_nn.specie, fgroup[1].specie, bond_order=bond_order) - # Catches for case of incompatibility between Element(s) and Species(s) + # Catch for case of incompatibility between Element(s) and Species(s) except TypeError: bl = None @@ -4195,7 +4279,7 @@ def operate_site(site): return self - def apply_strain(self, strain: ArrayLike, inplace: bool = True) -> Structure: + def apply_strain(self, strain: ArrayLike, inplace: bool = True) -> Self: """Apply a strain to the lattice. Args: @@ -4237,7 +4321,11 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> Self: return self def translate_sites( - self, indices: int | Sequence[int], vector: ArrayLike, frac_coords: bool = True, to_unit_cell: bool = True + self, + indices: int | Sequence[int], + vector: ArrayLike, + frac_coords: bool = True, + to_unit_cell: bool = True, ) -> Self: """Translate specific sites by some vector, keeping the sites within the unit cell. Modifies the structure in place. @@ -4339,7 +4427,7 @@ def perturb(self, distance: float, min_distance: float | None = None) -> Self: """ def get_rand_vec(): - # deals with zero vectors. + # Deal with zero vectors vector = np.random.randn(3) vnorm = np.linalg.norm(vector) dist = distance @@ -4352,7 +4440,12 @@ def get_rand_vec(): return self - def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, in_place: bool = True) -> Structure: + def make_supercell( + self, + scaling_matrix: ArrayLike, + to_unit_cell: bool = True, + in_place: bool = True, + ) -> Self: """Create a supercell. Args: @@ -4389,7 +4482,7 @@ def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, i return struct def scale_lattice(self, volume: float) -> Self: - """Perform a scaling of the lattice vectors so that length proportions + """Perform scaling of the lattice vectors so that length proportions and angles are preserved. Args: @@ -4408,7 +4501,7 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average Args: tol (float): Tolerance for distance to merge sites. - mode ('sum' | 'delete' | 'average'): "delete" means duplicate sites are + mode ("sum" | "delete" | "average"): "delete" means duplicate sites are deleted. "sum" means the occupancies are summed for the sites. "average" means that the site is deleted but the properties are averaged Only first letter is considered. @@ -4420,8 +4513,8 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average np.fill_diagonal(dist_mat, 0) clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance") sites = [] - for c in np.unique(clusters): - inds = np.where(clusters == c)[0] + for cluster in np.unique(clusters): + inds = np.where(clusters == cluster)[0] species = self[inds[0]].species coords = self[inds[0]].frac_coords props = self[inds[0]].properties @@ -4503,7 +4596,11 @@ def relax( verbose=verbose, ) - def calculate(self, calculator: str | Calculator = "m3gnet", verbose: bool = False) -> Calculator: + def calculate( + self, + calculator: str | Calculator = "m3gnet", + verbose: bool = False, + ) -> Calculator: """Perform an ASE calculation. Args: @@ -4572,8 +4669,8 @@ def from_prototype(cls, prototype: str, species: Sequence, **kwargs) -> Self: class Molecule(IMolecule, collections.abc.MutableSequence): - """Mutable Molecule. It has all the methods in IMolecule, but in addition, - it allows a user to perform edits on the molecule. + """Mutable Molecule. It has all the methods in IMolecule, + and allows a user to perform edits on the molecule. """ __hash__ = None # type: ignore[assignment] @@ -4590,7 +4687,7 @@ def __init__( charge_spin_check: bool = True, properties: dict | None = None, ) -> None: - """Create a MutableMolecule. + """Create a mutable Molecule. Args: species: list of atomic species. Possible kinds of input include a @@ -4630,10 +4727,12 @@ def __init__( charge_spin_check=charge_spin_check, properties=properties, ) - self._sites: list[Site] = list(self._sites) # type: ignore + self._sites: list[Site] = list(self._sites) - def __setitem__( # type: ignore - self, idx: int | slice | Sequence[int] | SpeciesLike, site: SpeciesLike | Site | Sequence + def __setitem__( # type: ignore[override] + self, + idx: int | slice | Sequence[int] | SpeciesLike, + site: SpeciesLike | Site | Sequence, ) -> None: """Modify a site in the molecule. @@ -4648,12 +4747,15 @@ def __setitem__( # type: ignore """ if isinstance(idx, int): indices = [idx] + elif isinstance(idx, (str, Element, Species)): - self.replace_species({idx: site}) # type: ignore + self.replace_species({idx: site}) # type: ignore[dict-item] return + elif isinstance(idx, slice): to_mod = self[idx] indices = [idx for idx, site in enumerate(self._sites) if site in to_mod] + else: indices = list(idx) @@ -4661,25 +4763,25 @@ def __setitem__( # type: ignore if isinstance(site, Site): self._sites[ii] = site elif isinstance(site, str) or not isinstance(site, collections.abc.Sequence): - self._sites[ii].species = site # type: ignore + self._sites[ii].species = site # type: ignore[assignment] else: - self._sites[ii].species = site[0] # type: ignore + self._sites[ii].species = site[0] # type: ignore[assignment, index] if len(site) > 1: - self._sites[ii].coords = site[1] # type: ignore + self._sites[ii].coords = site[1] # type: ignore[assignment, index] if len(site) > 2: - self._sites[ii].properties = site[2] # type: ignore + self._sites[ii].properties = site[2] # type: ignore[assignment, index] def __delitem__(self, idx: SupportsIndex | slice) -> None: """Deletes a site from the Structure.""" self._sites.__delitem__(idx) - def append( # type: ignore + def append( # type: ignore[override] self, species: CompositionLike, coords: ArrayLike, validate_proximity: bool = False, properties: dict | None = None, - ) -> Molecule: + ) -> Self: """Append a site to the molecule. Args: @@ -4700,7 +4802,7 @@ def append( # type: ignore properties=properties, ) - def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = None) -> Molecule: + def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = None) -> Self: """Set the charge and spin multiplicity. Args: @@ -4733,7 +4835,7 @@ def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = Non return self - def insert( # type: ignore + def insert( # type: ignore[override] self, idx: int, species: CompositionLike, @@ -4741,7 +4843,7 @@ def insert( # type: ignore validate_proximity: bool = False, properties: dict | None = None, label: str | None = None, - ) -> Molecule: + ) -> Self: """Insert a site to the molecule. Args: @@ -4761,11 +4863,11 @@ def insert( # type: ignore for site in self: if site.distance(new_site) < self.DISTANCE_TOLERANCE: raise ValueError("New site is too close to an existing site!") - self.sites.insert(idx, new_site) + cast(list[PeriodicSite], self.sites).insert(idx, new_site) return self - def remove_species(self, species: Sequence[SpeciesLike]) -> Molecule: + def remove_species(self, species: Sequence[SpeciesLike]) -> Self: """Remove all occurrences of a species from a molecule. Args: @@ -4783,7 +4885,7 @@ def remove_species(self, species: Sequence[SpeciesLike]) -> Molecule: self.sites = new_sites return self - def remove_sites(self, indices: Sequence[int]) -> Molecule: + def remove_sites(self, indices: Sequence[int]) -> Self: """Delete sites with at indices. Args: @@ -4795,7 +4897,7 @@ def remove_sites(self, indices: Sequence[int]) -> Molecule: self.sites = [self[idx] for idx in range(len(self)) if idx not in indices] return self - def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLike | None = None) -> Molecule: + def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLike | None = None) -> Self: """Translate specific sites by some vector, keeping the sites within the unit cell. @@ -4823,7 +4925,7 @@ def rotate_sites( theta: float = 0.0, axis: ArrayLike | None = None, anchor: ArrayLike | None = None, - ) -> Molecule: + ) -> Self: """Rotate specific sites by some angle around vector at anchor. Args: @@ -4860,7 +4962,7 @@ def rotate_sites( return self - def perturb(self, distance: float) -> Molecule: + def perturb(self, distance: float) -> Self: """Perform a random perturbation of the sites in a structure to break symmetries. @@ -4872,7 +4974,7 @@ def perturb(self, distance: float) -> Molecule: """ def get_rand_vec(): - # deals with zero vectors. + # Deal with zero vectors vector = np.random.randn(3) vnorm = np.linalg.norm(vector) return vector / vnorm * distance if vnorm != 0 else get_rand_vec() @@ -4882,7 +4984,7 @@ def get_rand_vec(): return self - def apply_operation(self, symm_op: SymmOp) -> Molecule: + def apply_operation(self, symm_op: SymmOp) -> Self: """Apply a symmetry operation to the molecule. Args: @@ -4900,7 +5002,12 @@ def operate_site(site): return self - def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> Molecule: + def substitute( + self, + index: int, + func_group: IMolecule | Molecule | str, + bond_order: int = 1, + ) -> Self: """Substitute atom at index with a functional group. Args: @@ -4935,8 +5042,8 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or all_non_terminal_nn.append(nn) break - if len(all_non_terminal_nn) == 0: - raise RuntimeError("Can't find a non-terminal neighbor to attach functional group to.") + if not all_non_terminal_nn: + raise RuntimeError("Can't find a non-terminal neighbor to attach functional group to") non_terminal_nn = min(all_non_terminal_nn, key=lambda nn: nn.nn_distance) @@ -5030,12 +5137,11 @@ def relax( verbose=verbose, ) - def calculate(self, calculator: str | Calculator = "gfn2-xtb", verbose: bool = False) -> Calculator: + def calculate(self, calculator: Literal["gfn2-xtb"] | Calculator = "gfn2-xtb", verbose: bool = False) -> Calculator: """Perform an ASE calculation. Args: - calculator: An ASE Calculator or a string from the following options: "gfn2-xtb". - Defaults to 'gfn2-xtb'. + calculator: An ASE Calculator or "gfn2-xtb". Defaults to 'gfn2-xtb'. verbose (bool): whether to print stdout. Defaults to False. Returns: diff --git a/pymatgen/electronic_structure/boltztrap.py b/pymatgen/electronic_structure/boltztrap.py index 9788d1249d1..a0de391a7d5 100644 --- a/pymatgen/electronic_structure/boltztrap.py +++ b/pymatgen/electronic_structure/boltztrap.py @@ -313,13 +313,11 @@ def write_struct(self, output_file) -> None: """ if self._symprec is not None: sym = SpacegroupAnalyzer(self._bs.structure, symprec=self._symprec) - elif self._symprec is None: - pass - with open(output_file, mode="w") as file: + with open(output_file, mode="w", encoding="utf-8") as file: if self._symprec is not None: file.write(f"{self._bs.structure.formula} {sym.get_space_group_symbol()}\n") # type: ignore[reportPossiblyUnboundVariable] - elif self._symprec is None: + else: file.write(f"{self._bs.structure.formula} symmetries disabled\n") file.write( @@ -330,13 +328,11 @@ def write_struct(self, output_file) -> None: + "\n" ) - if self._symprec is not None: - ops = sym.get_symmetry_dataset()["rotations"] # type: ignore[reportPossiblyUnboundVariable] - elif self._symprec is None: - ops = [np.eye(3)] - file.write(f"{len(ops)}\n") # type: ignore[reportPossiblyUnboundVariable] + ops = [np.eye(3)] if self._symprec is None else sym.get_symmetry_dataset()["rotations"] # type: ignore[reportPossiblyUnboundVariable] + + file.write(f"{len(ops)}\n") - for op in ops: # type: ignore[reportPossiblyUnboundVariable] + for op in ops: for row in op: file.write(f"{' '.join(map(str, row))}\n") diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 10fddb06fa2..c94c114d6d5 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1630,7 +1630,7 @@ def __init__( count += 1 else: - # The following just presents a deterministic ordering. + # The following just presents a deterministic ordering unique_sites = [ (min(sites, key=lambda site: tuple(abs(x) for x in site.frac_coords)), len(sites)) for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites # type: ignore[reportPossiblyUnboundVariable] diff --git a/pymatgen/io/gaussian.py b/pymatgen/io/gaussian.py index 64900152145..de4be8f1bb3 100644 --- a/pymatgen/io/gaussian.py +++ b/pymatgen/io/gaussian.py @@ -1219,7 +1219,7 @@ def get_spectre_plot(self, sigma=0.05, step=0.01): Returns: A dict: {"energies": values, "lambda": values, "xas": values} - where values are lists of abscissa (energies, lamba) and + where values are lists of abscissa (energies, lambda) and the sum of gaussian functions (xas). A matplotlib plot. """ diff --git a/pymatgen/io/vasp/optics.py b/pymatgen/io/vasp/optics.py index 2be4d9c6df4..412872fa751 100644 --- a/pymatgen/io/vasp/optics.py +++ b/pymatgen/io/vasp/optics.py @@ -215,7 +215,7 @@ def plot_weighted_transition_data( Since the computation of the final spectrum (especially the smearing part) is still fairly expensive. This function can be used to check the values of some portion of the spectrum (defined by the mask). - In a sense, we are lookin at the imaginary part of the dielectric function + In a sense, we are looking at the imaginary part of the dielectric function before the smearing is applied. Args: