diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5254cb4f..27cbae07 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | - pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu + pip install torch --index-url https://download.pytorch.org/whl/cpu uv pip install .[test] --system - name: Run Tests diff --git a/aviary/predict.py b/aviary/predict.py index d56c33ab..70e1bf28 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -74,7 +74,9 @@ def make_ensemble_predictions( enumerate(tqdm(checkpoint_paths), start=1), disable=None if pbar else True ): try: - checkpoint = torch.load(checkpoint_path, map_location=device) + checkpoint = torch.load( + checkpoint_path, map_location=device, weights_only=False + ) except Exception as exc: raise RuntimeError(f"Failed to load {checkpoint_path=}") from exc diff --git a/aviary/utils.py b/aviary/utils.py index bbe35667..ecf9043e 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -66,7 +66,7 @@ def initialize_model( if fine_tune is not None: print(f"Use material_nn and output_nn from {fine_tune=} as a starting point") - checkpoint = torch.load(fine_tune, map_location=device) + checkpoint = torch.load(fine_tune, map_location=device, weights_only=False) # update the task disk to fine tuning task checkpoint["model_params"]["task_dict"] = model_params["task_dict"] @@ -93,7 +93,7 @@ def initialize_model( f"Use material_nn from {transfer=} as a starting point and " "train the output_nn from scratch" ) - checkpoint = torch.load(transfer, map_location=device) + checkpoint = torch.load(transfer, map_location=device, weights_only=False) model = model_class(**model_params) model.to(device) @@ -107,7 +107,7 @@ def initialize_model( elif resume: print(f"Resuming training from {resume=}") - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) model = model_class(**checkpoint["model_params"]) model.to(device) @@ -186,7 +186,7 @@ def initialize_optim( # TODO work out how to ensure that we are using the same optimizer # when resuming such that the state dictionaries do not clash. # TODO breaking the function apart means we load the checkpoint twice. - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) @@ -261,7 +261,7 @@ def init_normalizers( """ normalizer_dict: dict[str, Normalizer | None] = {} if resume: - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) for task, state_dict in checkpoint["normalizer_dict"].items(): normalizer_dict[task] = Normalizer.from_state_dict(state_dict) @@ -478,7 +478,7 @@ def results_multitask( resume = f"{ROOT}/models/{model_name}/{eval_type}-r{ens_idx}.pth.tar" print(f"Evaluating Model {ens_idx + 1}/{ensemble_folds}") - checkpoint = torch.load(resume, map_location=device) + checkpoint = torch.load(resume, map_location=device, weights_only=False) if checkpoint["model_params"]["robust"] != robust: raise ValueError(f"robustness of checkpoint {resume=} is not {robust}") @@ -821,7 +821,7 @@ def update_module_path_in_pickled_object( sys.modules[old_module_path] = new_module try: - dic = torch.load(pickle_path, map_location="cpu") + dic = torch.load(pickle_path, map_location="cpu", weights_only=False) except Exception as exc: raise PickleError(pickle_path) from exc diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 0c81d272..e8aa5f82 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -9,6 +9,7 @@ from os.path import abspath, dirname, join from shutil import which from string import ascii_uppercase, digits +from typing import Literal from monty.fractions import gcd from pymatgen.core import Composition, Structure @@ -16,8 +17,21 @@ try: from pyxtal import pyxtal + + has_pyxtal = True except ImportError: pyxtal = None + has_pyxtal = False + +try: + import moyopy + from moyopy.interface import MoyoAdapter + + has_moyopy = True +except ImportError: + moyopy = None + MoyoAdapter = None + has_moyopy = False module_dir = dirname(abspath(__file__)) @@ -36,7 +50,7 @@ for spg_num, vals in relab_dict.items() } -cry_sys_dict = { +CRYSTAL_FAMILY_SYMBOLS = { "triclinic": "a", "monoclinic": "m", "orthorhombic": "o", @@ -46,7 +60,7 @@ "cubic": "c", } -cry_param_dict = { +CRYSTAL_LATTICE_PARAMETERS_COUNTS = { "a": 6, "m": 4, "o": 3, @@ -96,10 +110,55 @@ def count_values_for_wyckoff( ) +def get_centering(spg_sym: str) -> str: + """Get the centering for the structure, e.g. (A, B, C, S).""" + return "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] + + +def get_pearson_symbol_from_spg_analyzer(spg_analyzer: SpacegroupAnalyzer) -> str: + """Get the Pearson symbol for the structure.""" + cry_sys = spg_analyzer.get_crystal_system() + spg_sym = spg_analyzer.get_space_group_symbol() + centering = get_centering(spg_sym) + + num_sites_conventional = len(spg_analyzer.get_symmetry_dataset().std_types) + return f"{CRYSTAL_FAMILY_SYMBOLS[cry_sys]}{centering}{num_sites_conventional}" + + +def get_protostructure_label( + struct: Structure, + method: Literal["aflow", "spglib", "moyopy"], + raise_errors: bool = False, + **kwargs, +) -> str | None: + """Get protostructure label for a pymatgen Structure. + + Args: + struct (Structure): pymatgen Structure + method (Literal["aflow", "spglib", "moyopy"]): Method to use for symmetry + detection + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + **kwargs: Additional arguments for the specific method + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + if method == "aflow": + return get_protostructure_label_from_aflow(struct, raise_errors, **kwargs) + if method == "spglib": + return get_protostructure_label_from_spglib(struct, raise_errors, **kwargs) + if method == "moyopy": + return get_protostructure_label_from_moyopy(struct, raise_errors, **kwargs) + raise ValueError(f"Invalid method: {method}") + + def get_protostructure_label_from_aflow( struct: Structure, - aflow_executable: str | None = None, raise_errors: bool = False, + aflow_executable: str | None = None, ) -> str: """Get protostructure label for a pymatgen Structure. Make sure you're running a recent version of the aflow CLI as there's been several breaking changes. This code @@ -144,7 +203,7 @@ def get_protostructure_label_from_aflow( aflow_proto = json.loads(output.stdout) aflow_label = aflow_proto["aflow_prototype_label"] - chemsys = struct.composition.chemical_system + chemsys = struct.chemical_system # check that multiplicities satisfy original composition prototype_form, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") @@ -184,30 +243,22 @@ def get_protostructure_label_from_aflow( return protostructure_label -def get_protostructure_label_from_spg_analyzer( - spg_analyzer: SpacegroupAnalyzer, - raise_errors: bool = False, -) -> str: - """Get protostructure label for pymatgen SpacegroupAnalyzer. +def _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels: list[tuple[int, str, str]], + spg_num: int | str, +): + """Get Wyckoff position substring and element dict from equivalent Wyckoff labels. Args: - spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. - raise_errors (bool): Whether to raise errors or annotate them. Defaults to - False. + equivalent_wyckoff_labels (list[tuple[int, str, str]]): List of tuples containing + (multiplicity, element symbol, Wyckoff letter). + spg_num (int | str): Space group number. Returns: - str: protostructure_label which is constructed as `aflow_label:chemsys` or - explanation of failure if symmetry detection failed and `raise_errors` - is False. + tuple[str, dict]: Tuple containing: + - str: Wyckoff position substring + - dict: Dictionary mapping element symbols to their multiplicities """ - spg_num = spg_analyzer.get_space_group_number() - sym_struct = spg_analyzer.get_symmetrized_structure() - - equivalent_wyckoff_labels = [ - # tuple of (wp multiplicity, element, wyckoff letter) - (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) - for s, wyk_letter in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) - ] # Pre-sort by element and wyckoff letter to ensure continuous groups in groupby equivalent_wyckoff_labels = sorted( equivalent_wyckoff_labels, key=lambda x: (x[1], x[2]) @@ -222,25 +273,52 @@ def get_protostructure_label_from_spg_analyzer( element_dict[el] = sum( wyckoff_multiplicity_dict[str(spg_num)][e[2]] for e in list_group ) + # group by Wyckoff letter to get Wyckoff site multiplicity from len element_wyckoffs.append( "".join( - f"{len(list(w))}{wyk}" - for wyk, w in groupby(list_group, key=lambda x: x[2]) + f"{len(list(occurrences))}{wyk_letter}" + for wyk_letter, occurrences in groupby(list_group, key=lambda x: x[2]) ) ) + all_wyckoffs = "_".join(element_wyckoffs) + all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) - # get Pearson symbol - cry_sys = spg_analyzer.get_crystal_system() - spg_sym = spg_analyzer.get_space_group_symbol() - centering = "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] - num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) - pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" + return all_wyckoffs, element_dict + + +def get_protostructure_label_from_spg_analyzer( + spg_analyzer: SpacegroupAnalyzer, + raise_errors: bool = False, +) -> str: + """Get protostructure label for pymatgen SpacegroupAnalyzer. + Args: + spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + sym_struct = spg_analyzer.get_symmetrized_structure() + + spg_num = spg_analyzer.get_space_group_number() + pearson_symbol = get_pearson_symbol_from_spg_analyzer(spg_analyzer) prototype_form = get_prototype_formula_from_composition(sym_struct.composition) - chemsys = sym_struct.composition.chemical_system + chemsys = sym_struct.chemical_system - all_wyckoffs = "_".join(element_wyckoffs) - all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) + # get Wyckoff position substring + equivalent_wyckoff_labels = [ + # tuple of (wp multiplicity, element, wyckoff letter) + (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) + for s, wyk_letter in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) + ] + + all_wyckoffs, element_dict = _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels, spg_num + ) protostructure_label = ( f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" @@ -266,7 +344,7 @@ def get_protostructure_label_from_spglib( raise_errors: bool = False, init_symprec: float = 0.1, fallback_symprec: float | None = 1e-5, -) -> str | None: +) -> str: """Get AFLOW prototype label for pymatgen Structure. Args: @@ -316,6 +394,79 @@ def get_protostructure_label_from_spglib( raise +def get_protostructure_label_from_moyopy( + struct: Structure, + raise_errors: bool = False, + symprec: float = 0.1, +) -> str | None: + """Get AFLOW prototype label using Moyopy for symmetry detection. + + Args: + struct (Structure): pymatgen Structure object. + raise_errors (bool): Whether to raise errors or annotate them. Defaults to + False. + symprec (float): Initial symmetry precision for Moyopy. Defaults to 0.1. + + Returns: + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. + """ + if not has_moyopy: + raise ImportError("moyopy not found, run pip install moyopy") + + # Convert pymatgen Structure to Moyo Cell and get symmetry data + moyo_cell = MoyoAdapter.from_structure(struct) + moyo_data = moyopy.MoyoDataset(moyo_cell, symprec=symprec) + + # Get space group number and Pearson symbol + spg_num = moyo_data.number + pearson_symbol = moyo_data.pearson_symbol + prototype_form = get_prototype_formula_from_composition(struct.composition) + chemsys = struct.chemical_system + + # Group Wyckoff positions by orbit and element + equivalent_wyckoff_labels = [] + orbit_groups: dict[int, list[int]] = {} + + # Group sites by orbit + for idx, orbit_id in enumerate(moyo_data.orbits): + if orbit_id not in orbit_groups: + orbit_groups[orbit_id] = [] + orbit_groups[orbit_id].append(idx) + + # Create equivalent_wyckoff_labels from orbit groups + for orbit in orbit_groups.values(): + # All sites in an orbit have the same Wyckoff letter and element + wyckoff = moyo_data.wyckoffs[orbit[0]] + element = struct.species[orbit[0]] + equivalent_wyckoff_labels += [ + (len(orbit), element.symbol, wyckoff.translate(remove_digits)) + ] + + all_wyckoffs, element_dict = _get_all_wyckoffs_substring_and_element_dict( + equivalent_wyckoff_labels, spg_num + ) + + protostructure_label = ( + f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" + ) + + # Verify multiplicities match composition + observed_formula = Composition(element_dict).reduced_formula + expected_formula = struct.composition.reduced_formula + if observed_formula != expected_formula: + err_msg = ( + f"Invalid WP multiplicities - {protostructure_label}, expected " + f"{observed_formula} to be {expected_formula}" + ) + if raise_errors: + raise ValueError(err_msg) + return err_msg + + return protostructure_label + + def canonicalize_element_wyckoffs(element_wyckoffs: str, spg_num: int | str) -> str: """Given an element ordering, canonicalize the associated Wyckoff positions based on the alphabetical weight of equivalent choices of origin. @@ -490,7 +641,7 @@ def count_crystal_dof(protostructure_label: str) -> int: return ( _count_from_dict(element_wyckoffs, param_dict, spg_num) - + cry_param_dict[pearson_symbol[0]] + + CRYSTAL_LATTICE_PARAMETERS_COUNTS[pearson_symbol[0]] ) @@ -697,7 +848,7 @@ def get_random_structure_for_protostructure( sorted chemical system. **kwargs: Keyword arguments to pass to pyxtal().from_random() """ - if pyxtal is None: + if not has_pyxtal: raise ImportError("pyxtal is required for this function") aflow_label, chemsys = protostructure_label.split(":") diff --git a/aviary/wren/wyckoff-position-relabelings.json b/aviary/wren/wyckoff-position-relabelings.json index 99ebd0ad..d0ed532d 100644 --- a/aviary/wren/wyckoff-position-relabelings.json +++ b/aviary/wren/wyckoff-position-relabelings.json @@ -491,6 +491,38 @@ "100": "c", "101": "e", "102": "f" + }, + { + "97": "c", + "98": "d", + "99": "a", + "100": "b", + "101": "e", + "102": "f" + }, + { + "97": "c", + "98": "d", + "99": "b", + "100": "a", + "101": "e", + "102": "f" + }, + { + "97": "d", + "98": "c", + "99": "a", + "100": "b", + "101": "e", + "102": "f" + }, + { + "97": "d", + "98": "c", + "99": "b", + "100": "a", + "101": "e", + "102": "f" } ], "16": [ diff --git a/pyproject.toml b/pyproject.toml index fa6be25e..0f05974e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aviary" -version = "1.1.0" +version = "1.1.1" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" @@ -35,12 +35,12 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ - "numpy<2", + "numpy>=2,<3", "pandas", "pymatgen", "scikit_learn", "tensorboard", - "torch", + "torch>=2.3.0", "tqdm", "wandb", ] @@ -49,8 +49,9 @@ dependencies = [ Repo = "https://github.com/CompRhys/aviary" [project.optional-dependencies] -test = ["matminer", "pytest", "pytest-cov", "pyxtal"] +test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov", "pyxtal"] pyxtal = ["pyxtal"] +moyopy = ["moyopy>=0.3.3"] [tool.setuptools.packages] find = { include = ["aviary*"], exclude = ["tests*"] } @@ -106,16 +107,16 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "C408", # Unnecessary dict call - rewrite as a literal - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - "D205", # 1 blank line required between summary line and description - "E731", # Do not assign a lambda expression, use a def + "C408", # Unnecessary dict call - rewrite as a literal + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "E731", # Do not assign a lambda expression, use a def "ISC001", - "PD901", # pandas-df-variable-name - "PLR", # pylint refactor - "PT006", # pytest-parametrize-names-wrong-type + "PD901", # pandas-df-variable-name + "PLR", # pylint refactor + "PT006", # pytest-parametrize-names-wrong-type ] pydocstyle.convention = "google" isort.known-third-party = ["wandb"] diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 5475199d..66ea0cac 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -1,10 +1,10 @@ import inspect import re -from itertools import permutations +from itertools import permutations, product from shutil import which import pytest -from pymatgen.core.structure import Composition, Structure +from pymatgen.core.structure import Composition, Lattice, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from aviary.wren.utils import ( @@ -15,7 +15,9 @@ count_wyckoff_positions, get_anonymous_formula_from_prototype_formula, get_formula_from_protostructure_label, + get_protostructure_label, get_protostructure_label_from_aflow, + get_protostructure_label_from_moyopy, get_protostructure_label_from_spg_analyzer, get_protostructure_label_from_spglib, get_protostructures_from_aflow_label_and_composition, @@ -44,14 +46,71 @@ ("AB3C_cP5_221_a_c_b:Ba-O-Ti"), ] +TEST_STRUCTS = [ + Structure( # NaCl structure + lattice=[[2, 2, 0], [0, 2, 2], [2, 0, 2]], + species=["Na", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], + ), + Structure( # CsCl structure + lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], + species=["Cs", "Cl"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5]], + ), + Structure( # ZnO zincblende structure + lattice=[[2, 2, 0], [0, 2, 2], [2, 0, 2]], + species=["Zn", "O"], + coords=[[0, 0, 0], [0.25, 0.25, 0.25]], + ), + Structure( # ZnO wurtzite structure + lattice=Lattice.from_parameters( + a=3.8227, b=3.8227, c=6.2607, alpha=90, beta=90, gamma=120 + ), + species=["Zn", "O", "Zn", "O"], + coords=[ + [1 / 3, 2 / 3, 0], + [2 / 3, 1 / 3, 0.3748], + [2 / 3, 1 / 3, 1 / 2], + [1 / 3, 2 / 3, 1 / 2 + 0.3748], + ], + ), + Structure( + lattice=[[3.9, 0, 0], [0, 3.9, 0], [0, 0, 3.9]], + species=["Sr", "Ti", "O", "O", "O"], + coords=[[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ), + Structure( + lattice=[[5.76, 0, 0], [0, 5.76, 0], [0, 0, 5.76]], + species=["Al", "Fe", "Fe", "Fe", "Al", "Fe", "Fe", "Fe"], + coords=[ + [0, 0, 0], + [0.25, 0.25, 0.25], + [0.5, 0.5, 0], + [0.75, 0.75, 0.25], + [0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0, 0.5], + [0.75, 0.25, 0.75], + ], + ), + Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif"), +] -def test_get_protostructure_label_from_spglib(): - """Check that spglib gives correct protostructure label for esseneite""" - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - assert ( - get_protostructure_label_from_spglib(struct) - == "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" - ) +TEST_PROTOSTRUCTURES = [ + "AB_cF8_225_a_b:Cl-Na", + "AB_cP2_221_a_b:Cl-Cs", + "AB_cF8_216_a_c:O-Zn", + "AB_hP4_186_b_b:O-Zn", + "A3BC_cP5_221_c_a_b:O-Sr-Ti", + "AB3_tP4_115_a_cg:Al-Fe", + "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", +] + + +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_spglib(structure, expected): + """Check that spglib gives correct protostructure label simple cases.""" + assert get_protostructure_label_from_spglib(structure) == expected def test_get_protostructure_label_from_spglib_edge_case(): @@ -278,13 +337,37 @@ def test_count_distinct_wyckoff_letters(protostructure_label, expected): @pytest.mark.skipif(which("aflow") is None, reason="AFLOW CLI not installed") -def test_get_protostructure_label_from_aflow(): - """Check we extract correct protostructure label for esseneite using AFLOW CLI.""" - struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_aflow(structure, expected): + """Check that AFLOW CLI gives correct protostructure label simple cases.""" + assert ( + get_protostructure_label_from_aflow(structure, aflow_executable=which("aflow")) + == expected + ) + + +@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES)) +def test_get_protostructure_label_from_moyopy(structure, expected): + """Check that moyopy gives correct protostructure label simple cases.""" + assert ( + get_protostructure_label_from_moyopy(structure) == expected + ), f"unexpected moyopy protostructure for {structure=}" + - out = get_protostructure_label_from_aflow(struct, which("aflow")) - expected = "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" - assert out == expected +@pytest.mark.parametrize( + "protostructure", + PROTOSTRUCTURE_SET, +) +def test_moyopy_spglib_consistency(protostructure): + """Check that moyopy and spglib give consistent results.""" + struct = get_random_structure_for_protostructure(protostructure) + + moyopy_label = get_protostructure_label_from_moyopy(struct) + spglib_label = get_protostructure_label_from_spglib(struct) + + assert ( + moyopy_label == spglib_label + ), f"spglib moyopy protostructure mismatch for {protostructure}" @pytest.mark.skipif(pyxtal is None, reason="pyxtal not installed") @@ -292,13 +375,16 @@ def test_get_protostructure_label_from_aflow(): reason="pyxtal is non-deterministic and symmetry can increase in random crystal" ) @pytest.mark.parametrize( - "protostructure", - PROTOSTRUCTURE_SET, + "protostructure, method", + list(product(PROTOSTRUCTURE_SET, ["spglib", "moyopy"])), ) -def test_get_random_structure_for_protostructure_roundtrip(protostructure): +def test_get_random_structure_for_protostructure_roundtrip( + protostructure: str, method: str +): """Check roundtrip for generating a random structure from a prototype string""" - assert protostructure == get_protostructure_label_from_spglib( - get_random_structure_for_protostructure(protostructure) + assert protostructure == get_protostructure_label( + get_random_structure_for_protostructure(protostructure), + method=method, ) @@ -314,3 +400,9 @@ def test_get_random_structure_for_protostructure_random(protostructure): assert s1.composition == s2.composition assert s1.lattice != s2.lattice + + +if __name__ == "__main__": + import pytest + + pytest.main(["-v", __file__])