Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using full equality (==) to compare float, avoid assert_array_equal compare float array #4159

Open
wants to merge 60 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
588ceb8
replace some float equality check
DanielYang59 Nov 9, 2024
0b97cb0
explicit encoding
DanielYang59 Nov 9, 2024
82f3431
charge is also float
DanielYang59 Nov 9, 2024
389c59b
enhance types
DanielYang59 Nov 9, 2024
1d22fee
access gcd via math namespace as math is already imported
DanielYang59 Nov 9, 2024
84e3b70
put dunder method to top
DanielYang59 Nov 9, 2024
ea6089e
fix typo
DanielYang59 Nov 9, 2024
e264890
tweak _proj implementation
DanielYang59 Nov 9, 2024
95a6192
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 9, 2024
e431882
support array like
DanielYang59 Nov 9, 2024
8f30f13
Merge branch '4158-fix-eq-check' of https://github.com/DanielYang59/p…
DanielYang59 Nov 9, 2024
e6ea809
add arg and return type
DanielYang59 Nov 9, 2024
bf0ff16
tweak type
DanielYang59 Nov 9, 2024
5c9992e
avoid more == for float comparison
DanielYang59 Nov 10, 2024
4920eb7
replace some == in test, more left to do
DanielYang59 Nov 10, 2024
f343503
replace more in core test
DanielYang59 Nov 10, 2024
808c495
replace more in test
DanielYang59 Nov 10, 2024
c0692dd
replace even more
DanielYang59 Nov 10, 2024
48e0ead
replace last batch
DanielYang59 Nov 10, 2024
cdff78d
clean up assert approx
DanielYang59 Nov 10, 2024
7eb7caa
replace pytest.approx with approx
DanielYang59 Nov 10, 2024
7745458
also fix membership check
DanielYang59 Nov 10, 2024
24edca5
replace some equality check of list
DanielYang59 Nov 10, 2024
089e3d2
replace some sequences
DanielYang59 Nov 10, 2024
20abc2a
fix test
DanielYang59 Nov 10, 2024
4871d1b
replace float comparison as dict
DanielYang59 Nov 10, 2024
7a8c148
fix test
DanielYang59 Nov 10, 2024
1137a72
replace more float compare, mostly for VASP
DanielYang59 Nov 10, 2024
88aad8b
fix test
DanielYang59 Nov 10, 2024
d12a07b
fix approx in condition block
DanielYang59 Nov 10, 2024
4552881
replace sci notation
DanielYang59 Nov 10, 2024
30e0f66
suppress buggy ruff sim300
DanielYang59 Nov 10, 2024
4f0ff82
number_of_permutations to int
DanielYang59 Nov 10, 2024
24e81d2
revert change for formula_double_format, in favor of another PR
DanielYang59 Nov 10, 2024
8ef27dc
c_indices seems to be int
DanielYang59 Nov 10, 2024
5ac947d
use sci notation for crazily large int
DanielYang59 Nov 10, 2024
2bde949
simplify numpy.testing usage
DanielYang59 Nov 10, 2024
16fa94d
set tol as pos arg
DanielYang59 Nov 10, 2024
300dc30
avoid array equal for list of str
DanielYang59 Nov 10, 2024
d4309b7
assert_array_equal should not be used on float array
DanielYang59 Nov 10, 2024
8cbfcfc
fix module level var name
DanielYang59 Nov 10, 2024
dbe8659
more assert_array_equal on complex number
DanielYang59 Nov 10, 2024
5ff4248
simplify approx on dict value
DanielYang59 Nov 10, 2024
1f01241
avoid module level var when it's used only 3 times
DanielYang59 Nov 10, 2024
32929d4
pytext.approx to approx
DanielYang59 Nov 10, 2024
16dbec3
fix approx on nested dict
DanielYang59 Nov 10, 2024
3df99ab
avoid unnecessary convert to np.array
DanielYang59 Nov 10, 2024
fd573cd
array_equal to all close for float array
DanielYang59 Nov 10, 2024
e46dbf9
assert all close for float array
DanielYang59 Nov 10, 2024
857581b
capital class attrib is treated as constant
DanielYang59 Nov 11, 2024
7787a25
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 13, 2024
5700f3b
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 14, 2024
79d3ffc
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 16, 2024
1724f9e
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 16, 2024
e7e5209
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 18, 2024
626c3fb
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 19, 2024
7c3b822
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Dec 11, 2024
95ec56a
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Jan 2, 2025
8284fdc
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Jan 10, 2025
17787db
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions src/pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import warnings
from fractions import Fraction
from itertools import groupby, product
from math import gcd
from string import ascii_lowercase
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -49,6 +48,8 @@
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Literal

from numpy.typing import NDArray


__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"

Expand All @@ -68,6 +69,9 @@ def __init__(self, charge_balance_sp):
"""
self.charge_balance_sp = str(charge_balance_sp)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"

def apply_transformation(self, structure: Structure):
"""Apply the transformation.

