Skip to content

Commit

Permalink
More type annotations (#3800)
Browse files Browse the repository at this point in the history
* add Self return anno to copy() methods

* improve TestGrainBoundary.test_copy

TestGrainBoundary setUpClass -> setUp, helps language server infer types of class attributes

* add type annotations

* remove boilerplate inverse(self) -> None and is_one_to_many(self) -> False properties on AbstractTransformation subclasses

* is_one_to_many doc string to "Transform one structure to many." if returns True

* format """Returns: ... doc str by replacing with 'Returns:'->'Get' for methods and 'Returns:'->'' for properties
  • Loading branch information
janosh authored May 2, 2024
1 parent a5a4061 commit 87c92e6
Show file tree
Hide file tree
Showing 60 changed files with 279 additions and 573 deletions.
2 changes: 1 addition & 1 deletion dev_scripts/potcar_scrambler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, potcars: Potcar | PotcarSingle) -> None:

def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5) -> float:
n_prec = len(input_str.split(".")[1])
bd = max(1, bloat * abs(float(input_str)))
bd = max(1, bloat * abs(float(input_str))) # ensure we don't get 0
return round(bd * np.random.rand(1)[0], n_prec)

def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5):
Expand Down
6 changes: 2 additions & 4 deletions dev_scripts/update_pt_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""
Developer script to convert yaml periodic table to json format.
Created on Nov 15, 2011.
"""
"""Developer script to convert YAML periodic table to JSON format.
Created on 2011-11-15."""

from __future__ import annotations

Expand Down
6 changes: 3 additions & 3 deletions pymatgen/alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __repr__(self):
)

