Skip to content

Commit dcbaf06

Browse files
authored
ChemEnvSiteFingerprint.from_preset() removal of not-implemented CEs (#948)
* Removed ces that are not implemented in chemenv, added tests that previously would have failed (issue #945). * Fix linting errors.
1 parent ee5747d commit dcbaf06

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

matminer/featurizers/site/fingerprint.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -771,23 +771,18 @@ def from_preset(preset):
771771
"PA:10",
772772
"SBSA:10",
773773
"MI:10",
774-
"S:10",
775-
"H:10",
776774
"BS_1:10",
777775
"BS_2:10",
778776
"TBSA:10",
779777
"PCPA:11",
780778
"H:11",
781-
"SH:11",
782-
"CO:11",
783779
"DI:11",
784780
"I:12",
785781
"PBP:12",
786782
"TT:12",
787783
"C:12",
788784
"AC:12",
789785
"SC:12",
790-
"S:12",
791786
"HP:12",
792787
"HA:12",
793788
"SH:13",

matminer/featurizers/site/tests/test_fingerprint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,11 @@ def test_crystal_nn_fingerprint(self):
277277

278278
def test_chemenv_site_fingerprint(self):
279279
cefp = ChemEnvSiteFingerprint.from_preset("multi_weights")
280+
implemented_cetypes = {gg.ce_symbol for gg in cefp.lgf.allcg.get_implemented_geometries()}
281+
assert set(cefp.cetypes).difference(implemented_cetypes) == set() # Added after issue #945
280282
l = cefp.feature_labels()
281283
cevals = cefp.featurize(self.sc, 0)
282-
self.assertEqual(len(cevals), 66)
284+
self.assertEqual(len(cevals), 61)
283285
self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7)
284286
self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7)
285287
cevals = cefp.featurize(self.cscl, 0)
@@ -288,12 +290,14 @@ def test_chemenv_site_fingerprint(self):
288290
cefp = ChemEnvSiteFingerprint.from_preset("simple")
289291
l = cefp.feature_labels()
290292
cevals = cefp.featurize(self.sc, 0)
291-
self.assertEqual(len(cevals), 66)
293+
self.assertEqual(len(cevals), 61)
292294
self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7)
293295
self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7)
294296
cevals = cefp.featurize(self.cscl, 0)
295297
self.assertAlmostEqual(cevals[l.index("C:8")], 0.9953721, places=7)
296298
self.assertAlmostEqual(cevals[l.index("O:6")], 0, places=7)
299+
cevals = cefp.featurize(self.ni3al, 0) # Added after issue #945
300+
self.assertAlmostEqual(cevals[l.index("I:12")], 0.3401699, places=7)
297301

298302
def test_voronoifingerprint(self):
299303
df_sc = pd.DataFrame({"struct": [self.sc], "site": [0]})

0 commit comments

Comments
 (0)