Skip to content

Commit

Permalink
Add types annotations for core.interface (#3822)
Browse files Browse the repository at this point in the history
* tweak type and docstring

move dunder methods to the top

add more types and tweaks

relocate more dunder methods to top

more types and format tweaks

fix type error

add types for composition

help fix #3792 (comment)

reverse compare order for readability

Revert "reverse compare order for readability"

This reverts commit 05ea23a.

Revert "help fix #3792 (comment)"

This reverts commit cae7aed.

add types for `core.bonds`

finish `core.ion`

add some types

revert non-interface changes

* use math.gcd over gcd

* use more specific types for ClassVar

* more: use more specific types for ClassVar

* fix some type errors and comment tweaks

* fix mypy errors

* enable types: more type errors to fix

* fix type errors

* fix type errors

* fix mypy errors doesn't show locally

* revert change in test and convert rotation_axis and plane to tuple

* cast plane type to tuple

* remove `del` of var name

* add and update new type `Tuple3Ints = tuple[int, int, int]`

* relocate `Tuple4Ints` to `core.interface`

* relocate `Tuple4Ints`

* use `Tuple3Floats`

* revert usage of assert_allclose

* use more meaningful types

* fix replacement

* Revert "fix replacement"

This reverts commit 6b9589a.

* revert type aliases in docstring
  • Loading branch information
DanielYang59 authored Jun 6, 2024
1 parent 5c7889c commit 65d5379
Show file tree
Hide file tree
Showing 32 changed files with 796 additions and 667 deletions.
57 changes: 26 additions & 31 deletions pymatgen/analysis/diffraction/tem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymatgen.analysis.diffraction.core import AbstractDiffractionPatternCalculator
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.util.string import latexify_spacegroup, unicodeify_spacegroup
from pymatgen.util.typing import Tuple3Ints

if TYPE_CHECKING:
from numpy.typing import NDArray
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self,
symprec: float | None = None,
voltage: float = 200,
beam_direction: tuple[int, int, int] = (0, 0, 1),
beam_direction: Tuple3Ints = (0, 0, 1),
camera_length: int = 160,
debye_waller_factors: dict[str, float] | None = None,
cs: float = 1,
Expand Down Expand Up @@ -104,9 +105,7 @@ def generate_points(coord_left: int = -10, coord_right: int = 10) -> np.ndarray:
points_matrix = (np.ravel(points[i]) for i in range(3))
return np.vstack(list(points_matrix)).transpose()

def zone_axis_filter(
self, points: list[tuple[int, int, int]] | np.ndarray, laue_zone: int = 0
) -> list[tuple[int, int, int]]:
def zone_axis_filter(self, points: list[Tuple3Ints] | np.ndarray, laue_zone: int = 0) -> list[Tuple3Ints]:
"""Filter out all points that exist within the specified Laue zone according to the zone axis rule.
Args:
Expand All @@ -122,11 +121,11 @@ def zone_axis_filter(
return []
filtered = np.where(np.dot(np.array(self.beam_direction), np.transpose(points)) == laue_zone)
result = points[filtered] # type: ignore
return cast(list[tuple[int, int, int]], [tuple(x) for x in result.tolist()])
return cast(list[Tuple3Ints], [tuple(x) for x in result.tolist()])

def get_interplanar_spacings(
self, structure: Structure, points: list[tuple[int, int, int]] | np.ndarray
) -> dict[tuple[int, int, int], float]:
self, structure: Structure, points: list[Tuple3Ints] | np.ndarray
) -> dict[Tuple3Ints, float]:
"""
Args:
structure (Structure): the input structure.
Expand All @@ -142,9 +141,7 @@ def get_interplanar_spacings(
interplanar_spacings_val = np.array([structure.lattice.d_hkl(x) for x in points_filtered])
return dict(zip(points_filtered, interplanar_spacings_val))

def bragg_angles(
self, interplanar_spacings: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
def bragg_angles(self, interplanar_spacings: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""Get the Bragg angles for every hkl point passed in (where n = 1).
Args:
Expand All @@ -158,7 +155,7 @@ def bragg_angles(
bragg_angles_val = np.arcsin(self.wavelength_rel() / (2 * interplanar_spacings_val))
return dict(zip(plane, bragg_angles_val))

def get_s2(self, bragg_angles: dict[tuple[int, int, int], float]) -> dict[tuple[int, int, int], float]:
def get_s2(self, bragg_angles: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""
Calculates the s squared parameter (= square of sin theta over lambda) for each hkl plane.
Expand All @@ -175,8 +172,8 @@ def get_s2(self, bragg_angles: dict[tuple[int, int, int], float]) -> dict[tuple[
return dict(zip(plane, s2_val))

def x_ray_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[str, dict[tuple[int, int, int], float]]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[str, dict[Tuple3Ints, float]]:
"""
Calculates x-ray factors, which are required to calculate atomic scattering factors. Method partially inspired
by the equivalent process in the xrd module.
Expand Down Expand Up @@ -205,8 +202,8 @@ def x_ray_factors(
return x_ray_factors

def electron_scattering_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[str, dict[tuple[int, int, int], float]]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[str, dict[Tuple3Ints, float]]:
"""
Calculates atomic scattering factors for electrons using the Mott-Bethe formula (1st order Born approximation).
Expand All @@ -232,8 +229,8 @@ def electron_scattering_factors(
return electron_scattering_factors

def cell_scattering_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], int]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[Tuple3Ints, int]:
"""
Calculates the scattering factor for the whole cell.
Expand All @@ -258,9 +255,7 @@ def cell_scattering_factors(
scattering_factor_curr = 0
return cell_scattering_factors

def cell_intensity(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
def cell_intensity(self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""
Calculates cell intensity for each hkl plane. For simplicity's sake, take I = |F|**2.
Expand Down Expand Up @@ -317,8 +312,8 @@ def get_pattern(
return pd.DataFrame(rows, columns=field_names)

def normalized_cell_intensity(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[Tuple3Ints, float]:
"""
Normalizes the cell_intensity dict to 1, for use in plotting.
Expand All @@ -340,8 +335,8 @@ def normalized_cell_intensity(
def is_parallel(
self,
structure: Structure,
plane: tuple[int, int, int],
other_plane: tuple[int, int, int],
plane: Tuple3Ints,
other_plane: Tuple3Ints,
) -> bool:
"""
Checks if two hkl planes are parallel in reciprocal space.
Expand All @@ -357,7 +352,7 @@ def is_parallel(
phi = self.get_interplanar_angle(structure, plane, other_plane)
return phi in (180, 0) or np.isnan(phi)

def get_first_point(self, structure: Structure, points: list) -> dict[tuple[int, int, int], float]:
def get_first_point(self, structure: Structure, points: list) -> dict[Tuple3Ints, float]:
"""Get the first point to be plotted in the 2D DP, corresponding to maximum d/minimum R.
Args:
Expand All @@ -378,7 +373,7 @@ def get_first_point(self, structure: Structure, points: list) -> dict[tuple[int,
return {max_d_plane: max_d}

@staticmethod
def get_interplanar_angle(structure: Structure, p1: tuple[int, int, int], p2: tuple[int, int, int]) -> float:
def get_interplanar_angle(structure: Structure, p1: Tuple3Ints, p2: Tuple3Ints) -> float:
"""Get the interplanar angle (in degrees) between the normal of two crystal planes.
Formulas from International Tables for Crystallography Volume C pp. 2-9.
Expand Down Expand Up @@ -432,9 +427,9 @@ def get_interplanar_angle(structure: Structure, p1: tuple[int, int, int], p2: tu

@staticmethod
def get_plot_coeffs(
p1: tuple[int, int, int],
p2: tuple[int, int, int],
p3: tuple[int, int, int],
p1: Tuple3Ints,
p2: Tuple3Ints,
p3: Tuple3Ints,
) -> np.ndarray:
"""
Calculates coefficients of the vector addition required to generate positions for each DP point
Expand All @@ -454,7 +449,7 @@ def get_plot_coeffs(
x = np.dot(a_pinv, b)
return np.ravel(x)

def get_positions(self, structure: Structure, points: list) -> dict[tuple[int, int, int], np.ndarray]:
def get_positions(self, structure: Structure, points: list) -> dict[Tuple3Ints, np.ndarray]:
"""
Calculates all the positions of each hkl point in the 2D diffraction pattern by vector addition.
Distance in centimeters.
Expand Down Expand Up @@ -524,7 +519,7 @@ def tem_dots(self, structure: Structure, points) -> list:

class dot(NamedTuple):
position: NDArray
hkl: tuple[int, int, int]
hkl: Tuple3Ints
intensity: float
film_radius: float
d_spacing: float
Expand Down
9 changes: 5 additions & 4 deletions pymatgen/analysis/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import Species
from pymatgen.util.typing import Tuple3Ints


logger = logging.getLogger(__name__)
Expand All @@ -58,7 +59,7 @@

class ConnectedSite(NamedTuple):
site: PeriodicSite
jimage: tuple[int, int, int]
jimage: Tuple3Ints
index: Any # TODO: use more specific type
weight: float
dist: float
Expand Down Expand Up @@ -338,8 +339,8 @@ def add_edge(
self,
from_index: int,
to_index: int,
from_jimage: tuple[int, int, int] = (0, 0, 0),
to_jimage: tuple[int, int, int] | None = None,
from_jimage: Tuple3Ints = (0, 0, 0),
to_jimage: Tuple3Ints | None = None,
weight: float | None = None,
warn_duplicates: bool = True,
edge_properties: dict | None = None,
Expand Down Expand Up @@ -756,7 +757,7 @@ def map_indices(grp: Molecule) -> dict[int, int]:
warn_duplicates=False,
)

def get_connected_sites(self, n: int, jimage: tuple[int, int, int] = (0, 0, 0)) -> list[ConnectedSite]:
def get_connected_sites(self, n: int, jimage: Tuple3Ints = (0, 0, 0)) -> list[ConnectedSite]:
"""Get a named tuple of neighbors of site n:
periodic_site, jimage, index, weight.
Index is the index of the corresponding site
Expand Down
5 changes: 3 additions & 2 deletions pymatgen/analysis/interfaces/coherent_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Iterator, Sequence

from pymatgen.core import Structure
from pymatgen.util.typing import Tuple3Ints


class CoherentInterfaceBuilder:
Expand All @@ -30,8 +31,8 @@ def __init__(
self,
substrate_structure: Structure,
film_structure: Structure,
film_miller: tuple[int, int, int],
substrate_miller: tuple[int, int, int],
film_miller: Tuple3Ints,
substrate_miller: Tuple3Ints,
zslgen: ZSLGenerator | None = None,
):
"""
Expand Down
5 changes: 3 additions & 2 deletions pymatgen/analysis/interfaces/substrate_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Self

from pymatgen.core import Structure
from pymatgen.util.typing import Tuple3Ints


@dataclass
Expand All @@ -24,8 +25,8 @@ class SubstrateMatch(ZSLMatch):
energy if provided, and the elastic energy.
"""

film_miller: tuple[int, int, int]
substrate_miller: tuple[int, int, int]
film_miller: Tuple3Ints
substrate_miller: Tuple3Ints
strain: Strain
von_mises_strain: float
ground_state_energy: float
Expand Down
3 changes: 2 additions & 1 deletion pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from typing_extensions import Self

from pymatgen.core.composition import SpeciesLike
from pymatgen.util.typing import Tuple3Ints


__author__ = "Shyue Ping Ong, Geoffroy Hautier, Sai Jayaraman, "
Expand Down Expand Up @@ -540,7 +541,7 @@ def _get_nn_shell_info(
return list(all_sites.values())

@staticmethod
def _get_image(structure: Structure, site: Site) -> tuple[int, int, int]:
def _get_image(structure: Structure, site: Site) -> Tuple3Ints:
"""Private convenience method for get_nn_info,
gives lattice image from provided PeriodicSite and Structure.
Expand Down
4 changes: 3 additions & 1 deletion pymatgen/analysis/surface_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
if TYPE_CHECKING:
from typing_extensions import Self

from pymatgen.util.typing import Tuple3Ints

EV_PER_ANG2_TO_JOULES_PER_M2 = 16.0217656

__author__ = "Richard Tran"
Expand Down Expand Up @@ -566,7 +568,7 @@ def area_frac_vs_chempot_plot(
all_chempots = np.linspace(min(chempot_range), max(chempot_range), increments)

# initialize a dictionary of lists of fractional areas for each hkl
hkl_area_dict: dict[tuple[int, int, int], list[float]] = {}
hkl_area_dict: dict[Tuple3Ints, list[float]] = {}
for hkl in self.all_slab_entries:
hkl_area_dict[hkl] = []

Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Composition(collections.abc.Hashable, collections.abc.Mapping, MSONable, S

# Special formula handling for peroxides and certain elements. This is so
# that formula output does not write LiO instead of Li2O2 for example.
special_formulas: ClassVar = dict(
special_formulas: ClassVar[dict[str, str]] = dict(
LiO="Li2O2",
NaO="Na2O2",
KO="K2O2",
Expand Down
Loading

0 comments on commit 65d5379

Please sign in to comment.