From 94fba8c0fe2058be1425ef2cd525188e9cd8d7e2 Mon Sep 17 00:00:00 2001 From: CompRhys Date: Fri, 19 Jul 2024 12:18:08 -0400 Subject: [PATCH] fix: raise="ignore" default let to invalid prototypes being returned. --- aviary/wren/utils.py | 175 ++++++++++++++++++-------------------- tests/data/U2Pa4Tc6.json | 107 +++++++++++++++++++++++ tests/test_wyckoff_ops.py | 48 ++++++++++- 3 files changed, 239 insertions(+), 91 deletions(-) create mode 100644 tests/data/U2Pa4Tc6.json diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 88feefce..8fe4a21e 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -9,7 +9,6 @@ from os.path import abspath, dirname, join from shutil import which from string import ascii_uppercase, digits -from typing import Literal from monty.fractions import gcd from pymatgen.core import Composition, Structure @@ -97,7 +96,7 @@ def count_values_for_wyckoff( def get_aflow_label_from_aflow( struct: Structure, aflow_executable: str | None = None, - errors: Literal["raise", "annotate", "ignore"] = "raise", + raise_errors: bool = False, ) -> str: """Get Aflow prototype label for a pymatgen Structure. Make sure you're running a recent version of the aflow CLI as there's been several breaking changes. This code @@ -111,16 +110,12 @@ def get_aflow_label_from_aflow( Args: struct (Structure): pymatgen Structure aflow_executable (str): path to aflow executable. Defaults to which("aflow"). - errors ('raise' | 'annotate' | 'ignore']): How to handle errors. 'raise' and - 'ignore' are self-explanatory. 'annotate' prefixes problematic Aflow labels - with 'invalid : '. - - Raises: - ValueError: if errors='raise' and Wyckoff multiplicities do not add up to - expected composition. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. Returns: - str: Aflow prototype label + str: AFLOW prototype label or explanation of failure if symmetry detection + failed and raise_errors is False. """ if aflow_executable is None: aflow_executable = which("aflow") @@ -144,12 +139,13 @@ def get_aflow_label_from_aflow( aflow_proto = json.loads(output.stdout) aflow_label = aflow_proto["aflow_prototype_label"] + chem_sys = struct.composition.chemical_system + full_label = f"{aflow_label}:{chem_sys}" # check that multiplicities satisfy original composition _, _, spg_num, *wyckoff_letters = aflow_label.split("_") - elements = sorted(el.symbol for el in struct.composition) elem_dict = {} - for elem, wyk_letters_per_elem in zip(elements, wyckoff_letters): + for elem, wyk_letters_per_elem in zip(chem_sys.split("-"), wyckoff_letters): # normalize Wyckoff letters to start with 1 if missing digit wyk_letters_normalized = re.sub( RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyk_letters_per_elem @@ -162,91 +158,35 @@ def get_aflow_label_from_aflow( wyckoff_multiplicity_dict, ) - full_label = f"{aflow_label}:{'-'.join(elements)}" - observed_formula = Composition(elem_dict).reduced_formula expected_formula = struct.composition.reduced_formula if observed_formula != expected_formula: - if errors == "raise": - raise ValueError( - f"invalid WP multiplicities - {aflow_label}, expected " - f"{observed_formula} to be {expected_formula}" - ) - if errors == "annotate": - return f"invalid multiplicities: {full_label}" - - return full_label - - -def get_aflow_label_from_spglib( - struct: Structure, - errors: Literal["raise", "annotate", "ignore"] = "ignore", - init_symprec: float = 0.1, - fallback_symprec: float = 1e-5, -) -> str | None: - """Get AFLOW prototype label for pymatgen Structure. - - Args: - struct (Structure): pymatgen Structure object. - errors ('raise' | 'annotate' | 'ignore']): How to handle errors. 'raise' and - 'ignore' are self-explanatory. 'annotate' prefixes problematic Aflow labels - with 'invalid : '. - init_symprec (float): Initial symmetry precision for spglib. Defaults to 0.1. - fallback_symprec (float): Fallback symmetry precision for spglib if first - symmetry detection failed. Defaults to 1e-5. - - Returns: - str: AFLOW prototype label or None if errors='ignore' and symmetry detection - failed. - """ - try: - spg_analyzer = SpacegroupAnalyzer( - struct, symprec=init_symprec, angle_tolerance=5 - ) - aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( - spg_analyzer, errors + err_msg = ( + f"Invalid WP multiplicities - {full_label}, expected " + f"{observed_formula} to be {expected_formula}" ) + if raise_errors: + raise ValueError(err_msg) - # try again with refined structure if it initially fails - # NOTE structures with magmoms fail unless all have same magnetic moment - if "invalid" in aflow_label_with_chemsys: - spg_analyzer = SpacegroupAnalyzer( - spg_analyzer.get_refined_structure(), - symprec=fallback_symprec, - angle_tolerance=-1, - ) - aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( - spg_analyzer, errors - ) - return aflow_label_with_chemsys + return err_msg - except ValueError as exc: - if errors == "annotate": - return f"invalid spglib: {exc}" - raise # we only get here if errors == "raise" + return full_label def get_aflow_label_from_spg_analyzer( spg_analyzer: SpacegroupAnalyzer, - errors: Literal["raise", "annotate", "ignore"] = "raise", + raise_errors: bool = False, ) -> str: """Get AFLOW prototype label for pymatgen SpacegroupAnalyzer. Args: spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. - errors ('raise' | 'annotate' | 'ignore']): How to handle errors. 'raise' and - 'ignore' are self-explanatory. 'annotate' prefixes problematic Aflow labels - with 'invalid : '. - - Raises: - ValueError: if errors='raise' and Wyckoff multiplicities do not add up to - expected composition. - - Raises: - ValueError: if Wyckoff multiplicities do not add up to expected composition. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. Returns: - str: AFLOW prototype labels + str: AFLOW prototype label or explanation of failure if symmetry detection + failed and raise_errors is False. """ spg_num = spg_analyzer.get_space_group_number() sym_struct = spg_analyzer.get_symmetrized_structure() @@ -288,22 +228,77 @@ def get_aflow_label_from_spg_analyzer( prototype_form = prototype_formula(sym_struct.composition) chem_sys = sym_struct.composition.chemical_system - aflow_label_with_chemsys = ( - f"{prototype_form}_{pearson_symbol}_{spg_num}_{canonical}:{chem_sys}" - ) + full_label = f"{prototype_form}_{pearson_symbol}_{spg_num}_{canonical}:{chem_sys}" observed_formula = Composition(elem_dict).reduced_formula expected_formula = sym_struct.composition.reduced_formula if observed_formula != expected_formula: - if errors == "raise": - raise ValueError( - f"Invalid WP multiplicities - {aflow_label_with_chemsys}, expected " - f"{observed_formula} to be {expected_formula}" + err_msg = ( + f"Invalid WP multiplicities - {full_label}, expected " + f"{observed_formula} to be {expected_formula}" + ) + if raise_errors: + raise ValueError(err_msg) + + return err_msg + + return full_label + + +def get_aflow_label_from_spglib( + struct: Structure, + raise_errors: bool = False, + init_symprec: float = 0.1, + fallback_symprec: float | None = 1e-5, +) -> str | None: + """Get AFLOW prototype label for pymatgen Structure. + + Args: + struct (Structure): pymatgen Structure object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + init_symprec (float): Initial symmetry precision for spglib. Defaults to 0.1. + fallback_symprec (float): Fallback symmetry precision for spglib if first + symmetry detection failed. Defaults to 1e-5. + + Returns: + str: AFLOW prototype label or explanation of failure if symmetry detection + failed and raise_errors is False. + """ + attempt_to_recover = False + try: + spg_analyzer = SpacegroupAnalyzer( + struct, symprec=init_symprec, angle_tolerance=5 + ) + try: + aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( + spg_analyzer, raise_errors ) - if errors == "annotate": - return f"invalid multiplicities: {aflow_label_with_chemsys}" - return aflow_label_with_chemsys + if ("Invalid" in aflow_label_with_chemsys) and fallback_symprec is not None: + attempt_to_recover = True + except ValueError as exc: + if fallback_symprec is None: + raise exc + attempt_to_recover = True + + # try again with refined structure if it initially fails + # NOTE structures with magmoms fail unless all have same magnetic moment + if attempt_to_recover: + spg_analyzer = SpacegroupAnalyzer( + spg_analyzer.get_refined_structure(), + symprec=fallback_symprec, + angle_tolerance=-1, + ) + aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( + spg_analyzer, raise_errors + ) + return aflow_label_with_chemsys + + except ValueError as exc: + if not raise_errors: + return str(exc) + raise def canonicalize_elem_wyks(elem_wyks: str, spg_num: int | str) -> str: diff --git a/tests/data/U2Pa4Tc6.json b/tests/data/U2Pa4Tc6.json new file mode 100644 index 00000000..aa7ba3b8 --- /dev/null +++ b/tests/data/U2Pa4Tc6.json @@ -0,0 +1,107 @@ +{ + "@module": "pymatgen.core.structure", + "@class": "Structure", + "charge": 0, + "lattice": { + "matrix": [ + [5.989671, 0.00015953, 7.795e-05], + [0.00021958, 8.25008569, -0.03720131], + [-2.99487801, -4.14847393, 5.20632921] + ], + "pbc": [true, true, true], + "a": 5.9896710026317, + "b": 8.250169566622487, + "c": 7.299657121096198, + "alpha": 124.85712587044729, + "beta": 114.22257511338135, + "gamma": 89.99695241219993, + "volume": 256.349880211745 + }, + "properties": {}, + "sites": [ + { + "species": [{"element": "U", "occu": 1}], + "abc": [0.49997316, 0.25000333, 0.50000537], + "xyz": [1.4972745454666476, -0.011630586301441285, 2.5939310845153174], + "properties": {}, + "label": "U" + }, + { + "species": [{"element": "U", "occu": 1}], + "abc": [0.50002684, 0.74999667, 0.49999463], + "xyz": [1.497738024533353, 4.113401876301443, 2.5752747654846826], + "properties": {}, + "label": "U" + }, + { + "species": [{"element": "Pa", "occu": 1}], + "abc": [0.16664895, 0.91662717, 0.83327202], + "xyz": [-1.4971743930568415, 4.105472032220753, 4.304201716381764], + "properties": {}, + "label": "Pa" + }, + { + "species": [{"element": "Pa", "occu": 1}], + "abc": [0.16662875, 0.41663202, 0.83327108], + "xyz": [-1.4974023581607492, -0.019526903520663105, 4.322797295431363], + "properties": {}, + "label": "Pa" + }, + { + "species": [{"element": "Pa", "occu": 1}], + "abc": [0.83337125, 0.58336798, 0.16672892], + "xyz": [4.492414928160749, 4.121298193520663, 0.8464085545686368], + "properties": {}, + "label": "Pa" + }, + { + "species": [{"element": "Pa", "occu": 1}], + "abc": [0.83335105, 0.08337283, 0.16672798], + "xyz": [4.4921869630568425, -0.003700742220752091, 0.8650041336182359], + "properties": {}, + "label": "Pa" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [5.1e-07, 0.74999811, 0.99999918], + "xyz": [-2.994707814882828, 2.03907814666803, 5.178424028660278], + "properties": {}, + "label": "Tc" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [0.99999949, 0.25000189, 8.2e-07], + "xyz": [5.989720384882829, 2.062693143331971, -0.0092181786602782], + "properties": {}, + "label": "Tc" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [0.49999907, 0.74998284, 0.99999871], + "xyz": [0.00012046423061040414, 2.0390338824125678, 5.1784611246373045], + "properties": {}, + "label": "Tc" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [0.0, 0.0, 0.5], + "xyz": [-1.497439005, -2.074236965, 2.603164605], + "properties": {}, + "label": "Tc" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [0.0, 0.5, 0.5], + "xyz": [-1.497329215, 2.0508058800000004, 2.5845639499999997], + "properties": {}, + "label": "Tc" + }, + { + "species": [{"element": "Tc", "occu": 1}], + "abc": [0.50000093, 0.25001716, 1.29e-06], + "xyz": [2.9948921057693902, 2.062737407587434, -0.0092552746373052], + "properties": {}, + "label": "Tc" + } + ] +} diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index ed67b393..67ace39d 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -1,8 +1,11 @@ +import inspect +import re from itertools import permutations from shutil import which import pytest from pymatgen.core.structure import Composition, Structure +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from aviary.wren.utils import ( _find_translations, @@ -11,6 +14,7 @@ count_distinct_wyckoff_letters, count_wyckoff_positions, get_aflow_label_from_aflow, + get_aflow_label_from_spg_analyzer, get_aflow_label_from_spglib, get_aflow_strs_from_iso_and_composition, get_anom_formula_from_prototype_formula, @@ -40,12 +44,54 @@ def test_get_aflow_label_from_spglib(): - """Check that spglib gives correct Aflow label for esseneite""" + """Check that spglib gives correct Aflow label for esseneite.""" struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") assert get_aflow_label_from_spglib(struct) == "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" +def test_get_aflow_label_from_spglib_edge_case(): + """Check edge case where the symmetry precision is too low.""" + struct = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") + + defaults = inspect.signature(get_aflow_label_from_spglib).parameters + + assert defaults["init_symprec"].default == 0.1 + + spg_analyzer = SpacegroupAnalyzer( + struct, symprec=defaults["init_symprec"].default, angle_tolerance=5 + ) + + raises_str = ( + "Invalid WP multiplicities - A2B3C_hP6_191_c_2g_a:Pa-Tc-U, " + "expected U(PaTc3)2 to be UPa2Tc3" + ) + with pytest.raises(ValueError, match=re.escape(raises_str)): + get_aflow_label_from_spg_analyzer(spg_analyzer, raise_errors=True) + + assert ( + get_aflow_label_from_spg_analyzer(spg_analyzer, raise_errors=False) + == raises_str + ) + + # Test that it gives invalid protostructure if fallback is None. + with pytest.raises(ValueError, match=re.escape(raises_str)): + get_aflow_label_from_spglib(struct, raise_errors=True, fallback_symprec=None) + + assert ( + get_aflow_label_from_spglib(struct, raise_errors=False, fallback_symprec=None) + == raises_str + ) + + assert get_aflow_label_from_spglib(struct, raise_errors=True) == ( + "A2B3C_hP6_191_c_g_a:Pa-Tc-U" + ) + + assert get_aflow_label_from_spglib(struct, raise_errors=False) == ( + "A2B3C_hP6_191_c_g_a:Pa-Tc-U" + ) + + @pytest.mark.parametrize( "aflow_label, expected", [