def as_dict(self) -> dict:
"""Returns: MSONable dict."""
"""Get MSONable dict."""
return {
"@module": type(self).__module__,
"@class": type(self).__name__,
Expand Down Expand Up @@ -159,7 +159,7 @@ def test(self, structure: Structure):
return True

def as_dict(self):
"""Returns: MSONable dict."""
"""Get MSONable dict."""
return {
"@module": type(self).__module__,
"@class": type(self).__name__,
Expand Down Expand Up @@ -280,7 +280,7 @@ def get_sg(s):
return True

def as_dict(self):
"""Returns: MSONable dict."""
"""Get MSONable dict."""
return {
"@module": type(self).__module__,
"@class": type(self).__name__,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,39 +682,39 @@ def pauling_stability_ratio(self):
return self._pauling_stability_ratio

@property
def mp_symbol(self):
def mp_symbol(self) -> str:
"""Returns the MP symbol of this coordination geometry."""
return self._mp_symbol

@property
def ce_symbol(self):
"""Returns the symbol of this coordination geometry."""
def ce_symbol(self) -> str:
"""Returns the symbol of this coordination geometry. Same as the MP symbol."""
return self._mp_symbol

def get_coordination_number(self):
def get_coordination_number(self) -> int:
"""Returns the coordination number of this coordination geometry."""
return self.coordination

def is_implemented(self) -> bool:
"""Returns True if this coordination geometry is implemented."""
return bool(self.points)

def get_name(self):
def get_name(self) -> str:
"""Returns the name of this coordination geometry."""
return self.name

@property
def IUPAC_symbol(self):
def IUPAC_symbol(self) -> str:
"""Returns the IUPAC symbol of this coordination geometry."""
return self.IUPACsymbol

@property
def IUPAC_symbol_str(self):
def IUPAC_symbol_str(self) -> str:
"""Returns a string representation of the IUPAC symbol of this coordination geometry."""
return str(self.IUPACsymbol)

@property
def IUCr_symbol(self):
def IUCr_symbol(self) -> str:
"""Returns the IUCr symbol of this coordination geometry."""
return self.IUCrsymbol

Expand Down Expand Up @@ -848,7 +848,7 @@ def __init__(self, permutations_safe_override=False, only_symbols=None):
only_symbols: Whether to restrict the list of environments to be identified.
"""
dict.__init__(self)
self.cg_list = []
self.cg_list: list[CoordinationGeometry] = []
if only_symbols is None:
with open(f"{module_dir}/coordination_geometries_files/allcg.txt") as file:
data = file.readlines()
Expand Down Expand Up @@ -943,18 +943,18 @@ def get_geometries(self, coordination=None, returned="cg"):
"""
geom = []
if coordination is None:
for gg in self.cg_list:
for coord_geom in self.cg_list:
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
else:
for gg in self.cg_list:
if gg.get_coordination_number() == coordination:
for coord_geom in self.cg_list:
if coord_geom.get_coordination_number() == coordination:
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
return geom

def get_symbol_name_mapping(self, coordination=None):
Expand All @@ -969,12 +969,12 @@ def get_symbol_name_mapping(self, coordination=None):
"""
geom = {}
if coordination is None:
for gg in self.cg_list:
geom[gg.mp_symbol] = gg.name
for coord_geom in self.cg_list:
geom[coord_geom.mp_symbol] = coord_geom.name
else:
for gg in self.cg_list:
if gg.get_coordination_number() == coordination:
geom[gg.mp_symbol] = gg.name
for coord_geom in self.cg_list:
if coord_geom.get_coordination_number() == coordination:
geom[coord_geom.mp_symbol] = coord_geom.name
return geom

def get_symbol_cn_mapping(self, coordination=None):
Expand All @@ -989,12 +989,12 @@ def get_symbol_cn_mapping(self, coordination=None):
"""
geom = {}
if coordination is None:
for gg in self.cg_list:
geom[gg.mp_symbol] = gg.coordination_number
for coord_geom in self.cg_list:
geom[coord_geom.mp_symbol] = coord_geom.coordination_number
else:
for gg in self.cg_list:
if gg.get_coordination_number() == coordination:
geom[gg.mp_symbol] = gg.coordination_number
for coord_geom in self.cg_list:
if coord_geom.get_coordination_number() == coordination:
geom[coord_geom.mp_symbol] = coord_geom.coordination_number
return geom

def get_implemented_geometries(self, coordination=None, returned="cg", include_deactivated=False):
Expand All @@ -1008,23 +1008,23 @@ def get_implemented_geometries(self, coordination=None, returned="cg", include_d
"""
geom = []
if coordination is None:
for gg in self.cg_list:
if gg.points is not None and ((not gg.deactivate) or include_deactivated):
for coord_geom in self.cg_list:
if coord_geom.points is not None and ((not coord_geom.deactivate) or include_deactivated):
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
else:
for gg in self.cg_list:
for coord_geom in self.cg_list:
if (
gg.get_coordination_number() == coordination
and gg.points is not None
and ((not gg.deactivate) or include_deactivated)
coord_geom.get_coordination_number() == coordination
and coord_geom.points is not None
and ((not coord_geom.deactivate) or include_deactivated)
):
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
return geom

def get_not_implemented_geometries(self, coordination=None, returned="mp_symbol"):
Expand All @@ -1037,63 +1037,63 @@ def get_not_implemented_geometries(self, coordination=None, returned="mp_symbol"
"""
geom = []
if coordination is None:
for gg in self.cg_list:
if gg.points is None:
for coord_geom in self.cg_list:
if coord_geom.points is None:
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
else:
for gg in self.cg_list:
if gg.get_coordination_number() == coordination and gg.points is None:
for coord_geom in self.cg_list:
if coord_geom.get_coordination_number() == coordination and coord_geom.points is None:
if returned == "cg":
geom.append(gg)
geom.append(coord_geom)
elif returned == "mp_symbol":
geom.append(gg.mp_symbol)
geom.append(coord_geom.mp_symbol)
return geom

def get_geometry_from_name(self, name):
def get_geometry_from_name(self, name: str) -> CoordinationGeometry:
"""Get the coordination geometry of the given name.
Args:
name: The name of the coordination geometry.
"""
for gg in self.cg_list:
if gg.name == name or name in gg.alternative_names:
return gg
for coord_geom in self.cg_list:
if coord_geom.name == name or name in coord_geom.alternative_names:
return coord_geom
raise LookupError(f"No coordination geometry found with name {name!r}")

def get_geometry_from_IUPAC_symbol(self, IUPAC_symbol):
def get_geometry_from_IUPAC_symbol(self, IUPAC_symbol: str) -> CoordinationGeometry:
"""Get the coordination geometry of the given IUPAC symbol.
Args:
IUPAC_symbol: The IUPAC symbol of the coordination geometry.
"""
for gg in self.cg_list:
if gg.IUPAC_symbol == IUPAC_symbol:
return gg
for coord_geom in self.cg_list:
if coord_geom.IUPAC_symbol == IUPAC_symbol:
return coord_geom
raise LookupError(f"No coordination geometry found with IUPAC symbol {IUPAC_symbol!r}")

def get_geometry_from_IUCr_symbol(self, IUCr_symbol):
def get_geometry_from_IUCr_symbol(self, IUCr_symbol: str) -> CoordinationGeometry:
"""Get the coordination geometry of the given IUCr symbol.
Args:
IUCr_symbol: The IUCr symbol of the coordination geometry.
"""
for gg in self.cg_list:
if gg.IUCr_symbol == IUCr_symbol:
return gg
for coord_geom in self.cg_list:
if coord_geom.IUCr_symbol == IUCr_symbol:
return coord_geom
raise LookupError(f"No coordination geometry found with IUCr symbol {IUCr_symbol!r}")

def get_geometry_from_mp_symbol(self, mp_symbol):
def get_geometry_from_mp_symbol(self, mp_symbol: str) -> CoordinationGeometry:
"""Get the coordination geometry of the given mp_symbol.
Args:
mp_symbol: The mp_symbol of the coordination geometry.
"""
for gg in self.cg_list:
if gg.mp_symbol == mp_symbol:
return gg
for coord_geom in self.cg_list:
if coord_geom.mp_symbol == mp_symbol:
return coord_geom
raise LookupError(f"No coordination geometry found with mp_symbol {mp_symbol!r}")

def is_a_valid_coordination_geometry(
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/chempot_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _get_new_limits_from_padding(
elem_indices: list[int],
element_padding: float,
default_min_limit: float,
):
) -> list[float]:
"""Get new minimum limits for each element by subtracting specified padding
from the minimum for each axis found in any of the domains.
"""
Expand Down
11 changes: 5 additions & 6 deletions pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,15 +742,14 @@ def get_compliance_expansion(self):
ce_exp.append(temp)
return TensorCollection(ce_exp)

def get_strain_from_stress(self, stress):
"""Get the strain from a stress state according
to the compliance expansion corresponding to the
tensor expansion.
def get_strain_from_stress(self, stress) -> float:
"""Get the strain from a stress state according to the compliance
expansion corresponding to the tensor expansion.
"""
compl_exp = self.get_compliance_expansion()
strain = 0
for n, compl in enumerate(compl_exp, start=1):
strain += compl.einsum_sequence([stress] * (n)) / factorial(n)
for idx, compl in enumerate(compl_exp, start=1):
strain += compl.einsum_sequence([stress] * (idx)) / factorial(idx)
return strain

def get_effective_ecs(self, strain, order=2):
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/analysis/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _calc_real_and_point(self):

@property
def eta(self):
"""Returns: eta value used in Ewald summation."""
"""Eta value used in Ewald summation."""
return self._eta

def __str__(self):
Expand Down Expand Up @@ -694,17 +694,17 @@ def _recurse(self, matrix, m_list, indices, output_m_list=None):

@property
def best_m_list(self):
"""Returns: Best m_list found."""
"""The best manipulation list found."""
return self._best_m_list

@property
def minimized_sum(self):
"""Returns: Minimized sum."""
"""The minimized Ewald sum."""
return self._minimized_sum

@property
def output_lists(self):
"""Returns: output lists."""
"""Output lists."""
return self._output_lists


Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/ferroelectricity/polarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
__date__ = "April 15, 2017"


def zval_dict_from_potcar(potcar):
def zval_dict_from_potcar(potcar) -> dict[str, float]:
"""
Creates zval_dictionary for calculating the ionic polarization from
Potcar object.
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/molecule_structure_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _get_bonds(self, mol):
return [bond for bond, dist, cap in zip(all_pairs, pair_dists, max_length) if dist <= cap]

def as_dict(self):
"""Returns: MSONable dict."""
"""Get MSONable dict."""
return {
"version": __version__,
"@module": type(self).__module__,
Expand Down
Loading

0 comments on commit 87c92e6

Please sign in to comment.