Expand All @@ -87,9 +91,6 @@ def apply_transformation(self, structure: Structure):
trans = SubstitutionTransformation({self.charge_balance_sp: {self.charge_balance_sp: 1 - removal_fraction}})
return trans.apply_transformation(structure)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"


class SuperTransformation(AbstractTransformation):
"""This is a transformation that is inherently one-to-many. It is constructed
Expand All @@ -111,6 +112,9 @@ def __init__(self, transformations, nstructures_per_trans=1):
self._transformations = transformations
self.nstructures_per_trans = nstructures_per_trans

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -140,9 +144,6 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
)
return structures

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

@property
def is_one_to_many(self) -> bool:
"""Transform one structure to many."""
Expand Down Expand Up @@ -192,6 +193,9 @@ def __init__(
self.charge_balance_species = charge_balance_species
self.order = order

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -235,9 +239,6 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append({"structure": new_structure})
return outputs

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

@property
def is_one_to_many(self) -> bool:
"""Transform one structure to many."""
Expand Down Expand Up @@ -324,6 +325,9 @@ def __init__(
if max_cell_size and max_disordered_sites:
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")

def __repr__(self):
return "EnumerateStructureTransformation"

def apply_transformation(
self, structure: Structure, return_ranked_list: bool | int = False
) -> Structure | list[dict]:
Expand Down Expand Up @@ -469,9 +473,6 @@ def sort_func(struct):
return self._all_structures[:num_to_return]
return self._all_structures[0]["structure"]

def __repr__(self):
return "EnumerateStructureTransformation"

@property
def is_one_to_many(self) -> bool:
"""Transform one structure to many."""
Expand All @@ -495,6 +496,9 @@ def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
self.scale_volumes = scale_volumes
self._substitutor = SubstitutionPredictor(threshold=threshold, **kwargs)

def __repr__(self):
return "SubstitutionPredictorTransformation"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -529,9 +533,6 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append(output)
return outputs

def __repr__(self):
return "SubstitutionPredictorTransformation"

@property
def is_one_to_many(self) -> bool:
"""Transform one structure to many."""
Expand Down Expand Up @@ -660,7 +661,7 @@ def determine_min_cell(disordered_structure):

def lcm(n1, n2):
"""Find least common multiple of two numbers."""
return n1 * n2 / gcd(n1, n2)
return n1 * n2 / math.gcd(n1, n2)

# assumes all order parameters for a given species are the same
mag_species_order_parameter = {}
Expand All @@ -683,7 +684,7 @@ def lcm(n1, n2):
for sp, order_parameter in mag_species_order_parameter.items():
denom = Fraction(order_parameter).limit_denominator(100).denominator
num_atom_per_specie = mag_species_occurrences[sp]
n_gcd = gcd(denom, num_atom_per_specie)
n_gcd = math.gcd(denom, num_atom_per_specie)
smallest_n.append(lcm(int(n_gcd), denom) / n_gcd)

return max(smallest_n)
Expand Down Expand Up @@ -983,15 +984,19 @@ def __init__(
self.allowed_doping_species = allowed_doping_species
self.kwargs = kwargs

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
def apply_transformation(
self,
structure: Structure,
return_ranked_list: bool | int = False,
) -> list[dict[Literal["structure", "energy"], Structure | float]] | Structure:
"""
Args:
structure (Structure): Input structure to dope
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
structure (Structure): Input structure to dope.
return_ranked_list (bool | int, optional): If is int, that number of structures is returned.
If False, only the single lowest energy structure is returned. Defaults to False.

Returns:
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
list[dict] | Structure: each dict as {"structure": Structure, "energy": float}.
"""
comp = structure.composition
logger.info(f"Composition: {comp}")
Expand Down Expand Up @@ -1124,7 +1129,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return all_structures[0]["structure"]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1872,11 +1877,11 @@ def is_one_to_many(self) -> bool:
return True


def _proj(b, a):
def _proj(b: NDArray, a: NDArray) -> NDArray:
"""Get vector projection (np.ndarray) of vector b (np.ndarray)
onto vector a (np.ndarray).
"""
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
return (np.dot(b, a) / np.dot(a, a)) * a
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new implementation is slightly more readable (personal taste) and gives ~4x speedup, reference (the following is a project to b):

image
Original Implementation Time: 420.86 ms
New Implementation Time: 101.28 ms

Test script (by GPT):

import numpy as np
from numpy.typing import NDArray
from time import perf_counter_ns


def _proj_original(b: NDArray, a: NDArray) -> NDArray:
    return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))

def _proj_new(b: NDArray, a: NDArray) -> NDArray:
    return (np.dot(b, a) / np.dot(a, a)) * a

def verify_projection():
    a = np.random.rand(3)
    b = np.random.rand(3)
    proj1 = _proj_original(b, a)
    proj2 = _proj_new(b, a)
    assert np.allclose(proj1, proj2)

def benchmark_projections(n_iter=100000):
    a = np.random.rand(3)
    b = np.random.rand(3)

    # Measure original implementation
    start_time = perf_counter_ns()
    for _ in range(n_iter):
        _proj_original(b, a)
    time_original = perf_counter_ns() - start_time

    # Measure new implementation
    start_time = perf_counter_ns()
    for _ in range(n_iter):
        _proj_new(b, a)
    time_new = perf_counter_ns() - start_time

    print(f"Original Implementation Time: {time_original / 1e6:.2f} ms")
    print(f"New Implementation Time: {time_new / 1e6:.2f} ms")

verify_projection()

print("Benchmarking both implementations...")
benchmark_projections()



class SQSTransformation(AbstractTransformation):
Expand Down Expand Up @@ -2194,6 +2199,9 @@ def __init__(self, rattle_std: float, min_distance: float, seed: int | None = No
self.random_state = np.random.RandomState(seed)
self.kwargs = kwargs

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"

def apply_transformation(self, structure: Structure) -> Structure:
"""Apply the transformation.

