Skip to content

Commit

Permalink
fix ZeroDivisionError in > self.logger.debug(f"Average time per combi…
Browse files Browse the repository at this point in the history
… = {(now - start_time) / idx} seconds")
  • Loading branch information
janosh committed Jan 20, 2024
1 parent 882452e commit ab9319b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 49 deletions.
2 changes: 1 addition & 1 deletion pymatgen/transformations/site_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def _complete_ordering(self, structure: Structure, num_remove_dict):

all_combis = [list(itertools.combinations(ind, num)) for ind, num in num_remove_dict.items()]

for idx, all_indices in enumerate(itertools.product(*all_combis)):
for idx, all_indices in enumerate(itertools.product(*all_combis), 1):
sites_to_remove = []
indices_list = []
for indices in all_indices:
Expand Down
84 changes: 42 additions & 42 deletions tests/transformations/test_site_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@ def setUp(self):

def test_apply_transformation(self):
trafo = TranslateSitesTransformation([0, 1], [0.1, 0.2, 0.3])
s = trafo.apply_transformation(self.struct)
assert_allclose(s[0].frac_coords, [0.1, 0.2, 0.3])
assert_allclose(s[1].frac_coords, [0.475, 0.575, 0.675])
struct = trafo.apply_transformation(self.struct)
assert_allclose(struct[0].frac_coords, [0.1, 0.2, 0.3])
assert_allclose(struct[1].frac_coords, [0.475, 0.575, 0.675])
inv_t = trafo.inverse
s = inv_t.apply_transformation(s)
assert s[0].distance_and_image_from_frac_coords([0, 0, 0])[0] == 0
assert_allclose(s[1].frac_coords, [0.375, 0.375, 0.375])
struct = inv_t.apply_transformation(struct)
assert struct[0].distance_and_image_from_frac_coords([0, 0, 0])[0] == 0
assert_allclose(struct[1].frac_coords, [0.375, 0.375, 0.375])

def test_apply_transformation_site_by_site(self):
trafo = TranslateSitesTransformation([0, 1], [[0.1, 0.2, 0.3], [-0.075, -0.075, -0.075]])
s = trafo.apply_transformation(self.struct)
assert_allclose(s[0].frac_coords, [0.1, 0.2, 0.3])
assert_allclose(s[1].frac_coords, [0.3, 0.3, 0.3])
struct = trafo.apply_transformation(self.struct)
assert_allclose(struct[0].frac_coords, [0.1, 0.2, 0.3])
assert_allclose(struct[1].frac_coords, [0.3, 0.3, 0.3])
inv_t = trafo.inverse
s = inv_t.apply_transformation(s)
assert s[0].distance_and_image_from_frac_coords([0, 0, 0])[0] == 0
assert_allclose(s[1].frac_coords, [0.375, 0.375, 0.375])
struct = inv_t.apply_transformation(struct)
assert struct[0].distance_and_image_from_frac_coords([0, 0, 0])[0] == 0
assert_allclose(struct[1].frac_coords, [0.375, 0.375, 0.375])

def test_as_from_dict(self):
d1 = TranslateSitesTransformation([0], [0.1, 0.2, 0.3]).as_dict()
Expand Down Expand Up @@ -100,14 +100,14 @@ def setUp(self):

def test_apply_transformation(self):
trafo = ReplaceSiteSpeciesTransformation({0: "Na"})
s = trafo.apply_transformation(self.struct)
assert s.formula == "Na1 Li3 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Na1 Li3 O4"

def test_as_from_dict(self):
d = ReplaceSiteSpeciesTransformation({0: "Na"}).as_dict()
trafo = ReplaceSiteSpeciesTransformation.from_dict(d)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Na1 Li3 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Na1 Li3 O4"


class TestRemoveSitesTransformation(unittest.TestCase):
Expand All @@ -132,14 +132,14 @@ def setUp(self):

def test_apply_transformation(self):
trafo = RemoveSitesTransformation(range(2))
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O4"

def test_as_from_dict(self):
d = RemoveSitesTransformation(range(2)).as_dict()
trafo = RemoveSitesTransformation.from_dict(d)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O4"


class TestInsertSitesTransformation(unittest.TestCase):
Expand All @@ -164,8 +164,8 @@ def setUp(self):

def test_apply_transformation(self):
trafo = InsertSitesTransformation(["Fe", "Mn"], [[0.0, 0.5, 0], [0.5, 0.2, 0.2]])
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li4 Mn1 Fe1 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li4 Mn1 Fe1 O4"
trafo = InsertSitesTransformation(["Fe", "Mn"], [[0.001, 0, 0], [0.1, 0.2, 0.2]])

# Test validate proximity
Expand All @@ -175,8 +175,8 @@ def test_apply_transformation(self):
def test_as_from_dict(self):
d = InsertSitesTransformation(["Fe", "Mn"], [[0.5, 0, 0], [0.1, 0.5, 0.2]]).as_dict()
trafo = InsertSitesTransformation.from_dict(d)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li4 Mn1 Fe1 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li4 Mn1 Fe1 O4"


class TestPartialRemoveSitesTransformation(unittest.TestCase):
Expand Down Expand Up @@ -205,10 +205,10 @@ def test_apply_transformation_complete(self):
[0.5, 0.5],
PartialRemoveSitesTransformation.ALGO_COMPLETE,
)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O2"
s = trafo.apply_transformation(self.struct, 12)
assert len(s) == 12
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O2"
struct = trafo.apply_transformation(self.struct, 12)
assert len(struct) == 12

