From 783916a5e22a1c9a9ce1abf818647f7217907f03 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 19 Jan 2025 17:32:45 -0500 Subject: [PATCH 01/16] add get_protostructure_label_from_moyopy function using myoypy for symmetry detection moyopy is faster and potentially safer since written in rust --- aviary/wren/utils.py | 137 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 2 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 0c81d272..04c184ab 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -144,7 +144,7 @@ def get_protostructure_label_from_aflow( aflow_proto = json.loads(output.stdout) aflow_label = aflow_proto["aflow_prototype_label"] - chemsys = struct.composition.chemical_system + chemsys = struct.chemical_system # check that multiplicities satisfy original composition prototype_form, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") @@ -237,7 +237,7 @@ def get_protostructure_label_from_spg_analyzer( pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" prototype_form = get_prototype_formula_from_composition(sym_struct.composition) - chemsys = sym_struct.composition.chemical_system + chemsys = sym_struct.chemical_system all_wyckoffs = "_".join(element_wyckoffs) all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) @@ -316,6 +316,139 @@ def get_protostructure_label_from_spglib( raise +def get_protostructure_label_from_moyopy( + struct: Structure, + raise_errors: bool = False, + init_symprec: float = 0.1, + fallback_symprec: float | None = 1e-5, +) -> str | None: + """Get AFLOW prototype label using Moyopy for symmetry detection. + + Args: + struct (Structure): pymatgen Structure object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + init_symprec (float): Initial symmetry precision for Moyopy. Defaults to 0.1. + fallback_symprec (float): Fallback symmetry precision if first symmetry detection + failed. Defaults to 1e-5. + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + import moyopy + from moyopy.interface import MoyoAdapter + + attempt_to_recover = False + try: + # Convert pymatgen Structure to Moyo Cell + moyo_cell = MoyoAdapter.from_structure(struct) + + try: + # First attempt with initial symprec + moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=init_symprec) + + # Get space group number and Wyckoff positions + spg_num = moyo_data.number + wyckoff_symbols = moyo_data.wyckoffs + + # Get crystal system and centering from Hall symbol entry + hall_entry = moyopy.HallSymbolEntry(hall_number=moyo_data.hall_number) + spg_sym = hall_entry.hm_short + + # Get crystal system from space group number instead of symbol + if spg_num <= 2: + cry_sys = "triclinic" + elif spg_num <= 15: + cry_sys = "monoclinic" + elif spg_num <= 74: + cry_sys = "orthorhombic" + elif spg_num <= 142: + cry_sys = "tetragonal" + elif spg_num <= 167: + cry_sys = "trigonal" + elif spg_num <= 194: + cry_sys = "hexagonal" + else: + cry_sys = "cubic" + + # Get centering from first letter of space group symbol + # Handle special case for C-centered + centering = spg_sym[0] + if centering in ("A", "B", "C", "S"): + centering = "C" + + # Get number of sites in conventional cell + num_sites_conventional = len(moyo_data.std_cell.numbers) + pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + + # Group Wyckoff positions by element + element_dict = {} + element_wyckoffs = [] + for element, sites in groupby( + zip(struct.species, wyckoff_symbols), key=lambda x: x[0].symbol + ): + sites_list = list(sites) + element_dict[element] = sum( + wyckoff_multiplicity_dict[str(spg_num)][s[1].translate(remove_digits)] + for s in sites_list + ) + element_wyckoffs.append( + "".join( + f"{len(list(w))}{wyk[0].translate(remove_digits)}" + for wyk, w in groupby( + sorted(sites_list, key=lambda x: x[1]), key=lambda x: x[1] + ) + ) + ) + + prototype_form = get_prototype_formula_from_composition(struct.composition) + chemsys = struct.chemical_system + + all_wyckoffs = "_".join(element_wyckoffs) + all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) + + protostructure_label = ( + f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" + ) + + # Verify multiplicities match composition + observed_formula = Composition(element_dict).reduced_formula + expected_formula = struct.composition.reduced_formula + if observed_formula != expected_formula: + if fallback_symprec is not None: + attempt_to_recover = True + else: + err_msg = ( + f"Invalid WP multiplicities - {protostructure_label}, expected " + f"{observed_formula} to be {expected_formula}" + ) + if raise_errors: + raise ValueError(err_msg) + return err_msg + + return protostructure_label + + except Exception as exc: + if fallback_symprec is None: + raise exc + attempt_to_recover = True + + # Try again with fallback symprec if initial attempt failed + if attempt_to_recover: + return get_protostructure_label_from_moyopy( + struct, raise_errors=raise_errors, fallback_symprec=fallback_symprec + ) + + except Exception as exc: + if not raise_errors: + return str(exc) + raise + + return None + + def canonicalize_element_wyckoffs(element_wyckoffs: str, spg_num: int | str) -> str: """Given an element ordering, canonicalize the associated Wyckoff positions based on the alphabetical weight of equivalent choices of origin. From 1be475ac924fea486b98e6b4fac126bf01bb0e5a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 19 Jan 2025 17:36:44 -0500 Subject: [PATCH 02/16] wip tests get_protostructure_label_from_moyopy --- tests/test_wyckoff_ops.py | 255 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 5475199d..bd1ac7db 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -16,6 +16,7 @@ get_anonymous_formula_from_prototype_formula, get_formula_from_protostructure_label, get_protostructure_label_from_aflow, + get_protostructure_label_from_moyopy, get_protostructure_label_from_spg_analyzer, get_protostructure_label_from_spglib, get_protostructures_from_aflow_label_and_composition, @@ -314,3 +315,257 @@ def test_get_random_structure_for_protostructure_random(protostructure): assert s1.composition == s2.composition assert s1.lattice != s2.lattice + + +def test_get_protostructure_label_from_moyopy(): + """Check that moyopy gives correct protostructure label for esseneite""" + struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") + assert ( + get_protostructure_label_from_moyopy(struct) + == "ABC6D2_mC40_15_4e_4e_8f_24f:Ca-Fe-O-Si" + ) + + +def test_get_protostructure_label_from_moyopy_edge_case(): + """Check edge case where the symmetry precision is too low.""" + struct = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") + + defaults = inspect.signature(get_protostructure_label_from_moyopy).parameters + + assert defaults["init_symprec"].default == 0.1 + + raises_str = ( + "Invalid WP multiplicities - A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U, " + "expected UPa4Tc9 to be UPa2Tc3" + ) + + # Test that it gives invalid protostructure if fallback is None + with pytest.raises(ValueError, match=re.escape(raises_str)): + get_protostructure_label_from_moyopy( + struct, raise_errors=True, fallback_symprec=None + ) + + assert ( + get_protostructure_label_from_moyopy( + struct, raise_errors=False, fallback_symprec=None + ) + == raises_str + ) + + # Test that it recovers with fallback symprec + assert get_protostructure_label_from_moyopy(struct, raise_errors=True) == ( + "A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U" + ) + + assert get_protostructure_label_from_moyopy(struct, raise_errors=False) == ( + "A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U" + ) + + +@pytest.mark.parametrize( + "protostructure", + PROTOSTRUCTURE_SET, +) +def test_moyopy_spglib_consistency(protostructure): + """Check that moyopy and spglib give consistent results.""" + struct = get_random_structure_for_protostructure(protostructure) + + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + assert moyopy_label == spglib_label + + +def test_moyopy_spglib_interchangeable(): + """Test that moyopy and spglib functions are drop-in replacements for each other.""" + # Test normal case + struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") + + # Both should handle raise_errors=True/False similarly + for raise_errors in (True, False): + moyopy_label = get_protostructure_label_from_moyopy( + struct, raise_errors=raise_errors + ) + spglib_label = get_protostructure_label_from_spglib( + struct, raise_errors=raise_errors + ) + + # Compare parts that should be identical + moyopy_parts = moyopy_label.split("_", 3) + spglib_parts = spglib_label.split("_", 3) + + assert moyopy_parts[:3] == spglib_parts[:3] # prototype, Pearson, space group + assert ( + moyopy_label.split(":")[-1] == spglib_label.split(":")[-1] + ) # chemical system + + # Test edge case with invalid structure + struct_invalid = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") + + # Both should handle fallback_symprec=None similarly + for func in ( + get_protostructure_label_from_moyopy, + get_protostructure_label_from_spglib, + ): + # Should raise error when raise_errors=True + with pytest.raises(ValueError, match="Invalid WP multiplicities"): + func(struct_invalid, raise_errors=True, fallback_symprec=None) + + # Should return error message when raise_errors=False + result = func(struct_invalid, raise_errors=False, fallback_symprec=None) + assert "Invalid WP multiplicities" in result + assert "expected" in result + + # Both should recover with default fallback_symprec + moyopy_recovered = get_protostructure_label_from_moyopy(struct_invalid) + spglib_recovered = get_protostructure_label_from_spglib(struct_invalid) + + # Compare recovered results (ignoring Wyckoff position format differences) + moyopy_parts = moyopy_recovered.split("_", 3) + spglib_parts = spglib_recovered.split("_", 3) + + assert moyopy_parts[:3] == spglib_parts[:3] + assert moyopy_recovered.split(":")[-1] == spglib_recovered.split(":")[-1] + + +def test_moyopy_spglib_identical_results(): + """Test that moyopy and spglib give identical results for simple structures.""" + # Create simple test structures + test_structs = { + "cubic": Structure( + lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + species=["Na", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], + ), + "tetragonal": Structure( + lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], + species=["Ti", "O", "O"], + coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], + ), + "hexagonal": Structure( + lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], + species=["Zn", "O"], + coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], + ), + } + + # Expected outputs for each structure + expected_outputs = { + "cubic": ( + "AB_cP2_221_a_b:Cl-Na", # moyopy + "AB_cF8_225_a_b:Na-Cl", # spglib + ), + "tetragonal": ( + "AB2_tP6_136_2a_4c:Ti-O", + "AB2_tP6_136_a_2c:Ti-O", + ), + "hexagonal": ( + "AB_hP4_194_2a_2b:Zn-O", + "AB_hP4_194_a_b:Zn-O", + ), + } + + for name, struct in test_structs.items(): + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + moyopy_expected, spglib_expected = expected_outputs[name] + + assert moyopy_label == moyopy_expected, ( + f"Moyopy output mismatch for {name}:\n" + f"got: {moyopy_label}\n" + f"expected: {moyopy_expected}" + ) + + assert spglib_label == spglib_expected, ( + f"Spglib output mismatch for {name}:\n" + f"got: {spglib_label}\n" + f"expected: {spglib_expected}" + ) + + +def test_moyopy_spglib_equivalence(): + """Test that moyopy and spglib give equivalent results for various structures.""" + # Simple test structures with known symmetry + test_structs = { + "cubic": Structure( # NaCl structure + lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + species=["Na", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], + ), + "tetragonal": Structure( # TiO2 structure + lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], + species=["Ti", "O", "O"], + coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], + ), + "hexagonal": Structure( # ZnO structure + lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], + species=["Zn", "O"], + coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], + ), + # Real structure from file + "esseneite": Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), + } + + # Expected outputs (moyopy, spglib) for each structure + expected_outputs = { + "cubic": ( + "AB_cF8_225_a_b:Na-Cl", + "AB_cF8_225_a_b:Na-Cl", + ), + "tetragonal": ( + "AB2_tP6_136_2a_4c:Ti-O", + "AB2_tP6_136_a_2c:Ti-O", + ), + "hexagonal": ( + "AB_hP4_194_2a_2b:Zn-O", + "AB_hP4_194_a_b:Zn-O", + ), + "esseneite": ( + "ABC6D2_mC40_15_4e_4e_8f_24f:Ca-Fe-O-Si", + "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", + ), + } + + # Test each structure + for name, struct in test_structs.items(): + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + moyopy_expected, spglib_expected = expected_outputs[name] + + # Check full labels match expected output + assert moyopy_label == moyopy_expected, ( + f"Moyopy output mismatch for {name}:\n" + f"got: {moyopy_label}\n" + f"expected: {moyopy_expected}" + ) + assert spglib_label == spglib_expected, ( + f"Spglib output mismatch for {name}:\n" + f"got: {spglib_label}\n" + f"expected: {spglib_expected}" + ) + + # Check that both functions agree on key properties + moyopy_parts = moyopy_label.split("_", 3) + spglib_parts = spglib_label.split("_", 3) + + assert moyopy_parts[:3] == spglib_parts[:3], ( + f"Core properties mismatch for {name}:\n" + f"moyopy: {moyopy_parts[:3]}\n" + f"spglib: {spglib_parts[:3]}" + ) + + # Test random structures from PROTOSTRUCTURE_SET + for proto in PROTOSTRUCTURE_SET: + struct = get_random_structure_for_protostructure(proto) + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + # Compare core properties (prototype, Pearson, space group) + moyopy_parts = moyopy_label.split("_", 3) + spglib_parts = spglib_label.split("_", 3) + assert moyopy_parts[:3] == spglib_parts[:3] + + # Compare chemical system + assert moyopy_label.split(":")[-1] == spglib_label.split(":")[-1] From 88c4599074b0d73dd72975c49893b16fd955b67b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 19 Jan 2025 22:24:19 -0500 Subject: [PATCH 03/16] fix: use orbits to get correct multiplicities but the order/element match still seems wrong for moyopy. Remove fallback from moyopy as without spglib structure refinement it won't do anything. --- aviary/wren/utils.py | 355 +++++++++++++++++++++++--------------- tests/test_wyckoff_ops.py | 62 +------ 2 files changed, 216 insertions(+), 201 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 04c184ab..2180d99c 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -9,6 +9,7 @@ from os.path import abspath, dirname, join from shutil import which from string import ascii_uppercase, digits +from typing import Literal from monty.fractions import gcd from pymatgen.core import Composition, Structure @@ -16,8 +17,21 @@ try: from pyxtal import pyxtal + + has_pyxtal = True except ImportError: pyxtal = None + has_pyxtal = False + +try: + import moyopy + from moyopy.interface import MoyoAdapter + + has_moyopy = True +except ImportError: + moyopy = None + has_moyopy = False + module_dir = dirname(abspath(__file__)) @@ -96,10 +110,111 @@ def count_values_for_wyckoff( ) +def get_crystal_system(n: int) -> str: + """Get the crystal system for the structure, e.g. (triclinic, orthorhombic, + cubic, etc.). + + Mirrors method of SpacegroupAnalyzer.get_crystal_system(). + + Args: + n (int): Space group number + + Raises: + ValueError: on invalid space group numbers < 1 or > 230. + + Returns: + str: Crystal system for structure + """ + # Not using isinstance(n, int) to allow 0-decimal floats + if n != int(n) or not 0 < n < 231: + raise ValueError(f"Received invalid space group {n}") + + if 0 < n < 3: + return "triclinic" + if n < 16: + return "monoclinic" + if n < 75: + return "orthorhombic" + if n < 143: + return "tetragonal" + if n < 168: + return "trigonal" + if n < 195: + return "hexagonal" + return "cubic" + + +def get_centering(spg_sym: str) -> str: + """Get the centering for the structure, e.g. (A, B, C, S).""" + return "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] + + +def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> str: + """Get the Pearson symbol for the structure.""" + cry_sys = spg_analyzer.get_crystal_system() + spg_sym = spg_analyzer.get_space_group_symbol() + centering = get_centering(spg_sym) + + num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) + return f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + + +def get_pearson_symbol_from_moyo_dataset(moyo_data: moyopy.MoyoDataset) -> str: + """Get the Pearson symbol for the structure from a MoyoDataset.""" + # Get space group number and Wyckoff positions + spg_num = moyo_data.number + + # Get crystal system and centering from Hall symbol entry + hall_entry = moyopy.HallSymbolEntry(hall_number=moyo_data.hall_number) + spg_sym = hall_entry.hm_short + centering = hall_entry.centering + + # Get crystal system from space group number instead of symbol + cry_sys = get_crystal_system(spg_num) + + # Get centering from first letter of space group symbol + # Handle special case for C-centered + centering = get_centering(spg_sym) + + # Get number of sites in conventional cell + num_sites_conventional = len(moyo_data.std_cell.numbers) + return f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + + +def get_protostructure_label( + struct: Structure, + method: Literal["aflow", "spglib", "moyopy"], + raise_errors: bool = False, + **kwargs, +) -> str: + """Get protostructure label for a pymatgen Structure. + + Args: + struct (Structure): pymatgen Structure + method (Literal["aflow", "spglib", "moyopy"]): Method to use for symmetry + detection + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + **kwargs: Additional arguments for the specific method + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + if method == "aflow": + return get_protostructure_label_from_aflow(struct, raise_errors, **kwargs) + if method == "spglib": + return get_protostructure_label_from_spglib(struct, raise_errors, **kwargs) + if method == "moyopy": + return get_protostructure_label_from_moyopy(struct, raise_errors, **kwargs) + raise ValueError(f"Invalid method: {method}") + + def get_protostructure_label_from_aflow( struct: Structure, - aflow_executable: str | None = None, raise_errors: bool = False, + aflow_executable: str | None = None, ) -> str: """Get protostructure label for a pymatgen Structure. Make sure you're running a recent version of the aflow CLI as there's been several breaking changes. This code @@ -184,30 +299,10 @@ def get_protostructure_label_from_aflow( return protostructure_label -def get_protostructure_label_from_spg_analyzer( - spg_analyzer: SpacegroupAnalyzer, - raise_errors: bool = False, -) -> str: - """Get protostructure label for pymatgen SpacegroupAnalyzer. - - Args: - spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. - raise_errors (bool): Whether to raise errors or annotate them. Defaults to - False. - - Returns: - str: protostructure_label which is constructed as `aflow_label:chemsys` or - explanation of failure if symmetry detection failed and `raise_errors` - is False. - """ - spg_num = spg_analyzer.get_space_group_number() - sym_struct = spg_analyzer.get_symmetrized_structure() - - equivalent_wyckoff_labels = [ - # tuple of (wp multiplicity, element, wyckoff letter) - (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) - for s, wyk_letter in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) - ] +def _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels: list[tuple[int, str, str]], + spg_num: int | str, +): # Pre-sort by element and wyckoff letter to ensure continuous groups in groupby equivalent_wyckoff_labels = sorted( equivalent_wyckoff_labels, key=lambda x: (x[1], x[2]) @@ -228,19 +323,45 @@ def get_protostructure_label_from_spg_analyzer( for wyk, w in groupby(list_group, key=lambda x: x[2]) ) ) + all_wyckoffs = "_".join(element_wyckoffs) + all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) - # get Pearson symbol - cry_sys = spg_analyzer.get_crystal_system() - spg_sym = spg_analyzer.get_space_group_symbol() - centering = "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] - num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) - pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + return all_wyckoffs, element_dict + +def get_protostructure_label_from_spg_analyzer( + spg_analyzer: SpacegroupAnalyzer, + raise_errors: bool = False, +) -> str: + """Get protostructure label for pymatgen SpacegroupAnalyzer. + + Args: + spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + sym_struct = spg_analyzer.get_symmetrized_structure() + + spg_num = spg_analyzer.get_space_group_number() + pearson_symbol = get_pearson_symbol_from_spg_analyzer(spg_analyzer) prototype_form = get_prototype_formula_from_composition(sym_struct.composition) chemsys = sym_struct.chemical_system - all_wyckoffs = "_".join(element_wyckoffs) - all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) + # get Wyckoff position substring + equivalent_wyckoff_labels = [ + # tuple of (wp multiplicity, element, wyckoff letter) + (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) + for s, wyk_letter in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) + ] + + all_wyckoffs, element_dict = _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels, spg_num + ) protostructure_label = ( f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" @@ -266,7 +387,7 @@ def get_protostructure_label_from_spglib( raise_errors: bool = False, init_symprec: float = 0.1, fallback_symprec: float | None = 1e-5, -) -> str | None: +) -> str: """Get AFLOW prototype label for pymatgen Structure. Args: @@ -316,12 +437,59 @@ def get_protostructure_label_from_spglib( raise +def _get_protostructure_label_from_moyopy( + struct: Structure, + symprec: float, + raise_errors: bool = False, +) -> str: + moyo_cell = MoyoAdapter.from_structure(struct) + moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=symprec) + + # Get space group number and Wyckoff positions + spg_num = moyo_data.number + pearson_symbol = get_pearson_symbol_from_moyo_dataset(moyo_data) + prototype_form = get_prototype_formula_from_composition(struct.composition) + chemsys = struct.chemical_system + + # Group Wyckoff positions by orbit and element + equivalent_wyckoff_labels = [] + for orbit_idx, group in groupby(sorted(moyo_data.orbits)): + equivalent_wyckoff_labels.append( + ( + len(list(group)), # multiplicity + struct.species[orbit_idx], # element + moyo_data.wyckoffs[orbit_idx], # wyckoff letter + ) + ) + + all_wyckoffs, element_dict = _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels, spg_num + ) + + protostructure_label = ( + f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" + ) + + # Verify multiplicities match composition + observed_formula = Composition(element_dict).reduced_formula + expected_formula = struct.composition.reduced_formula + if observed_formula != expected_formula: + err_msg = ( + f"Invalid WP multiplicities - {protostructure_label}, expected " + f"{observed_formula} to be {expected_formula}" + ) + if raise_errors: + raise ValueError(err_msg) + return err_msg + + return protostructure_label + + def get_protostructure_label_from_moyopy( struct: Structure, raise_errors: bool = False, init_symprec: float = 0.1, - fallback_symprec: float | None = 1e-5, -) -> str | None: +) -> str: """Get AFLOW prototype label using Moyopy for symmetry detection. Args: @@ -329,124 +497,27 @@ def get_protostructure_label_from_moyopy( raise_errors (bool): Whether to raise errors or annotate them. Defaults to False. init_symprec (float): Initial symmetry precision for Moyopy. Defaults to 0.1. - fallback_symprec (float): Fallback symmetry precision if first symmetry detection - failed. Defaults to 1e-5. Returns: str: protostructure_label which is constructed as `aflow_label:chemsys` or explanation of failure if symmetry detection failed and `raise_errors` is False. """ - import moyopy - from moyopy.interface import MoyoAdapter + if not has_moyopy: + raise ImportError( + "moyopy is not installed, please install it with `pip install moyopy`" + ) - attempt_to_recover = False try: - # Convert pymatgen Structure to Moyo Cell - moyo_cell = MoyoAdapter.from_structure(struct) - - try: - # First attempt with initial symprec - moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=init_symprec) - - # Get space group number and Wyckoff positions - spg_num = moyo_data.number - wyckoff_symbols = moyo_data.wyckoffs - - # Get crystal system and centering from Hall symbol entry - hall_entry = moyopy.HallSymbolEntry(hall_number=moyo_data.hall_number) - spg_sym = hall_entry.hm_short - - # Get crystal system from space group number instead of symbol - if spg_num <= 2: - cry_sys = "triclinic" - elif spg_num <= 15: - cry_sys = "monoclinic" - elif spg_num <= 74: - cry_sys = "orthorhombic" - elif spg_num <= 142: - cry_sys = "tetragonal" - elif spg_num <= 167: - cry_sys = "trigonal" - elif spg_num <= 194: - cry_sys = "hexagonal" - else: - cry_sys = "cubic" - - # Get centering from first letter of space group symbol - # Handle special case for C-centered - centering = spg_sym[0] - if centering in ("A", "B", "C", "S"): - centering = "C" - - # Get number of sites in conventional cell - num_sites_conventional = len(moyo_data.std_cell.numbers) - pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" - - # Group Wyckoff positions by element - element_dict = {} - element_wyckoffs = [] - for element, sites in groupby( - zip(struct.species, wyckoff_symbols), key=lambda x: x[0].symbol - ): - sites_list = list(sites) - element_dict[element] = sum( - wyckoff_multiplicity_dict[str(spg_num)][s[1].translate(remove_digits)] - for s in sites_list - ) - element_wyckoffs.append( - "".join( - f"{len(list(w))}{wyk[0].translate(remove_digits)}" - for wyk, w in groupby( - sorted(sites_list, key=lambda x: x[1]), key=lambda x: x[1] - ) - ) - ) - - prototype_form = get_prototype_formula_from_composition(struct.composition) - chemsys = struct.chemical_system - - all_wyckoffs = "_".join(element_wyckoffs) - all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) - - protostructure_label = ( - f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" - ) - - # Verify multiplicities match composition - observed_formula = Composition(element_dict).reduced_formula - expected_formula = struct.composition.reduced_formula - if observed_formula != expected_formula: - if fallback_symprec is not None: - attempt_to_recover = True - else: - err_msg = ( - f"Invalid WP multiplicities - {protostructure_label}, expected " - f"{observed_formula} to be {expected_formula}" - ) - if raise_errors: - raise ValueError(err_msg) - return err_msg - - return protostructure_label - - except Exception as exc: - if fallback_symprec is None: - raise exc - attempt_to_recover = True - - # Try again with fallback symprec if initial attempt failed - if attempt_to_recover: - return get_protostructure_label_from_moyopy( - struct, raise_errors=raise_errors, fallback_symprec=fallback_symprec - ) - + aflow_label_with_chemsys = _get_protostructure_label_from_moyopy( + struct, init_symprec, raise_errors + ) except Exception as exc: if not raise_errors: return str(exc) raise - return None + return aflow_label_with_chemsys def canonicalize_element_wyckoffs(element_wyckoffs: str, spg_num: int | str) -> str: @@ -830,7 +901,7 @@ def get_random_structure_for_protostructure( sorted chemical system. **kwargs: Keyword arguments to pass to pyxtal().from_random() """ - if pyxtal is None: + if not has_pyxtal: raise ImportError("pyxtal is required for this function") aflow_label, chemsys = protostructure_label.split(":") diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index bd1ac7db..e0e3f191 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -283,7 +283,7 @@ def test_get_protostructure_label_from_aflow(): """Check we extract correct protostructure label for esseneite using AFLOW CLI.""" struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - out = get_protostructure_label_from_aflow(struct, which("aflow")) + out = get_protostructure_label_from_aflow(struct, aflow_executable=which("aflow")) expected = "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" assert out == expected @@ -428,62 +428,6 @@ def test_moyopy_spglib_interchangeable(): assert moyopy_recovered.split(":")[-1] == spglib_recovered.split(":")[-1] -def test_moyopy_spglib_identical_results(): - """Test that moyopy and spglib give identical results for simple structures.""" - # Create simple test structures - test_structs = { - "cubic": Structure( - lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], - species=["Na", "Cl"], - coords=[[0, 0, 0], [0.5, 0.5, 0.5]], - ), - "tetragonal": Structure( - lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], - species=["Ti", "O", "O"], - coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], - ), - "hexagonal": Structure( - lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], - species=["Zn", "O"], - coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], - ), - } - - # Expected outputs for each structure - expected_outputs = { - "cubic": ( - "AB_cP2_221_a_b:Cl-Na", # moyopy - "AB_cF8_225_a_b:Na-Cl", # spglib - ), - "tetragonal": ( - "AB2_tP6_136_2a_4c:Ti-O", - "AB2_tP6_136_a_2c:Ti-O", - ), - "hexagonal": ( - "AB_hP4_194_2a_2b:Zn-O", - "AB_hP4_194_a_b:Zn-O", - ), - } - - for name, struct in test_structs.items(): - moyopy_label = get_protostructure_label_from_moyopy(struct) - spglib_label = get_protostructure_label_from_spglib(struct) - - moyopy_expected, spglib_expected = expected_outputs[name] - - assert moyopy_label == moyopy_expected, ( - f"Moyopy output mismatch for {name}:\n" - f"got: {moyopy_label}\n" - f"expected: {moyopy_expected}" - ) - - assert spglib_label == spglib_expected, ( - f"Spglib output mismatch for {name}:\n" - f"got: {spglib_label}\n" - f"expected: {spglib_expected}" - ) - - def test_moyopy_spglib_equivalence(): """Test that moyopy and spglib give equivalent results for various structures.""" # Simple test structures with known symmetry @@ -510,8 +454,8 @@ def test_moyopy_spglib_equivalence(): # Expected outputs (moyopy, spglib) for each structure expected_outputs = { "cubic": ( - "AB_cF8_225_a_b:Na-Cl", - "AB_cF8_225_a_b:Na-Cl", + "AB_cF8_225_a_b:Cl-Na", + "AB_cF8_225_a_b:Cl-Na", ), "tetragonal": ( "AB2_tP6_136_2a_4c:Ti-O", From 1486cec286494f55f5ab84af69d378ea7cab29c4 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 19 Jan 2025 22:57:24 -0500 Subject: [PATCH 04/16] clean: remove confusing tests to focus on getting right answers --- aviary/wren/utils.py | 59 +++----- tests/test_wyckoff_ops.py | 283 ++++++++++---------------------------- 2 files changed, 90 insertions(+), 252 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 2180d99c..e237a22f 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -437,11 +437,29 @@ def get_protostructure_label_from_spglib( raise -def _get_protostructure_label_from_moyopy( +def get_protostructure_label_from_moyopy( struct: Structure, - symprec: float, raise_errors: bool = False, + symprec: float = 0.1, ) -> str: + """Get AFLOW prototype label using Moyopy for symmetry detection. + + Args: + struct (Structure): pymatgen Structure object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + symprec (float): Initial symmetry precision for Moyopy. Defaults to 0.1. + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + if not has_moyopy: + raise ImportError( + "moyopy is not installed, please install it with `pip install moyopy`" + ) + moyo_cell = MoyoAdapter.from_structure(struct) moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=symprec) @@ -453,7 +471,7 @@ def _get_protostructure_label_from_moyopy( # Group Wyckoff positions by orbit and element equivalent_wyckoff_labels = [] - for orbit_idx, group in groupby(sorted(moyo_data.orbits)): + for orbit_idx, group in groupby(moyo_data.orbits): equivalent_wyckoff_labels.append( ( len(list(group)), # multiplicity @@ -485,41 +503,6 @@ def _get_protostructure_label_from_moyopy( return protostructure_label -def get_protostructure_label_from_moyopy( - struct: Structure, - raise_errors: bool = False, - init_symprec: float = 0.1, -) -> str: - """Get AFLOW prototype label using Moyopy for symmetry detection. - - Args: - struct (Structure): pymatgen Structure object. - raise_errors (bool): Whether to raise errors or annotate them. Defaults to - False. - init_symprec (float): Initial symmetry precision for Moyopy. Defaults to 0.1. - - Returns: - str: protostructure_label which is constructed as `aflow_label:chemsys` or - explanation of failure if symmetry detection failed and `raise_errors` - is False. - """ - if not has_moyopy: - raise ImportError( - "moyopy is not installed, please install it with `pip install moyopy`" - ) - - try: - aflow_label_with_chemsys = _get_protostructure_label_from_moyopy( - struct, init_symprec, raise_errors - ) - except Exception as exc: - if not raise_errors: - return str(exc) - raise - - return aflow_label_with_chemsys - - def canonicalize_element_wyckoffs(element_wyckoffs: str, spg_num: int | str) -> str: """Given an element ordering, canonicalize the associated Wyckoff positions based on the alphabetical weight of equivalent choices of origin. diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index e0e3f191..1a79e65b 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -1,6 +1,6 @@ import inspect import re -from itertools import permutations +from itertools import permutations, product from shutil import which import pytest @@ -15,6 +15,7 @@ count_wyckoff_positions, get_anonymous_formula_from_prototype_formula, get_formula_from_protostructure_label, + get_protostructure_label, get_protostructure_label_from_aflow, get_protostructure_label_from_moyopy, get_protostructure_label_from_spg_analyzer, @@ -45,14 +46,37 @@ ("AB3C_cP5_221_a_c_b:Ba-O-Ti"), ] +TEST_STRUCTS = [ + Structure( # NaCl structure + lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + species=["Na", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], + ), + Structure( # TiO2 structure + lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], + species=["Ti", "O", "O"], + coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], + ), + Structure( # ZnO structure + lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], + species=["Zn", "O"], + coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], + ), + Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), +] -def test_get_protostructure_label_from_spglib(): - """Check that spglib gives correct protostructure label for esseneite""" - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - assert ( - get_protostructure_label_from_spglib(struct) - == "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" - ) +TEST_PROTOSTRUCTURES = [ + "AB_cF8_225_a_b:Cl-Na", + "AB2_tP6_136_a_2c:Ti-O", + "AB_hP4_194_a_b:Zn-O", + "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", +] + + +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_spglib(structure, expected): + """Check that spglib gives correct protostructure label simple cases.""" + assert get_protostructure_label_from_spglib(structure) == expected def test_get_protostructure_label_from_spglib_edge_case(): @@ -279,13 +303,33 @@ def test_count_distinct_wyckoff_letters(protostructure_label, expected): @pytest.mark.skipif(which("aflow") is None, reason="AFLOW CLI not installed") -def test_get_protostructure_label_from_aflow(): - """Check we extract correct protostructure label for esseneite using AFLOW CLI.""" - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_aflow(structure, expected): + """Check that AFLOW CLI gives correct protostructure label simple cases.""" + assert ( + get_protostructure_label_from_aflow(structure, aflow_executable=which("aflow")) + == expected + ) + - out = get_protostructure_label_from_aflow(struct, aflow_executable=which("aflow")) - expected = "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" - assert out == expected +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_moyopy(structure, expected): + """Check that moyopy gives correct protostructure label simple cases.""" + assert get_protostructure_label_from_moyopy(structure) == expected + + +@pytest.mark.parametrize( + "protostructure", + PROTOSTRUCTURE_SET, +) +def test_moyopy_spglib_consistency(protostructure): + """Check that moyopy and spglib give consistent results.""" + struct = get_random_structure_for_protostructure(protostructure) + + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + assert moyopy_label == spglib_label @pytest.mark.skipif(pyxtal is None, reason="pyxtal not installed") @@ -293,13 +337,16 @@ def test_get_protostructure_label_from_aflow(): reason="pyxtal is non-deterministic and symmetry can increase in random crystal" ) @pytest.mark.parametrize( - "protostructure", - PROTOSTRUCTURE_SET, + "protostructure, method", + list(product(PROTOSTRUCTURE_SET, ["spglib", "moyopy"])), ) -def test_get_random_structure_for_protostructure_roundtrip(protostructure): +def test_get_random_structure_for_protostructure_roundtrip( + protostructure: str, method: str +): """Check roundtrip for generating a random structure from a prototype string""" - assert protostructure == get_protostructure_label_from_spglib( - get_random_structure_for_protostructure(protostructure) + assert protostructure == get_protostructure_label( + get_random_structure_for_protostructure(protostructure), + method=method, ) @@ -317,199 +364,7 @@ def test_get_random_structure_for_protostructure_random(protostructure): assert s1.lattice != s2.lattice -def test_get_protostructure_label_from_moyopy(): - """Check that moyopy gives correct protostructure label for esseneite""" - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - assert ( - get_protostructure_label_from_moyopy(struct) - == "ABC6D2_mC40_15_4e_4e_8f_24f:Ca-Fe-O-Si" - ) - - -def test_get_protostructure_label_from_moyopy_edge_case(): - """Check edge case where the symmetry precision is too low.""" - struct = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") - - defaults = inspect.signature(get_protostructure_label_from_moyopy).parameters - - assert defaults["init_symprec"].default == 0.1 - - raises_str = ( - "Invalid WP multiplicities - A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U, " - "expected UPa4Tc9 to be UPa2Tc3" - ) - - # Test that it gives invalid protostructure if fallback is None - with pytest.raises(ValueError, match=re.escape(raises_str)): - get_protostructure_label_from_moyopy( - struct, raise_errors=True, fallback_symprec=None - ) - - assert ( - get_protostructure_label_from_moyopy( - struct, raise_errors=False, fallback_symprec=None - ) - == raises_str - ) - - # Test that it recovers with fallback symprec - assert get_protostructure_label_from_moyopy(struct, raise_errors=True) == ( - "A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U" - ) - - assert get_protostructure_label_from_moyopy(struct, raise_errors=False) == ( - "A2B3C_hP6_191_2a_4c_6g:Pa-Tc-U" - ) - - -@pytest.mark.parametrize( - "protostructure", - PROTOSTRUCTURE_SET, -) -def test_moyopy_spglib_consistency(protostructure): - """Check that moyopy and spglib give consistent results.""" - struct = get_random_structure_for_protostructure(protostructure) - - moyopy_label = get_protostructure_label_from_moyopy(struct) - spglib_label = get_protostructure_label_from_spglib(struct) - - assert moyopy_label == spglib_label - - -def test_moyopy_spglib_interchangeable(): - """Test that moyopy and spglib functions are drop-in replacements for each other.""" - # Test normal case - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - - # Both should handle raise_errors=True/False similarly - for raise_errors in (True, False): - moyopy_label = get_protostructure_label_from_moyopy( - struct, raise_errors=raise_errors - ) - spglib_label = get_protostructure_label_from_spglib( - struct, raise_errors=raise_errors - ) - - # Compare parts that should be identical - moyopy_parts = moyopy_label.split("_", 3) - spglib_parts = spglib_label.split("_", 3) - - assert moyopy_parts[:3] == spglib_parts[:3] # prototype, Pearson, space group - assert ( - moyopy_label.split(":")[-1] == spglib_label.split(":")[-1] - ) # chemical system - - # Test edge case with invalid structure - struct_invalid = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") - - # Both should handle fallback_symprec=None similarly - for func in ( - get_protostructure_label_from_moyopy, - get_protostructure_label_from_spglib, - ): - # Should raise error when raise_errors=True - with pytest.raises(ValueError, match="Invalid WP multiplicities"): - func(struct_invalid, raise_errors=True, fallback_symprec=None) - - # Should return error message when raise_errors=False - result = func(struct_invalid, raise_errors=False, fallback_symprec=None) - assert "Invalid WP multiplicities" in result - assert "expected" in result - - # Both should recover with default fallback_symprec - moyopy_recovered = get_protostructure_label_from_moyopy(struct_invalid) - spglib_recovered = get_protostructure_label_from_spglib(struct_invalid) - - # Compare recovered results (ignoring Wyckoff position format differences) - moyopy_parts = moyopy_recovered.split("_", 3) - spglib_parts = spglib_recovered.split("_", 3) - - assert moyopy_parts[:3] == spglib_parts[:3] - assert moyopy_recovered.split(":")[-1] == spglib_recovered.split(":")[-1] - - -def test_moyopy_spglib_equivalence(): - """Test that moyopy and spglib give equivalent results for various structures.""" - # Simple test structures with known symmetry - test_structs = { - "cubic": Structure( # NaCl structure - lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], - species=["Na", "Cl"], - coords=[[0, 0, 0], [0.5, 0.5, 0.5]], - ), - "tetragonal": Structure( # TiO2 structure - lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], - species=["Ti", "O", "O"], - coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], - ), - "hexagonal": Structure( # ZnO structure - lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], - species=["Zn", "O"], - coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], - ), - # Real structure from file - "esseneite": Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), - } - - # Expected outputs (moyopy, spglib) for each structure - expected_outputs = { - "cubic": ( - "AB_cF8_225_a_b:Cl-Na", - "AB_cF8_225_a_b:Cl-Na", - ), - "tetragonal": ( - "AB2_tP6_136_2a_4c:Ti-O", - "AB2_tP6_136_a_2c:Ti-O", - ), - "hexagonal": ( - "AB_hP4_194_2a_2b:Zn-O", - "AB_hP4_194_a_b:Zn-O", - ), - "esseneite": ( - "ABC6D2_mC40_15_4e_4e_8f_24f:Ca-Fe-O-Si", - "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", - ), - } - - # Test each structure - for name, struct in test_structs.items(): - moyopy_label = get_protostructure_label_from_moyopy(struct) - spglib_label = get_protostructure_label_from_spglib(struct) - - moyopy_expected, spglib_expected = expected_outputs[name] - - # Check full labels match expected output - assert moyopy_label == moyopy_expected, ( - f"Moyopy output mismatch for {name}:\n" - f"got: {moyopy_label}\n" - f"expected: {moyopy_expected}" - ) - assert spglib_label == spglib_expected, ( - f"Spglib output mismatch for {name}:\n" - f"got: {spglib_label}\n" - f"expected: {spglib_expected}" - ) - - # Check that both functions agree on key properties - moyopy_parts = moyopy_label.split("_", 3) - spglib_parts = spglib_label.split("_", 3) - - assert moyopy_parts[:3] == spglib_parts[:3], ( - f"Core properties mismatch for {name}:\n" - f"moyopy: {moyopy_parts[:3]}\n" - f"spglib: {spglib_parts[:3]}" - ) - - # Test random structures from PROTOSTRUCTURE_SET - for proto in PROTOSTRUCTURE_SET: - struct = get_random_structure_for_protostructure(proto) - moyopy_label = get_protostructure_label_from_moyopy(struct) - spglib_label = get_protostructure_label_from_spglib(struct) - - # Compare core properties (prototype, Pearson, space group) - moyopy_parts = moyopy_label.split("_", 3) - spglib_parts = spglib_label.split("_", 3) - assert moyopy_parts[:3] == spglib_parts[:3] +if __name__ == "__main__": + import pytest - # Compare chemical system - assert moyopy_label.split(":")[-1] == spglib_label.split(":")[-1] + pytest.main(["-v", __file__]) From a81fe1c9229cf26403dc37c0c7e569e1752d83fe Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Jan 2025 09:27:06 -0500 Subject: [PATCH 05/16] fix: replace tests with known structure test cases --- pyproject.toml | 7 ++++--- tests/test_wyckoff_ops.py | 35 ++++++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa6be25e..df995e9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aviary" -version = "1.1.0" +version = "1.1.1" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" @@ -35,7 +35,7 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ - "numpy<2", + "numpy", "pandas", "pymatgen", "scikit_learn", @@ -49,8 +49,9 @@ dependencies = [ Repo = "https://github.com/CompRhys/aviary" [project.optional-dependencies] -test = ["matminer", "pytest", "pytest-cov", "pyxtal"] +test = ["matminer", "moyopy", "pytest", "pytest-cov", "pyxtal"] pyxtal = ["pyxtal"] +moyopy = ["moyopy"] [tool.setuptools.packages] find = { include = ["aviary*"], exclude = ["tests*"] } diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 1a79e65b..113dfcd3 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -4,7 +4,7 @@ from shutil import which import pytest -from pymatgen.core.structure import Composition, Structure +from pymatgen.core.structure import Composition, Lattice, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from aviary.wren.utils import ( @@ -48,27 +48,40 @@ TEST_STRUCTS = [ Structure( # NaCl structure - lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + lattice=[[2, 2, 0], [0, 2, 2], [2, 0, 2]], species=["Na", "Cl"], coords=[[0, 0, 0], [0.5, 0.5, 0.5]], ), - Structure( # TiO2 structure - lattice=[[3, 0, 0], [0, 3, 0], [0, 0, 4]], - species=["Ti", "O", "O"], - coords=[[0, 0, 0], [0.3, 0.3, 0], [0.7, 0.7, 0]], + Structure( # CsCl structure + lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + species=["Cs", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], ), - Structure( # ZnO structure - lattice=[[3, 0, 0], [-1.5, 2.6, 0], [0, 0, 5]], + Structure( # ZnO zincblende structure + lattice=[[2, 2, 0], [0, 2, 2], [2, 0, 2]], species=["Zn", "O"], - coords=[[1 / 3, 2 / 3, 0], [2 / 3, 1 / 3, 0.5]], + coords=[[0, 0, 0], [0.25, 0.25, 0.25]], + ), + Structure( # ZnO wurtzite structure + lattice=Lattice.from_parameters( + a=3.8227, b=3.8227, c=6.2607, alpha=90, beta=90, gamma=120 + ), + species=["Zn", "O", "Zn", "O"], + coords=[ + [1 / 3, 2 / 3, 0], + [2 / 3, 1 / 3, 0.3748], + [2 / 3, 1 / 3, 1 / 2], + [1 / 3, 2 / 3, 1 / 2 + 0.3748], + ], ), Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), ] TEST_PROTOSTRUCTURES = [ "AB_cF8_225_a_b:Cl-Na", - "AB2_tP6_136_a_2c:Ti-O", - "AB_hP4_194_a_b:Zn-O", + "AB_cP2_221_a_b:Cl-Cs", + "AB_cF8_216_a_c:O-Zn", + "AB_hP4_186_b_b:O-Zn", "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", ] From 57372af7d35cc2fac536d434a08942bc32a8e5d3 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 20 Jan 2025 19:12:44 -0500 Subject: [PATCH 06/16] improve get_protostructure_label_from_moyopy orbit grouping - refactor moyopy import handling - print offending structure in error messages test_wyckoff_ops assert messages --- aviary/wren/utils.py | 75 +++++++++++++++++++++++++-------------- tests/test_wyckoff_ops.py | 8 +++-- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index e237a22f..98e05e43 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -9,7 +9,7 @@ from os.path import abspath, dirname, join from shutil import which from string import ascii_uppercase, digits -from typing import Literal +from typing import TYPE_CHECKING, Literal from monty.fractions import gcd from pymatgen.core import Composition, Structure @@ -23,15 +23,8 @@ pyxtal = None has_pyxtal = False -try: +if TYPE_CHECKING: import moyopy - from moyopy.interface import MoyoAdapter - - has_moyopy = True -except ImportError: - moyopy = None - has_moyopy = False - module_dir = dirname(abspath(__file__)) @@ -161,6 +154,8 @@ def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> st def get_pearson_symbol_from_moyo_dataset(moyo_data: moyopy.MoyoDataset) -> str: """Get the Pearson symbol for the structure from a MoyoDataset.""" + import moyopy + # Get space group number and Wyckoff positions spg_num = moyo_data.number @@ -186,7 +181,7 @@ def get_protostructure_label( method: Literal["aflow", "spglib", "moyopy"], raise_errors: bool = False, **kwargs, -) -> str: +) -> str | None: """Get protostructure label for a pymatgen Structure. Args: @@ -303,6 +298,18 @@ def _get_all_wyckoffs_substring_and_element_dict( equivalent_wyckoff_labels: list[tuple[int, str, str]], spg_num: int | str, ): + """Get Wyckoff position substring and element dict from equivalent Wyckoff labels. + + Args: + equivalent_wyckoff_labels (list[tuple[int, str, str]]): List of tuples containing + (multiplicity, element symbol, Wyckoff letter). + spg_num (int | str): Space group number. + + Returns: + tuple[str, dict]: Tuple containing: + - str: Wyckoff position substring + - dict: Dictionary mapping element symbols to their multiplicities + """ # Pre-sort by element and wyckoff letter to ensure continuous groups in groupby equivalent_wyckoff_labels = sorted( equivalent_wyckoff_labels, key=lambda x: (x[1], x[2]) @@ -317,10 +324,11 @@ def _get_all_wyckoffs_substring_and_element_dict( element_dict[el] = sum( wyckoff_multiplicity_dict[str(spg_num)][e[2]] for e in list_group ) + # group by Wyckoff letter to get Wyckoff site multiplicity from len element_wyckoffs.append( "".join( - f"{len(list(w))}{wyk}" - for wyk, w in groupby(list_group, key=lambda x: x[2]) + f"{len(list(occurrences))}{wyk_letter}" + for wyk_letter, occurrences in groupby(list_group, key=lambda x: x[2]) ) ) all_wyckoffs = "_".join(element_wyckoffs) @@ -441,7 +449,7 @@ def get_protostructure_label_from_moyopy( struct: Structure, raise_errors: bool = False, symprec: float = 0.1, -) -> str: +) -> str | None: """Get AFLOW prototype label using Moyopy for symmetry detection. Args: @@ -455,15 +463,17 @@ def get_protostructure_label_from_moyopy( explanation of failure if symmetry detection failed and `raise_errors` is False. """ - if not has_moyopy: - raise ImportError( - "moyopy is not installed, please install it with `pip install moyopy`" - ) + # Convert pymatgen Structure to Moyo Cell and get symmetry data + try: + import moyopy + from moyopy.interface import MoyoAdapter + except ImportError: + raise ImportError("moyopy not found, run pip install moyopy") from None moyo_cell = MoyoAdapter.from_structure(struct) moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=symprec) - # Get space group number and Wyckoff positions + # Get space group number and Pearson symbol spg_num = moyo_data.number pearson_symbol = get_pearson_symbol_from_moyo_dataset(moyo_data) prototype_form = get_prototype_formula_from_composition(struct.composition) @@ -471,14 +481,27 @@ def get_protostructure_label_from_moyopy( # Group Wyckoff positions by orbit and element equivalent_wyckoff_labels = [] - for orbit_idx, group in groupby(moyo_data.orbits): - equivalent_wyckoff_labels.append( - ( - len(list(group)), # multiplicity - struct.species[orbit_idx], # element - moyo_data.wyckoffs[orbit_idx], # wyckoff letter - ) - ) + orbit_groups: list[list[int]] = [] + current_orbit: list[int] = [] + + # Group sites by orbit + for idx, orbit_id in enumerate(moyo_data.orbits): + if not current_orbit or orbit_id == moyo_data.orbits[current_orbit[0]]: + current_orbit += [idx] + else: + orbit_groups += [current_orbit] + current_orbit = [idx] + if current_orbit: + orbit_groups += [current_orbit] + + # Create equivalent_wyckoff_labels from orbit groups + for orbit in orbit_groups: + # All sites in an orbit have the same Wyckoff letter and element + wyckoff = moyo_data.wyckoffs[orbit[0]] + element = struct.species[orbit[0]] + equivalent_wyckoff_labels += [ + (len(orbit), element.symbol, wyckoff.translate(remove_digits)) + ] all_wyckoffs, element_dict = _get_all_wyckoffs_substring_and_element_dict( equivalent_wyckoff_labels, spg_num diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 113dfcd3..abcaa337 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -328,7 +328,9 @@ def test_get_protostructure_label_from_aflow(structure, expected): @pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) def test_get_protostructure_label_from_moyopy(structure, expected): """Check that moyopy gives correct protostructure label simple cases.""" - assert get_protostructure_label_from_moyopy(structure) == expected + assert ( + get_protostructure_label_from_moyopy(structure) == expected + ), f"unexpected moyopy protostructure for {structure=}" @pytest.mark.parametrize( @@ -342,7 +344,9 @@ def test_moyopy_spglib_consistency(protostructure): moyopy_label = get_protostructure_label_from_moyopy(struct) spglib_label = get_protostructure_label_from_spglib(struct) - assert moyopy_label == spglib_label + assert ( + moyopy_label == spglib_label + ), f"spglib moyopy protostructure mismatch for {protostructure}" @pytest.mark.skipif(pyxtal is None, reason="pyxtal not installed") From 4d885bc59fce1a96549f631b0c47ebc8e2ccf757 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 27 Jan 2025 09:18:08 -0500 Subject: [PATCH 07/16] pin moyopy>=0.3.1 --- pyproject.toml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df995e9b..30da10e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,9 +49,9 @@ dependencies = [ Repo = "https://github.com/CompRhys/aviary" [project.optional-dependencies] -test = ["matminer", "moyopy", "pytest", "pytest-cov", "pyxtal"] +test = ["matminer", "moyopy>=0.3.1", "pytest", "pytest-cov", "pyxtal"] pyxtal = ["pyxtal"] -moyopy = ["moyopy"] +moyopy = ["moyopy>=0.3.1"] [tool.setuptools.packages] find = { include = ["aviary*"], exclude = ["tests*"] } @@ -107,16 +107,16 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "C408", # Unnecessary dict call - rewrite as a literal - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - "D205", # 1 blank line required between summary line and description - "E731", # Do not assign a lambda expression, use a def + "C408", # Unnecessary dict call - rewrite as a literal + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "E731", # Do not assign a lambda expression, use a def "ISC001", - "PD901", # pandas-df-variable-name - "PLR", # pylint refactor - "PT006", # pytest-parametrize-names-wrong-type + "PD901", # pandas-df-variable-name + "PLR", # pylint refactor + "PT006", # pytest-parametrize-names-wrong-type ] pydocstyle.convention = "google" isort.known-third-party = ["wandb"] From 475d98526609cb7f628f72096c681ced17603ca0 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 2 Feb 2025 14:20:31 -0500 Subject: [PATCH 08/16] fix: orbit grouping code was not doing anything useful. --- aviary/wren/utils.py | 38 +++++++++++++++++++------------------- tests/test_wyckoff_ops.py | 21 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 98e05e43..0a9c01b3 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -9,7 +9,7 @@ from os.path import abspath, dirname, join from shutil import which from string import ascii_uppercase, digits -from typing import TYPE_CHECKING, Literal +from typing import Literal from monty.fractions import gcd from pymatgen.core import Composition, Structure @@ -23,8 +23,15 @@ pyxtal = None has_pyxtal = False -if TYPE_CHECKING: +try: import moyopy + from moyopy.interface import MoyoAdapter + + has_moyopy = True +except ImportError: + moyopy = None + MoyoAdapter = None + has_moyopy = False module_dir = dirname(abspath(__file__)) @@ -154,7 +161,8 @@ def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> st def get_pearson_symbol_from_moyo_dataset(moyo_data: moyopy.MoyoDataset) -> str: """Get the Pearson symbol for the structure from a MoyoDataset.""" - import moyopy + if not has_moyopy: + raise ImportError("moyopy not found, run pip install moyopy") # Get space group number and Wyckoff positions spg_num = moyo_data.number @@ -463,13 +471,10 @@ def get_protostructure_label_from_moyopy( explanation of failure if symmetry detection failed and `raise_errors` is False. """ - # Convert pymatgen Structure to Moyo Cell and get symmetry data - try: - import moyopy - from moyopy.interface import MoyoAdapter - except ImportError: - raise ImportError("moyopy not found, run pip install moyopy") from None + if not has_moyopy: + raise ImportError("moyopy not found, run pip install moyopy") + # Convert pymatgen Structure to Moyo Cell and get symmetry data moyo_cell = MoyoAdapter.from_structure(struct) moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=symprec) @@ -481,21 +486,16 @@ def get_protostructure_label_from_moyopy( # Group Wyckoff positions by orbit and element equivalent_wyckoff_labels = [] - orbit_groups: list[list[int]] = [] - current_orbit: list[int] = [] + orbit_groups: dict[int, list[int]] = {} # Group sites by orbit for idx, orbit_id in enumerate(moyo_data.orbits): - if not current_orbit or orbit_id == moyo_data.orbits[current_orbit[0]]: - current_orbit += [idx] - else: - orbit_groups += [current_orbit] - current_orbit = [idx] - if current_orbit: - orbit_groups += [current_orbit] + if orbit_id not in orbit_groups: + orbit_groups[orbit_id] = [] + orbit_groups[orbit_id].append(idx) # Create equivalent_wyckoff_labels from orbit groups - for orbit in orbit_groups: + for orbit in orbit_groups.values(): # All sites in an orbit have the same Wyckoff letter and element wyckoff = moyo_data.wyckoffs[orbit[0]] element = struct.species[orbit[0]] diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index abcaa337..66ea0cac 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -74,6 +74,25 @@ [1 / 3, 2 / 3, 1 / 2 + 0.3748], ], ), + Structure( + lattice=[[3.9, 0, 0], [0, 3.9, 0], [0, 0, 3.9]], + species=["Sr", "Ti", "O", "O", "O"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ), + Structure( + lattice=[[5.76, 0, 0], [0, 5.76, 0], [0, 0, 5.76]], + species=["Al", "Fe", "Fe", "Fe", "Al", "Fe", "Fe", "Fe"], + coords=[ + [0, 0, 0], + [0.25, 0.25, 0.25], + [0.5, 0.5, 0], + [0.75, 0.75, 0.25], + [0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0, 0.5], + [0.75, 0.25, 0.75], + ], + ), Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), ] @@ -82,6 +101,8 @@ "AB_cP2_221_a_b:Cl-Cs", "AB_cF8_216_a_c:O-Zn", "AB_hP4_186_b_b:O-Zn", + "A3BC_cP5_221_c_a_b:O-Sr-Ti", + "AB3_tP4_115_a_cg:Al-Fe", "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", ] From 2cd2304c89b8f524a3aaf78dc374e3e5f36788d3 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 3 Feb 2025 10:31:18 -0500 Subject: [PATCH 09/16] fea: use moyopy.SpaceGroupType to get the crystal system --- aviary/wren/utils.py | 36 +----------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 0a9c01b3..b9398f9f 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -110,40 +110,6 @@ def count_values_for_wyckoff( ) -def get_crystal_system(n: int) -> str: - """Get the crystal system for the structure, e.g. (triclinic, orthorhombic, - cubic, etc.). - - Mirrors method of SpacegroupAnalyzer.get_crystal_system(). - - Args: - n (int): Space group number - - Raises: - ValueError: on invalid space group numbers < 1 or > 230. - - Returns: - str: Crystal system for structure - """ - # Not using isinstance(n, int) to allow 0-decimal floats - if n != int(n) or not 0 < n < 231: - raise ValueError(f"Received invalid space group {n}") - - if 0 < n < 3: - return "triclinic" - if n < 16: - return "monoclinic" - if n < 75: - return "orthorhombic" - if n < 143: - return "tetragonal" - if n < 168: - return "trigonal" - if n < 195: - return "hexagonal" - return "cubic" - - def get_centering(spg_sym: str) -> str: """Get the centering for the structure, e.g. (A, B, C, S).""" return "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] @@ -173,7 +139,7 @@ def get_pearson_symbol_from_moyo_dataset(moyo_data: moyopy.MoyoDataset) -> str: centering = hall_entry.centering # Get crystal system from space group number instead of symbol - cry_sys = get_crystal_system(spg_num) + cry_sys = moyopy.SpaceGroupType(spg_num).crystal_system.lower() # Get centering from first letter of space group symbol # Handle special case for C-centered From 09a359d1ca663ba22f29218cde9ab54f9f379733 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 4 Feb 2025 22:19:01 -0500 Subject: [PATCH 10/16] fea: use moyo pearson number --- aviary/wren/utils.py | 37 ++++++------------------------------- pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index b9398f9f..ad977225 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -50,7 +50,7 @@ for spg_num, vals in relab_dict.items() } -cry_sys_dict = { +CRYSTAL_FAMILY_SYMBOLS = { "triclinic": "a", "monoclinic": "m", "orthorhombic": "o", @@ -60,7 +60,7 @@ "cubic": "c", } -cry_param_dict = { +CRYSTAL_LATTICE_PARAMETERS_COUNTS = { "a": 6, "m": 4, "o": 3, @@ -122,32 +122,7 @@ def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> st centering = get_centering(spg_sym) num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) - return f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" - - -def get_pearson_symbol_from_moyo_dataset(moyo_data: moyopy.MoyoDataset) -> str: - """Get the Pearson symbol for the structure from a MoyoDataset.""" - if not has_moyopy: - raise ImportError("moyopy not found, run pip install moyopy") - - # Get space group number and Wyckoff positions - spg_num = moyo_data.number - - # Get crystal system and centering from Hall symbol entry - hall_entry = moyopy.HallSymbolEntry(hall_number=moyo_data.hall_number) - spg_sym = hall_entry.hm_short - centering = hall_entry.centering - - # Get crystal system from space group number instead of symbol - cry_sys = moyopy.SpaceGroupType(spg_num).crystal_system.lower() - - # Get centering from first letter of space group symbol - # Handle special case for C-centered - centering = get_centering(spg_sym) - - # Get number of sites in conventional cell - num_sites_conventional = len(moyo_data.std_cell.numbers) - return f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + return f"{CRYSTAL_FAMILY_SYMBOLS[cry_sys]}{centering}{num_sites_conventional}" def get_protostructure_label( @@ -330,7 +305,7 @@ def get_protostructure_label_from_spg_analyzer( sym_struct = spg_analyzer.get_symmetrized_structure() spg_num = spg_analyzer.get_space_group_number() - pearson_symbol = get_pearson_symbol_from_spg_analyzer(spg_analyzer) + pearson_symbol = spg_analyzer.get_pearson_symbol() prototype_form = get_prototype_formula_from_composition(sym_struct.composition) chemsys = sym_struct.chemical_system @@ -446,7 +421,7 @@ def get_protostructure_label_from_moyopy( # Get space group number and Pearson symbol spg_num = moyo_data.number - pearson_symbol = get_pearson_symbol_from_moyo_dataset(moyo_data) + pearson_symbol = moyo_data.pearson_symbol prototype_form = get_prototype_formula_from_composition(struct.composition) chemsys = struct.chemical_system @@ -666,7 +641,7 @@ def count_crystal_dof(protostructure_label: str) -> int: return ( _count_from_dict(element_wyckoffs, param_dict, spg_num) - + cry_param_dict[pearson_symbol[0]] + + CRYSTAL_LATTICE_PARAMETERS_COUNTS[pearson_symbol[0]] ) diff --git a/pyproject.toml b/pyproject.toml index 30da10e1..4f76ea05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ Repo = "https://github.com/CompRhys/aviary" [project.optional-dependencies] test = ["matminer", "moyopy>=0.3.1", "pytest", "pytest-cov", "pyxtal"] pyxtal = ["pyxtal"] -moyopy = ["moyopy>=0.3.1"] +moyopy = ["moyopy>=0.3.3"] [tool.setuptools.packages] find = { include = ["aviary*"], exclude = ["tests*"] } From 43cb948358d8cb899708bacd26846172f31a4845 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 5 Feb 2025 22:29:17 -0500 Subject: [PATCH 11/16] fix: avoid spglib warning --- aviary/wren/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index ad977225..9d8336f0 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -121,7 +121,7 @@ def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> st spg_sym = spg_analyzer.get_space_group_symbol() centering = get_centering(spg_sym) - num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) + num_sites_conventional = len(spg_analyzer.get_symmetry_dataset().std_types) return f"{CRYSTAL_FAMILY_SYMBOLS[cry_sys]}{centering}{num_sites_conventional}" From 5e22e2c65d4d2c5a75bc065400a2aab904b49f30 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 5 Feb 2025 22:39:27 -0500 Subject: [PATCH 12/16] fix: use the util here until added to pmg sga --- aviary/wren/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 9d8336f0..e8aa5f82 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -305,7 +305,7 @@ def get_protostructure_label_from_spg_analyzer( sym_struct = spg_analyzer.get_symmetrized_structure() spg_num = spg_analyzer.get_space_group_number() - pearson_symbol = spg_analyzer.get_pearson_symbol() + pearson_symbol = get_pearson_symbol_from_spg_analyzer(spg_analyzer) prototype_form = get_prototype_formula_from_composition(sym_struct.composition) chemsys = sym_struct.chemical_system From b1881978c14894b7b6a0d4a19610a392410b1fb5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 28 Feb 2025 19:00:38 -0500 Subject: [PATCH 13/16] fix: update relabelling table to include edge case in spg 15 --- aviary/wren/wyckoff-position-relabelings.json | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/aviary/wren/wyckoff-position-relabelings.json b/aviary/wren/wyckoff-position-relabelings.json index 99ebd0ad..d0ed532d 100644 --- a/aviary/wren/wyckoff-position-relabelings.json +++ b/aviary/wren/wyckoff-position-relabelings.json @@ -491,6 +491,38 @@ "100": "c", "101": "e", "102": "f" + }, + { + "97": "c", + "98": "d", + "99": "a", + "100": "b", + "101": "e", + "102": "f" + }, + { + "97": "c", + "98": "d", + "99": "b", + "100": "a", + "101": "e", + "102": "f" + }, + { + "97": "d", + "98": "c", + "99": "a", + "100": "b", + "101": "e", + "102": "f" + }, + { + "97": "d", + "98": "c", + "99": "b", + "100": "a", + "101": "e", + "102": "f" } ], "16": [ From 838a58029db3456c49cb3a8af31b004befd1b7c1 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 28 Feb 2025 19:19:50 -0500 Subject: [PATCH 14/16] fea: update moyo version in tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4f76ea05..e788d2b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ Repo = "https://github.com/CompRhys/aviary" [project.optional-dependencies] -test = ["matminer", "moyopy>=0.3.1", "pytest", "pytest-cov", "pyxtal"] +test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov", "pyxtal"] pyxtal = ["pyxtal"] moyopy = ["moyopy>=0.3.3"] From 9b51ca10fd4a2b347cd0841df4b27b7d8ba3e435 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 28 Feb 2025 19:26:31 -0500 Subject: [PATCH 15/16] fix: remove torch pin from test.yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5254cb4f..27cbae07 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | - pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu + pip install torch --index-url https://download.pytorch.org/whl/cpu uv pip install .[test] --system - name: Run Tests From fa9bbcd9f5fbbdc7edcf3dcc2bf9f6f05737e11a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 28 Feb 2025 19:52:43 -0500 Subject: [PATCH 16/16] fix: torch.load with weights_only=False --- aviary/predict.py | 4 +++- aviary/utils.py | 14 +++++++------- pyproject.toml | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index d56c33ab..70e1bf28 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -74,7 +74,9 @@ def make_ensemble_predictions( enumerate(tqdm(checkpoint_paths), start=1), disable=None if pbar else True ): try: - checkpoint = torch.load(checkpoint_path, map_location=device) + checkpoint = torch.load( + checkpoint_path, map_location=device, weights_only=False + ) except Exception as exc: raise RuntimeError(f"Failed to load {checkpoint_path=}") from exc diff --git a/aviary/utils.py b/aviary/utils.py index bbe35667..ecf9043e 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -66,7 +66,7 @@ def initialize_model( if fine_tune is not None: print(f"Use material_nn and output_nn from {fine_tune=} as a starting point") - checkpoint = torch.load(fine_tune, map_location=device) + checkpoint = torch.load(fine_tune, map_location=device, weights_only=False) # update the task disk to fine tuning task checkpoint["model_params"]["task_dict"] = model_params["task_dict"] @@ -93,7 +93,7 @@ def initialize_model( f"Use material_nn from {transfer=} as a starting point and " "train the output_nn from scratch" ) - checkpoint = torch.load(transfer, map_location=device) + checkpoint = torch.load(transfer, map_location=device, weights_only=False) model = model_class(**model_params) model.to(device) @@ -107,7 +107,7 @@ def initialize_model( elif resume: print(f"Resuming training from {resume=}") - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) model = model_class(**checkpoint["model_params"]) model.to(device) @@ -186,7 +186,7 @@ def initialize_optim( # TODO work out how to ensure that we are using the same optimizer # when resuming such that the state dictionaries do not clash. # TODO breaking the function apart means we load the checkpoint twice. - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) @@ -261,7 +261,7 @@ def init_normalizers( """ normalizer_dict: dict[str, Normalizer | None] = {} if resume: - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) for task, state_dict in checkpoint["normalizer_dict"].items(): normalizer_dict[task] = Normalizer.from_state_dict(state_dict) @@ -478,7 +478,7 @@ def results_multitask( resume = f"{ROOT}/models/{model_name}/{eval_type}-r{ens_idx}.pth.tar" print(f"Evaluating Model {ens_idx + 1}/{ensemble_folds}") - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) if checkpoint["model_params"]["robust"] != robust: raise ValueError(f"robustness of checkpoint {resume=} is not {robust}") @@ -821,7 +821,7 @@ def update_module_path_in_pickled_object( sys.modules[old_module_path] = new_module try: - dic = torch.load(pickle_path, map_location="cpu") + dic = torch.load(pickle_path, map_location="cpu", weights_only=False) except Exception as exc: raise PickleError(pickle_path) from exc diff --git a/pyproject.toml b/pyproject.toml index e788d2b9..0f05974e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,12 +35,12 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ - "numpy", + "numpy>=2,<3", "pandas", "pymatgen", "scikit_learn", "tensorboard", - "torch", + "torch>=2.3.0", "tqdm", "wandb", ]