Expand All @@ -2215,6 +2223,3 @@ def apply_transformation(self, structure: Structure) -> Structure:
structure.cart_coords + displacements,
coords_are_cartesian=True,
)

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"
21 changes: 10 additions & 11 deletions tests/transformations/test_advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
from monty.serialization import loadfn
from numpy.testing import assert_allclose, assert_array_equal
from pytest import approx

from pymatgen.analysis.energy_models import IsingModel, SymmetryModel
from pymatgen.analysis.gb.grain import GrainBoundaryGenerator
Expand Down Expand Up @@ -54,7 +53,7 @@ def get_table():
default lambda table.
"""
json_path = f"{TEST_FILES_DIR}/analysis/struct_predictor/test_lambda.json"
with open(json_path) as file:
with open(json_path, encoding="utf-8") as file:
return json.load(file)


Expand Down Expand Up @@ -150,7 +149,7 @@ def test_apply_transformation(self):
struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "Li+", "Li+", "O2-", "O2-"], coords)
struct_trafo = trafo.apply_transformation(struct)

assert struct_trafo.charge == approx(0, abs=1e-5)
assert struct_trafo.charge == pytest.approx(0, abs=1e-5)


@pytest.mark.skipif(not enumlib_present, reason="enum_lib not present.")
Expand Down Expand Up @@ -242,7 +241,7 @@ def test_as_from_dict(self):
trans = EnumerateStructureTransformation()
dct = trans.as_dict()
trans = EnumerateStructureTransformation.from_dict(dct)
assert trans.symm_prec == 0.1
assert trans.symm_prec == pytest.approx(0.1)


class TestSubstitutionPredictorTransformation:
Expand Down Expand Up @@ -499,7 +498,7 @@ def test_apply_transformation(self):
ss = trafo.apply_transformation(structure, 1000)
assert len(ss) == n_structures
for d in ss:
assert d["structure"].charge == 0
assert d["structure"].charge == pytest.approx(0)

# Aliovalent doping with codopant
for dopant, n_structures in [("Al3+", 3), ("N3-", 37), ("Cl-", 37)]:
Expand All @@ -513,7 +512,7 @@ def test_apply_transformation(self):
ss = trafo.apply_transformation(structure, 1000)
assert len(ss) == n_structures
for d in ss:
assert d["structure"].charge == 0
assert d["structure"].charge == pytest.approx(0)

# Make sure compensation is done with lowest oxi state
structure = PymatgenTest.get_structure("SrTiO3")
Expand Down Expand Up @@ -796,8 +795,8 @@ def test_apply_transformation_orthorhombic_supercell(self):
supercell_generator_cubic.transformation_matrix,
supercell_generator_orthorhombic.transformation_matrix,
)
assert transformed_cubic.lattice.angles != transformed_orthorhombic.lattice.angles
assert transformed_orthorhombic.lattice.abc != transformed_cubic.lattice.abc
assert not np.allclose(transformed_cubic.lattice.angles, transformed_orthorhombic.lattice.angles)
assert not np.allclose(transformed_orthorhombic.lattice.abc, transformed_cubic.lattice.abc)

structure = self.get_structure("Si")
min_atoms = 100
Expand Down Expand Up @@ -835,9 +834,9 @@ def test_apply_transformation_orthorhombic_supercell(self):
supercell_generator_cubic.transformation_matrix,
supercell_generator_orthorhombic.transformation_matrix,
)
assert transformed_orthorhombic.lattice.abc != transformed_cubic.lattice.abc
# only angels are expected to be the same because of force_90_degrees = True
assert transformed_cubic.lattice.angles == transformed_orthorhombic.lattice.angles
assert not np.allclose(transformed_orthorhombic.lattice.abc, transformed_cubic.lattice.abc)
# only angels are expected to be the same because of `force_90_degrees = True`
assert_allclose(transformed_cubic.lattice.angles, transformed_orthorhombic.lattice.angles)


class TestAddAdsorbateTransformation(PymatgenTest):
Expand Down