@unittest.skipIf(not enumlib_present, "enum_lib not present.")
def test_apply_transformation_enumerate(self):
Expand All @@ -217,38 +217,38 @@ def test_apply_transformation_enumerate(self):
[0.5, 0.5],
PartialRemoveSitesTransformation.ALGO_ENUMERATE,
)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O2"
s = trafo.apply_transformation(self.struct, 12)
assert len(s) == 12
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O2"
struct = trafo.apply_transformation(self.struct, 12)
assert len(struct) == 12

def test_apply_transformation_best_first(self):
trafo = PartialRemoveSitesTransformation(
[tuple(range(4)), tuple(range(4, 8))],
[0.5, 0.5],
PartialRemoveSitesTransformation.ALGO_BEST_FIRST,
)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O2"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O2"

def test_apply_transformation_fast(self):
trafo = PartialRemoveSitesTransformation(
[tuple(range(4)), tuple(range(4, 8))],
[0.5, 0.5],
PartialRemoveSitesTransformation.ALGO_FAST,
)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O2"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O2"
trafo = PartialRemoveSitesTransformation([tuple(range(8))], [0.5], PartialRemoveSitesTransformation.ALGO_FAST)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O2"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O2"

def test_as_from_dict(self):
dct = PartialRemoveSitesTransformation([tuple(range(4))], [0.5]).as_dict()
assert {*dct} == {"@module", "@class", "@version", "algo", "indices", "fractions"}
trafo = PartialRemoveSitesTransformation.from_dict(dct)
s = trafo.apply_transformation(self.struct)
assert s.formula == "Li2 O4"
struct = trafo.apply_transformation(self.struct)
assert struct.formula == "Li2 O4"

def test_str(self):
trafo = PartialRemoveSitesTransformation([tuple(range(4))], [0.5])
Expand Down Expand Up @@ -339,6 +339,6 @@ def test(self):

def test_second_nn(self):
trafo = RadialSiteDistortionTransformation(0, 1, nn_only=False)
s = trafo.apply_transformation(self.molecule)
for c1, c2 in zip(self.molecule[7:], s[7:]):
struct = trafo.apply_transformation(self.molecule)
for c1, c2 in zip(self.molecule[7:], struct[7:]):
assert abs(round(sum(c2.coords - c1.coords), 2)) == 0.33
12 changes: 6 additions & 6 deletions tests/transformations/test_standard_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_apply_transformation(self):
[0.00, -2.2171384943, 3.1355090603],
]
struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords)
s = trafo.apply_transformation(struct)
assert s.composition.formula == "O2"
struct = trafo.apply_transformation(struct)
assert struct.composition.formula == "O2"

d = trafo.as_dict()
assert isinstance(RemoveSpeciesTransformation.from_dict(d), RemoveSpeciesTransformation)
Expand Down Expand Up @@ -298,12 +298,12 @@ def test_apply_transformation(self):
assert trafo.lowest_energy_structure == output[0]["structure"]

struct = Structure(lattice, [{"Si4+": 0.5}, {"Si4+": 0.5}, {"O2-": 0.5}, {"O2-": 0.5}], coords)
allstructs = trafo.apply_transformation(struct, 50)
assert len(allstructs) == 4
all_structs = trafo.apply_transformation(struct, 50)
assert len(all_structs) == 4

struct = Structure(lattice, [{"Si4+": 0.333}, {"Si4+": 0.333}, {"Si4+": 0.333}, "O2-"], coords)
allstructs = trafo.apply_transformation(struct, 50)
assert len(allstructs) == 3
all_structs = trafo.apply_transformation(struct, 50)
assert len(all_structs) == 3

d = trafo.as_dict()
assert isinstance(OrderDisorderedStructureTransformation.from_dict(d), OrderDisorderedStructureTransformation)
Expand Down

0 comments on commit ab9319b

Please sign in to comment.