From b7655b5d614ad08cfedff9fea127b40b4b89a07e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 22 Jul 2024 14:46:32 -0400 Subject: [PATCH] Adopt protostructure naming (#84) * doc: adopt protostructure naming * lint: sort imports in notebooks * clean: rename last functions * fix outdated import pymatviz.(utils->powerups).add_identity_line * ruff auto-fixes * fix ruff aviary/utils.py:732:5: PLC0206 Extracting value from dictionary without calling `.items()` and aviary/roost/data.py:116:13: PLR1704 Redefining argument with the local name `idx` * fix save_results_dict doc string * fea: bump version for breaking change * fea: rename parse function used in wren data --------- Co-authored-by: Janosh Riebesell --- .pre-commit-config.yaml | 4 +- aviary/roost/data.py | 4 +- aviary/segments.py | 8 +- aviary/train.py | 13 +- aviary/utils.py | 54 +-- aviary/wren/data.py | 22 +- aviary/wren/utils.py | 451 ++++++++++-------- aviary/wrenformer/data.py | 12 +- examples/cgcnn-example.py | 33 +- examples/inputs/poscar_to_df.py | 7 +- examples/notebooks/Roost.ipynb | 9 +- examples/notebooks/Wren.ipynb | 7 +- examples/roost-example.py | 33 +- examples/wren-example.py | 31 +- .../compare_spglib_vs_aflow_wyckoff_labels.py | 22 +- examples/wrenformer/mat_bench/make_plots.py | 2 +- .../mat_bench/plotting_functions.py | 2 +- .../mat_bench/save_matbench_aflow_labels.py | 4 +- pyproject.toml | 4 +- tests/conftest.py | 6 +- tests/test_wyckoff_ops.py | 133 ++++-- 21 files changed, 467 insertions(+), 394 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c3d5656..6cda0b14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.5.3 hooks: - id: ruff args: [--fix] @@ -30,7 +30,7 @@ repos: args: [--check-filenames] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.1 + rev: v1.11.0 hooks: - id: mypy exclude: (tests|examples)/ diff --git a/aviary/roost/data.py b/aviary/roost/data.py index 02b54d65..07af42c7 100644 --- a/aviary/roost/data.py +++ b/aviary/roost/data.py @@ -113,8 +113,8 @@ def __getitem__(self, idx: int): n_elems = len(elements) self_idx = [] nbr_idx = [] - for idx in range(n_elems): - self_idx += [idx] * n_elems + for elem_idx in range(n_elems): + self_idx += [elem_idx] * n_elems nbr_idx += list(range(n_elems)) # convert all data to tensors diff --git a/aviary/segments.py b/aviary/segments.py index e6eb5593..66c78190 100644 --- a/aviary/segments.py +++ b/aviary/segments.py @@ -38,9 +38,9 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor: """ gate = self.gate_nn(x) - gate = gate - scatter_max(gate, index, dim=0)[0][index] + gate -= scatter_max(gate, index, dim=0)[0][index] gate = gate.exp() - gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10) + gate /= scatter_add(gate, index, dim=0)[index] + 1e-10 x = self.message_nn(x) return scatter_add(gate * x, index, dim=0) @@ -78,9 +78,9 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor: """ gate = self.gate_nn(x) - gate = gate - scatter_max(gate, index, dim=0)[0][index] + gate -= scatter_max(gate, index, dim=0)[0][index] gate = (weights**self.pow) * gate.exp() - gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10) + gate /= scatter_add(gate, index, dim=0)[index] + 1e-10 x = self.message_nn(x) return scatter_add(gate * x, index, dim=0) diff --git a/aviary/train.py b/aviary/train.py index 7500b6d8..2b6423e4 100644 --- a/aviary/train.py +++ b/aviary/train.py @@ -246,14 +246,13 @@ def train_model( print("Starting stochastic weight averaging...") swa_model.update_parameters(model) swa_scheduler.step() + elif scheduler_name == "ReduceLROnPlateau": + val_metric = val_metrics[target_col][ + "MAE" if task_type == reg_key else "Accuracy" + ] + lr_scheduler.step(val_metric) else: - if scheduler_name == "ReduceLROnPlateau": - val_metric = val_metrics[target_col][ - "MAE" if task_type == reg_key else "Accuracy" - ] - lr_scheduler.step(val_metric) - else: - lr_scheduler.step() + lr_scheduler.step() model.epoch += 1 diff --git a/aviary/utils.py b/aviary/utils.py index 4b91fdb0..4c1e22e6 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -237,15 +237,12 @@ def initialize_losses( raise NameError( "Only L1 or L2 losses are allowed for robust regression tasks" ) + elif loss_name_dict[name] == "L1": + loss_func_dict[name] = (task, L1Loss()) + elif loss_name_dict[name] == "L2": + loss_func_dict[name] = (task, MSELoss()) else: - if loss_name_dict[name] == "L1": - loss_func_dict[name] = (task, L1Loss()) - elif loss_name_dict[name] == "L2": - loss_func_dict[name] = (task, MSELoss()) - else: - raise NameError( - "Only L1 or L2 losses are allowed for regression tasks" - ) + raise NameError("Only L1 or L2 losses are allowed for regression tasks") return loss_func_dict @@ -723,46 +720,35 @@ def save_results_dict( """Save the results to a file after model evaluation. Args: - ids (dict[str, list[str | int]]): ): Each key is the name of an identifier + ids (dict[str, list[str | int]]): Each key is the name of an identifier (e.g. material ID, composition, ...) and its value a list of IDs. - results_dict (dict[str, Any]): ): nested dictionary of results - {name: {col: data}} - model_name (str): ): The name given the model via the --model-name flag. - run_id (str): ): The run ID given to the model via the --run-id flag. + results_dict (dict[str, Any]): nested dictionary of results {name: {col: data}} + model_name (str): The name given the model via the --model-name flag. + run_id (str): The run ID given to the model via the --run-id flag. """ - results = {} + results: dict[str, np.ndarray] = {} - for target_name in results_dict: - for col, data in results_dict[target_name].items(): + for target_name, target_data in results_dict.items(): + for col, data in target_data.items(): # NOTE we save pre_logits rather than logits due to fact # that with the heteroskedastic setup we want to be able to # sample from the Gaussian distributed pre_logits we parameterize. if "pre-logits" in col: for n_ens, y_pre_logit in enumerate(data): - results.update( - { - f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel() - for lab, val in enumerate(y_pre_logit.T) - } - ) + results |= { + f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel() + for lab, val in enumerate(y_pre_logit.T) + } - elif "pred" in col: - preds = { + elif "pred" in col or "ale" in col: + # elif so that pre-logit-ale doesn't trigger + results |= { f"{target_name}_{col}_n{n_ens}": val.ravel() for (n_ens, val) in enumerate(data) } - results.update(preds) - - elif "ale" in col: # elif so that pre-logit-ale doesn't trigger - results.update( - { - f"{target_name}_{col}_n{n_ens}": val.ravel() - for (n_ens, val) in enumerate(data) - } - ) elif col == "target": - results.update({f"{target_name}_target": data}) + results |= {f"{target_name}_target": data} df = pd.DataFrame({**ids, **results}) diff --git a/aviary/wren/data.py b/aviary/wren/data.py index 8cc629ca..5d1b4b5f 100644 --- a/aviary/wren/data.py +++ b/aviary/wren/data.py @@ -108,10 +108,10 @@ def __getitem__(self, idx: int): - list[str | int]: identifiers like material_id, composition """ row = self.df.iloc[idx] - wyckoff_str = row[self.inputs] + protostructure_label = row[self.inputs] material_ids = row[self.identifiers].to_list() - parsed_output = parse_aflow_wyckoff_str(wyckoff_str) + parsed_output = parse_protostructure_label(protostructure_label) spg_num, wyk_site_multiplcities, elements, augmented_wyks = parsed_output wyk_site_multiplcities = np.atleast_2d(wyk_site_multiplcities).T / np.sum( @@ -256,21 +256,29 @@ def collate_batch( ) -def parse_aflow_wyckoff_str( - aflow_label: str, +def parse_protostructure_label( + protostructure_label: str, ) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]: """Parse the Wren AFLOW-like Wyckoff encoding. Args: - aflow_label (str): AFLOW-style prototype string with appended chemical system + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: tuple[str, list[float], list[str], list[str]]: spacegroup number, Wyckoff site multiplicities, elements symbols and equivalent wyckoff sets """ - proto, chemsys = aflow_label.split(":") + aflow_label, chemsys = protostructure_label.split(":") elems = chemsys.split("-") - _, _, spg_num, *wyckoff_letters = proto.split("_") + _, _, spg_num, *wyckoff_letters = aflow_label.split("_") + + if len(elems) != len(wyckoff_letters): + raise ValueError( + f"Chemical system {chemsys} does not match Wyckoff letters " + f"{wyckoff_letters}" + ) wyckoff_site_multiplicities = [] elements = [] diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 8fe4a21e..4baae435 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -32,8 +32,8 @@ relab_dict = json.load(file) relab_dict = { - spg: [{int(key): line for key, line in val.items()} for val in vals] - for spg, vals in relab_dict.items() + spg_num: [{int(key): line for key, line in val.items()} for val in vals] + for spg_num, vals in relab_dict.items() } cry_sys_dict = { @@ -84,23 +84,27 @@ def split_alpha_numeric(s: str) -> dict[str, list[str]]: def count_values_for_wyckoff( - wyckoff: list[str], - multiplicity: list[str], - spg: str, + element_wyckoffs: list[str], + counts: list[str], + spg_num: str, lookup_dict: dict[str, dict[str, int]], ): """Count values from a lookup table and scale by wyckoff multiplicities.""" - return sum(int(n) * lookup_dict[spg][k] for n, k in zip(multiplicity, wyckoff)) + return sum( + int(count) * lookup_dict[spg_num][wyckoff_letter] + for count, wyckoff_letter in zip(counts, element_wyckoffs) + ) -def get_aflow_label_from_aflow( +def get_protostructure_label_from_aflow( struct: Structure, aflow_executable: str | None = None, raise_errors: bool = False, ) -> str: - """Get Aflow prototype label for a pymatgen Structure. Make sure you're running a + """Get protostructure label for a pymatgen Structure. Make sure you're running a recent version of the aflow CLI as there's been several breaking changes. This code - was tested under v3.2.12. + was tested under v3.2.12. The protostructure label is constructed as + `aflow_label:chemsys`. Install guide: https://aflow.org/install-aflow/#install_aflow http://aflow.org/install-aflow/install-aflow.sh -o install-aflow.sh @@ -114,8 +118,9 @@ def get_aflow_label_from_aflow( False. Returns: - str: AFLOW prototype label or explanation of failure if symmetry detection - failed and raise_errors is False. + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. """ if aflow_executable is None: aflow_executable = which("aflow") @@ -139,30 +144,36 @@ def get_aflow_label_from_aflow( aflow_proto = json.loads(output.stdout) aflow_label = aflow_proto["aflow_prototype_label"] - chem_sys = struct.composition.chemical_system - full_label = f"{aflow_label}:{chem_sys}" - + chemsys = struct.composition.chemical_system # check that multiplicities satisfy original composition - _, _, spg_num, *wyckoff_letters = aflow_label.split("_") - elem_dict = {} - for elem, wyk_letters_per_elem in zip(chem_sys.split("-"), wyckoff_letters): + prototype_form, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") + + element_dict = {} + for elem, wyk_letters_per_elem in zip(chemsys.split("-"), element_wyckoffs): # normalize Wyckoff letters to start with 1 if missing digit wyk_letters_normalized = re.sub( RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyk_letters_per_elem ) sep_el_wyks = split_alpha_numeric(wyk_letters_normalized) - elem_dict[elem] = count_values_for_wyckoff( + element_dict[elem] = count_values_for_wyckoff( sep_el_wyks["alpha"], sep_el_wyks["numeric"], spg_num, wyckoff_multiplicity_dict, ) - observed_formula = Composition(elem_dict).reduced_formula + element_wyckoffs = "_".join(element_wyckoffs) + element_wyckoffs = canonicalize_element_wyckoffs(element_wyckoffs, spg_num) + + protostructure_label = ( + f"{prototype_form}_{pearson_symbol}_{spg_num}_{element_wyckoffs}:{chemsys}" + ) + + observed_formula = Composition(element_dict).reduced_formula expected_formula = struct.composition.reduced_formula if observed_formula != expected_formula: err_msg = ( - f"Invalid WP multiplicities - {full_label}, expected " + f"Invalid WP multiplicities - {protostructure_label}, expected " f"{observed_formula} to be {expected_formula}" ) if raise_errors: @@ -170,14 +181,14 @@ def get_aflow_label_from_aflow( return err_msg - return full_label + return protostructure_label -def get_aflow_label_from_spg_analyzer( +def get_protostructure_label_from_spg_analyzer( spg_analyzer: SpacegroupAnalyzer, raise_errors: bool = False, ) -> str: - """Get AFLOW prototype label for pymatgen SpacegroupAnalyzer. + """Get protostructure label for pymatgen SpacegroupAnalyzer. Args: spg_analyzer (SpacegroupAnalyzer): pymatgen SpacegroupAnalyzer object. @@ -185,38 +196,40 @@ def get_aflow_label_from_spg_analyzer( False. Returns: - str: AFLOW prototype label or explanation of failure if symmetry detection - failed and raise_errors is False. + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. """ spg_num = spg_analyzer.get_space_group_number() sym_struct = spg_analyzer.get_symmetrized_structure() equivalent_wyckoff_labels = [ + # tuple of (wp multiplicity, element, wyckoff letter) (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) for s, wyk_letter in zip( sym_struct.equivalent_sites, sym_struct.wyckoff_symbols ) ] + # Pre-sort by element and wyckoff letter to ensure continuous groups in groupby equivalent_wyckoff_labels = sorted( equivalent_wyckoff_labels, key=lambda x: (x[1], x[2]) ) # check that multiplicities satisfy original composition - elem_dict = {} - elem_wyks = [] - for el, g in groupby( - equivalent_wyckoff_labels, key=lambda x: x[1] - ): # sort alphabetically by element - lg = list(g) # NOTE create a list from the iterator so that we can reuse it - elem_dict[el] = sum(wyckoff_multiplicity_dict[str(spg_num)][e[2]] for e in lg) - wyks = "" - # sort groups alphabetically by wyckoff letter - for wyk, w in groupby(lg, key=lambda x: x[2]): - wyks += f"{len(list(w))}{wyk}" - elem_wyks.append(wyks) - - # canonicalize the possible wyckoff letter sequences - canonical = canonicalize_elem_wyks("_".join(elem_wyks), spg_num) + element_dict = {} + element_wyckoffs = [] + for el, group in groupby(equivalent_wyckoff_labels, key=lambda x: x[1]): + # NOTE create a list from the iterator so that we can use it without exhausting + list_group = list(group) + element_dict[el] = sum( + wyckoff_multiplicity_dict[str(spg_num)][e[2]] for e in list_group + ) + element_wyckoffs.append( + "".join( + f"{len(list(w))}{wyk}" + for wyk, w in groupby(list_group, key=lambda x: x[2]) + ) + ) # get Pearson symbol cry_sys = spg_analyzer.get_crystal_system() @@ -225,16 +238,21 @@ def get_aflow_label_from_spg_analyzer( num_sites_conventional = len(spg_analyzer.get_symmetry_dataset()["std_types"]) pearson_symbol = f"{cry_sys_dict[cry_sys]}{centering}{num_sites_conventional}" - prototype_form = prototype_formula(sym_struct.composition) + prototype_form = get_prototype_formula_from_composition(sym_struct.composition) + chemsys = sym_struct.composition.chemical_system + + all_wyckoffs = "_".join(element_wyckoffs) + all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, spg_num) - chem_sys = sym_struct.composition.chemical_system - full_label = f"{prototype_form}_{pearson_symbol}_{spg_num}_{canonical}:{chem_sys}" + protostructure_label = ( + f"{prototype_form}_{pearson_symbol}_{spg_num}_{all_wyckoffs}:{chemsys}" + ) - observed_formula = Composition(elem_dict).reduced_formula + observed_formula = Composition(element_dict).reduced_formula expected_formula = sym_struct.composition.reduced_formula if observed_formula != expected_formula: err_msg = ( - f"Invalid WP multiplicities - {full_label}, expected " + f"Invalid WP multiplicities - {protostructure_label}, expected " f"{observed_formula} to be {expected_formula}" ) if raise_errors: @@ -242,10 +260,10 @@ def get_aflow_label_from_spg_analyzer( return err_msg - return full_label + return protostructure_label -def get_aflow_label_from_spglib( +def get_protostructure_label_from_spglib( struct: Structure, raise_errors: bool = False, init_symprec: float = 0.1, @@ -262,8 +280,9 @@ def get_aflow_label_from_spglib( symmetry detection failed. Defaults to 1e-5. Returns: - str: AFLOW prototype label or explanation of failure if symmetry detection - failed and raise_errors is False. + str: protostructure_label which is constructed as `aflow_label:chemsys` or + explanation of failure if symmetry detection failed and `raise_errors` + is False. """ attempt_to_recover = False try: @@ -271,7 +290,7 @@ def get_aflow_label_from_spglib( struct, symprec=init_symprec, angle_tolerance=5 ) try: - aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( + aflow_label_with_chemsys = get_protostructure_label_from_spg_analyzer( spg_analyzer, raise_errors ) @@ -290,7 +309,7 @@ def get_aflow_label_from_spglib( symprec=fallback_symprec, angle_tolerance=-1, ) - aflow_label_with_chemsys = get_aflow_label_from_spg_analyzer( + aflow_label_with_chemsys = get_protostructure_label_from_spg_analyzer( spg_analyzer, raise_errors ) return aflow_label_with_chemsys @@ -301,40 +320,39 @@ def get_aflow_label_from_spglib( raise -def canonicalize_elem_wyks(elem_wyks: str, spg_num: int | str) -> str: +def canonicalize_element_wyckoffs(element_wyckoffs: str, spg_num: int | str) -> str: """Given an element ordering, canonicalize the associated Wyckoff positions based on the alphabetical weight of equivalent choices of origin. Args: - elem_wyks (str): Wren Wyckoff string encoding element types at Wyckoff positions + element_wyckoffs (str): wyckoff substring section from aflow_label with the + wyckoff letters for different elements separated by underscores. spg_num (int | str): International space group number. Returns: - str: Canonicalized Wren Wyckoff encoding. + str: element_wyckoff string with canonical ordering of the wyckoff letters. """ - isopointal = [] - - for trans in relab_dict[str(spg_num)]: - t = str.maketrans(trans) - isopointal.append(elem_wyks.translate(t)) - - isopointal = list(set(isopointal)) + isopointal_element_wyckoffs = list( + { + element_wyckoffs.translate(str.maketrans(trans)) + for trans in relab_dict[str(spg_num)] + } + ) - scores = [] - sorted_iso = [] - for wyks in isopointal: - sorted_el_wyks, score = sort_and_score_wyks(wyks) - scores.append(score) - sorted_iso.append(sorted_el_wyks) + scored_element_wyckoffs = [ + sort_and_score_element_wyckoffs(element_wyckoffs) + for element_wyckoffs in isopointal_element_wyckoffs + ] - return sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1]))[0][1] + return min(scored_element_wyckoffs, key=lambda x: (x[1], x[0]))[0] -def sort_and_score_wyks(wyks: str) -> tuple[str, int]: - """Determines the order or Wyckoff positions when canonicalizing Aflow labels. +def sort_and_score_element_wyckoffs(element_wyckoffs: str) -> tuple[str, int]: + """Determines the order or Wyckoff positions when canonicalizing AFLOW labels. Args: - wyks (str): Wyckoff position substring from AFLOW-style prototype label + element_wyckoffs (str): wyckoff substring section from aflow_label with the + wyckoff letters for different elements separated by underscores. Returns: tuple: containing @@ -342,26 +360,29 @@ def sort_and_score_wyks(wyks: str) -> tuple[str, int]: - int: integer score to rank order when canonicalizing """ score = 0 - sorted_el_wyks = [] - for el_wyks in wyks.split("_"): - sep_el_wyks = split_alpha_numeric(el_wyks) - sorted_el_wyks.append( + sorted_element_wyckoffs = [] + for el_wyks in element_wyckoffs.split("_"): + wp_counts = split_alpha_numeric(el_wyks) + sorted_element_wyckoffs.append( "".join( [ - f"{mult}{wyk}" if mult != "1" else wyk - for mult, wyk in sorted( - zip(sep_el_wyks["numeric"], sep_el_wyks["alpha"]), + f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter + for count, wyckoff_letter in sorted( + zip(wp_counts["numeric"], wp_counts["alpha"]), key=lambda x: x[1], ) ] ) ) - score += sum(0 if el == "A" else ord(el) - 96 for el in sep_el_wyks["alpha"]) + score += sum( + 0 if wyckoff_letter == "A" else ord(wyckoff_letter) - 96 + for wyckoff_letter in wp_counts["alpha"] + ) - return "_".join(sorted_el_wyks), score + return "_".join(sorted_element_wyckoffs), score -def prototype_formula(composition: Composition) -> str: +def get_prototype_formula_from_composition(composition: Composition) -> str: """An anonymized formula. Unique species are arranged in alphabetical order and assigned ascending alphabets. This format is used in the aflow structure prototype labelling scheme. @@ -390,7 +411,7 @@ def prototype_formula(composition: Composition) -> str: return anon -def get_anom_formula_from_prototype_formula(prototype_formula: str) -> str: +def get_anonymous_formula_from_prototype_formula(prototype_formula: str) -> str: """Get an anonymous formula from a prototype formula.""" prototype_formula = re.sub( RE_ELEMENT_NO_SUFFIX, RE_SUBST_ONE_SUFFIX, prototype_formula @@ -399,25 +420,44 @@ def get_anom_formula_from_prototype_formula(prototype_formula: str) -> str: return "".join( [ - f"{el}{num}" if num != "1" else el + f"{el}{num}" if num != 1 else el for el, num in zip( anom_list["alpha"], - sorted(anom_list["numeric"]), + sorted(map(int, anom_list["numeric"])), ) ] ) -def count_wyckoff_positions(aflow_label: str) -> int: - """Count number of Wyckoff positions in Wyckoff representation. +def count_distinct_wyckoff_letters(protostructure_label: str) -> int: + """Count number of distinct Wyckoff letters in protostructure_label. Args: - aflow_label (str): AFLOW-style prototype label with appended chemical system + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: - int: number of distinct Wyckoff positions + int: number of distinct Wyckoff letters in protostructure_label """ - aflow_label, _ = aflow_label.split(":") # remove chemical system + aflow_label, _ = protostructure_label.split(":") + _, _, _, element_wyckoffs = aflow_label.split("_", 3) + element_wyckoffs = element_wyckoffs.translate(remove_digits).replace("_", "") + return len(set(element_wyckoffs)) # number of distinct Wyckoff letters + + +def count_wyckoff_positions(protostructure_label: str) -> int: + """Count number of Wyckoff positions in protostructure_label. + + Args: + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. + + Returns: + int: number of distinct Wyckoff positions in protostructure_label + """ + aflow_label, _ = protostructure_label.split(":") # remove chemical system # discard prototype formula and spg symbol and spg number wyk_letters = aflow_label.split("_", maxsplit=3)[-1] # throw Wyckoff positions for all elements together @@ -428,136 +468,131 @@ def count_wyckoff_positions(aflow_label: str) -> int: return sum(1 if len(x) == 0 else int(x) for x in wyk_list) -def count_crystal_dof(aflow_label: str) -> int: - """Count number of free parameters in coarse-grained Wyckoff representation: how - many degrees of freedom would remain to optimize during a crystal structure - relaxation. +def count_crystal_dof(protostructure_label: str) -> int: + """Count number of free parameters in coarse-grained protostructure_label + representation: how many degrees of freedom would remain to optimize during + a crystal structure relaxation. Args: - aflow_label (str): AFLOW-style prototype label with appended chemical system + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: int: Number of free-parameters in given prototype """ - n_params = 0 + aflow_label, _ = protostructure_label.split(":") # chop off chemical system + _, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") - aflow_label, _ = aflow_label.split(":") # chop off chemical system - _, pearson, spg, *wyks = aflow_label.split("_") - - n_params += cry_param_dict[pearson[0]] - - for wyk_letters_per_elem in wyks: - # normalize Wyckoff letters to start with 1 if missing digit - wyk_letters_normalized = re.sub( - RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyk_letters_per_elem - ) - sep_el_wyks = split_alpha_numeric(wyk_letters_normalized) - n_params += count_values_for_wyckoff( - sep_el_wyks["alpha"], - sep_el_wyks["numeric"], - spg, - param_dict, - ) - - return n_params + return ( + _count_from_dict(element_wyckoffs, param_dict, spg_num) + + cry_param_dict[pearson_symbol[0]] + ) -def count_crystal_sites(aflow_label: str) -> int: - """Count number of sites from Wyckoff representation. +def count_crystal_sites(protostructure_label: str) -> int: + """Count number of sites from protostructure_label. Args: - aflow_label (str): AFLOW-style prototype label with appended chemical system + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: int: Number of free-parameters in given prototype """ - n_params = 0 + aflow_label, _ = protostructure_label.split(":") # chop off chemical system + _, _, spg_num, *element_wyckoffs = aflow_label.split("_") + + return _count_from_dict(element_wyckoffs, wyckoff_multiplicity_dict, spg_num) + - aflow_label, _ = aflow_label.split(":") # chop off chemical system - _, pearson, spg, *wyks = aflow_label.split("_") +def _count_from_dict( + element_wyckoffs: list[str], lookup_dict: dict, spg_num: str +) -> int: + """Count number of sites from protostructure_label.""" + n_params = 0 - for wyk_letters_per_elem in wyks: + for wyckoffs in element_wyckoffs: # normalize Wyckoff letters to start with 1 if missing digit - wyk_letters_normalized = re.sub( - RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyk_letters_per_elem + sep_el_wyks = split_alpha_numeric( + re.sub(RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyckoffs) ) - sep_el_wyks = split_alpha_numeric(wyk_letters_normalized) n_params += count_values_for_wyckoff( sep_el_wyks["alpha"], sep_el_wyks["numeric"], - spg, - wyckoff_multiplicity_dict, + spg_num, + lookup_dict, ) return int(n_params) -def get_isopointal_proto_from_aflow(aflow_label: str) -> str: - """Get a canonicalized string for the prototype. +def get_prototype_from_protostructure(protostructure_label: str) -> str: + """Get a canonicalized string for the prototype. This prototype should be + the same for all isopointal protostructures. Args: - aflow_label (str): AFLOW-style prototype label with appended chemical system + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: - str: Canonicalized AFLOW-style prototype label with appended chemical system + str: Canonicalized AFLOW-style prototype label """ - aflow_label, _ = aflow_label.split(":") - anonymous_formula, pearson, spg, *wyckoffs = aflow_label.split("_") - - anonymous_formula = re.sub( - RE_ELEMENT_NO_SUFFIX, RE_SUBST_ONE_SUFFIX, anonymous_formula - ) - anom_list = split_alpha_numeric(anonymous_formula) - counts = [int(x) for x in anom_list["numeric"]] - dummy_els = anom_list["alpha"] - - s_counts, s_wyks_tup = list(zip(*sorted(zip(counts, wyckoffs)))) - s_wyks = re.sub(RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, "_".join(s_wyks_tup)) - c_anom = "".join( - [f"{el}{num}" if num != 1 else el for el, num in zip(dummy_els, s_counts)] + aflow_label, _ = protostructure_label.split(":") + prototype_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split( + "_" ) - if len(s_counts) == len(set(s_counts)): - cs_wyks = canonicalize_elem_wyks(s_wyks, int(spg)) - return f"{c_anom}_{pearson}_{spg}_{cs_wyks}" + anonymous_formula = get_anonymous_formula_from_prototype_formula(prototype_formula) + counts = [ + int(x) + for x in split_alpha_numeric( + re.sub(RE_ELEMENT_NO_SUFFIX, RE_SUBST_ONE_SUFFIX, prototype_formula) + )["numeric"] + ] + + # map to list to avoid mypy error, zip returns tuples. + counts, element_wyckoffs = map(list, zip(*sorted(zip(counts, element_wyckoffs)))) + all_wyckoffs = "_".join(element_wyckoffs) + all_wyckoffs = re.sub(RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, all_wyckoffs) + if len(counts) == len(set(counts)): + all_wyckoffs = canonicalize_element_wyckoffs(all_wyckoffs, int(spg_num)) + return f"{anonymous_formula}_{pearson_symbol}_{spg_num}_{all_wyckoffs}" # credit Stef: https://stackoverflow.com/a/70126643/5517459 - valid_permutations = [ - list(map(itemgetter(1), chain.from_iterable(p))) + all_wyckoffs_permutations = [ + "_".join(list(map(itemgetter(1), chain.from_iterable(p)))) for p in product( *[ permutations(g) for _, g in groupby( - sorted(zip(s_counts, s_wyks.split("_"))), key=lambda x: x[0] + sorted(zip(counts, all_wyckoffs.split("_"))), key=lambda x: x[0] ) ] ) ] - isopointal: list[str] = [] - - for wyks_list in valid_permutations: - for trans in relab_dict[spg]: - t = str.maketrans(trans) - isopointal.append("_".join(wyks_list).translate(t)) - - isopointal = list(set(isopointal)) + isopointal_all_wyckoffs = list( + { + all_wyckoffs.translate(str.maketrans(trans)) + for all_wyckoffs in all_wyckoffs_permutations + for trans in relab_dict[spg_num] + } + ) - scores = [] - sorted_iso = [] - for wyks in isopointal: - sorted_el_wyks, score = sort_and_score_wyks(wyks) - scores.append(score) - sorted_iso.append(sorted_el_wyks) + scored_all_wyckoffs = [ + sort_and_score_element_wyckoffs(element_wyckoffs) + for element_wyckoffs in isopointal_all_wyckoffs + ] - canonical = sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1])) + all_wyckoffs = min(scored_all_wyckoffs, key=lambda x: (x[1], x[0]))[0] - # TODO: how to tie break when the scores are the same? - # currently done by alphabetical - return "_".join((c_anom, pearson, spg, canonical[0][1])) + return f"{anonymous_formula}_{pearson_symbol}_{spg_num}_{all_wyckoffs}" -def _get_anom_formula_dict(anonymous_formula: str) -> dict: +def _get_anonymous_formula_dict(anonymous_formula: str) -> dict: """Get a dictionary of element to count from an anonymous formula.""" result: defaultdict = defaultdict(int) element = "" @@ -609,88 +644,88 @@ def backtrack(translation, index): return backtrack({}, 0) -def get_aflow_strs_from_iso_and_composition( - isopointal_proto: str, composition: Composition +def get_protostructures_from_aflow_label_and_composition( + aflow_label: str, composition: Composition ) -> list[str]: """Get a canonicalized string for the prototype. Args: - isopointal_proto (str): AFLOW-style Canonicalized prototype label + aflow_label (str): AFLOW-style prototype label composition (Composition): pymatgen Composition object Returns: - list[str]: List of possible AFLOW-style prototype labels with appended - chemical systems that can be generated from combinations of the - input isopointal_proto and composition. + list[str]: List of possible protostructure labels that can be generated + from combinations of the input aflow_label and composition. """ - if not isinstance(isopointal_proto, str): - raise TypeError( - f"Invalid isopointal_proto: {isopointal_proto} ({type(isopointal_proto)})" - ) - - anonymous_formula, pearson, spg, *wyckoffs = isopointal_proto.split("_") + anonymous_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split( + "_" + ) ele_amt_dict = composition.get_el_amt_dict() - proto_formula = prototype_formula(composition) - anom_amt_dict = _get_anom_formula_dict(anonymous_formula) + proto_formula = get_prototype_formula_from_composition(composition) + anom_amt_dict = _get_anonymous_formula_dict(anonymous_formula) translations = _find_translations(ele_amt_dict, anom_amt_dict) - anom_ele_to_wyk = dict(zip(anom_amt_dict.keys(), wyckoffs)) + anom_ele_to_wyk = dict(zip(anom_amt_dict.keys(), element_wyckoffs)) anonymous_formula = RE_ANONYMOUS.sub(RE_SUBST_ONE_PREFIX, anonymous_formula) - result = set() + protostructures = set() for t in translations: wyckoff_part = "_".join( RE_WYCKOFF.sub(RE_SUBST_ONE_PREFIX, anom_ele_to_wyk[t[elem]]) for elem in sorted(t.keys()) ) - canonicalized_wyckoff = canonicalize_elem_wyks(wyckoff_part, spg) + canonicalized_wyckoff = canonicalize_element_wyckoffs(wyckoff_part, spg_num) chemical_system = "-".join(sorted(t.keys())) - aflow_str = ( - f"{proto_formula}_{pearson}_{spg}_{canonicalized_wyckoff}:{chemical_system}" + protostructures.add( + f"{proto_formula}_{pearson_symbol}_{spg_num}_{canonicalized_wyckoff}:{chemical_system}" ) - result.add(aflow_str) - - return list(result) - -def count_distinct_wyckoff_letters(aflow_str: str) -> int: - """Count number of distinct Wyckoff letters in Wyckoff representation.""" - aflow_str, _ = aflow_str.split(":") # drop chemical system - _, _, _, wyckoff_letters = aflow_str.split("_", 3) # drop prototype, Pearson, spg - wyckoff_letters = wyckoff_letters.translate(remove_digits).replace("_", "") - return len(set(wyckoff_letters)) # number of distinct Wyckoff letters + return list(protostructures) -def get_random_structure_for_protostructure(protostructure: str, **kwargs) -> Structure: +def get_random_structure_for_protostructure( + protostructure_label: str, **kwargs +) -> Structure: """Generate a random structure for a given prototype structure. NOTE that due to the random nature of the generation, the output structure may be higher symmetry than the requested prototype structure. + + Args: + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. + **kwargs: Keyword arguments to pass to pyxtal().from_random() """ if pyxtal is None: raise ImportError("pyxtal is required for this function") - aflow_label, chemsys = protostructure.split(":") - _, _, spg, *wyckoffs = aflow_label.split("_") + aflow_label, chemsys = protostructure_label.split(":") + _, _, spg_num, *element_wyckoffs = aflow_label.split("_") - wyckoffs = [re.sub(RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, w) for w in wyckoffs] - sep_el_wyks = [split_alpha_numeric(w) for w in wyckoffs] + sep_el_wyks = [ + split_alpha_numeric(re.sub(RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, w)) + for w in element_wyckoffs + ] species_sites = [ [ site - for m, w in zip(d["numeric"], d["alpha"]) - for site in [f"{wyckoff_multiplicity_dict[spg][w]}{w}"] * int(m) + for count, wyckoff_letter in zip(d["numeric"], d["alpha"]) + for site in [ + f"{wyckoff_multiplicity_dict[spg_num][wyckoff_letter]}{wyckoff_letter}" + ] + * int(count) ] for d in sep_el_wyks ] species_counts = [ sum( - wyckoff_multiplicity_dict[spg][w] * int(m) - for m, w in zip(d["numeric"], d["alpha"]) + wyckoff_multiplicity_dict[spg_num][wyckoff_letter] * int(count) + for count, wyckoff_letter in zip(d["numeric"], d["alpha"]) ) for d in sep_el_wyks ] @@ -698,7 +733,7 @@ def get_random_structure_for_protostructure(protostructure: str, **kwargs) -> St p = pyxtal() p.from_random( dim=3, - group=int(spg), + group=int(spg_num), species=chemsys.split("-"), numIons=species_counts, sites=species_sites, diff --git a/aviary/wrenformer/data.py b/aviary/wrenformer/data.py index f32ff796..685c1e98 100644 --- a/aviary/wrenformer/data.py +++ b/aviary/wrenformer/data.py @@ -11,7 +11,7 @@ from aviary import PKG_DIR from aviary.data import InMemoryDataLoader -from aviary.wren.data import parse_aflow_wyckoff_str +from aviary.wren.data import parse_protostructure_label if TYPE_CHECKING: import pandas as pd @@ -84,18 +84,20 @@ def get_wyckoff_features( ) -def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor: +def wyckoff_embedding_from_protostructure_label(protostructure_label: str) -> Tensor: """Concatenate Matscholar element embeddings with Wyckoff set embeddings and handle augmentation of equivalent Wyckoff sets. Args: - wyckoff_str (str): Aflow-style Wyckoff string. + protostructure_label (str): label constructed as `aflow_label:chemsys` where + aflow_label is an AFLOW-style prototype label chemsys is the alphabetically + sorted chemical system. Returns: Tensor: Shape (n_equiv_wyksets, n_wyckoff_sites, n_features) where n_features = 200 + 444 for Matscholar and Wyckoff embeddings respectively. """ - parsed_output = parse_aflow_wyckoff_str(wyckoff_str) + parsed_output = parse_protostructure_label(protostructure_label) spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoffs = parsed_output symmetry_features = np.stack( @@ -185,7 +187,7 @@ def df_to_in_mem_dataloader( raise ValueError(f"{embedding_type = } must be 'wyckoff' or 'composition'") initial_embeddings = df[input_col].map( - wyckoff_embedding_from_aflow_str + wyckoff_embedding_from_protostructure_label if embedding_type == "wyckoff" else get_composition_embedding ) diff --git a/examples/cgcnn-example.py b/examples/cgcnn-example.py index f02b1e88..7d398c33 100644 --- a/examples/cgcnn-example.py +++ b/examples/cgcnn-example.py @@ -142,24 +142,23 @@ def main( df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) + elif val_size == 0 and evaluate: + print("No validation set used, using test set for evaluation purposes") + # NOTE that when using this option care must be taken not to + # peak at the test-set. The only valid model to use is the one + # obtained after the final epoch where the epoch count is + # decided in advance of the experiment. + val_set = test_set + elif val_size == 0: + val_set = None else: - if val_size == 0 and evaluate: - print("No validation set used, using test set for evaluation purposes") - # NOTE that when using this option care must be taken not to - # peak at the test-set. The only valid model to use is the one - # obtained after the final epoch where the epoch count is - # decided in advance of the experiment. - val_set = test_set - elif val_size == 0: - val_set = None - else: - print(f"using {val_size} of training set as validation set") - train_idx, val_idx = split( - train_idx, - random_state=data_seed, - test_size=val_size / (1 - test_size), - ) - val_set = torch.utils.data.Subset(dataset, val_idx) + print(f"using {val_size} of training set as validation set") + train_idx, val_idx = split( + train_idx, + random_state=data_seed, + test_size=val_size / (1 - test_size), + ) + val_set = torch.utils.data.Subset(dataset, val_idx) train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) diff --git a/examples/inputs/poscar_to_df.py b/examples/inputs/poscar_to_df.py index 4b0a7979..bd496c49 100644 --- a/examples/inputs/poscar_to_df.py +++ b/examples/inputs/poscar_to_df.py @@ -8,7 +8,10 @@ from pymatgen.core import Composition, Structure from tqdm import tqdm -from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib +from aviary.wren.utils import ( + count_wyckoff_positions, + get_protostructure_label_from_spglib, +) tqdm.pandas() # prime progress_map functionality @@ -62,7 +65,7 @@ print(f"Number of points in dataset: {len(df)}") # takes ~ 15mins -df["wyckoff"] = df.final_structure.progress_map(get_aflow_label_from_spglib) +df["wyckoff"] = df.final_structure.progress_map(get_protostructure_label_from_spglib) # lattice, sites = zip(*df.final_structure.progress_map(get_cgcnn_input)) diff --git a/examples/notebooks/Roost.ipynb b/examples/notebooks/Roost.ipynb index cadbc5f9..4b3941c6 100644 --- a/examples/notebooks/Roost.ipynb +++ b/examples/notebooks/Roost.ipynb @@ -36,7 +36,10 @@ "from aviary.roost.data import collate_batch as roost_cb\n", "from aviary.roost.model import Roost\n", "from aviary.utils import results_multitask, train_ensemble\n", - "from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib" + "from aviary.wren.utils import (\n", + " count_wyckoff_positions,\n", + " get_protostructure_label_from_spglib,\n", + ")" ] }, { @@ -58,7 +61,7 @@ "\n", "df[\"composition\"] = [x.composition.reduced_formula for x in df.final_structure]\n", "df[\"volume_per_atom\"] = [x.volume / len(x) for x in df.final_structure]\n", - "df[\"wyckoff\"] = df[\"final_structure\"].map(get_aflow_label_from_spglib)\n", + "df[\"wyckoff\"] = df[\"final_structure\"].map(get_protostructure_label_from_spglib)\n", "\n", "df = df[df.wyckoff.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", @@ -237,7 +240,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.12.2" }, "vscode": { "interpreter": { diff --git a/examples/notebooks/Wren.ipynb b/examples/notebooks/Wren.ipynb index 4149870e..43900b53 100644 --- a/examples/notebooks/Wren.ipynb +++ b/examples/notebooks/Wren.ipynb @@ -36,7 +36,10 @@ "from aviary.wren.data import WyckoffData\n", "from aviary.wren.data import collate_batch as wren_cb\n", "from aviary.wren.model import Wren\n", - "from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib" + "from aviary.wren.utils import (\n", + " count_wyckoff_positions,\n", + " get_protostructure_label_from_spglib,\n", + ")" ] }, { @@ -58,7 +61,7 @@ "\n", "df[\"composition\"] = [x.composition.reduced_formula for x in df.final_structure]\n", "df[\"volume_per_atom\"] = [x.volume / len(x) for x in df.final_structure]\n", - "df[\"wyckoff\"] = df[\"final_structure\"].map(get_aflow_label_from_spglib)\n", + "df[\"wyckoff\"] = df[\"final_structure\"].map(get_protostructure_label_from_spglib)\n", "\n", "df = df[df.wyckoff.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", diff --git a/examples/roost-example.py b/examples/roost-example.py index 1c5532b0..b69347b5 100644 --- a/examples/roost-example.py +++ b/examples/roost-example.py @@ -122,24 +122,23 @@ def main( df=df, elem_embedding=elem_embedding, task_dict=task_dict ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) + elif val_size == 0 and evaluate: + print("No validation set used, using test set for evaluation purposes") + # NOTE that when using this option care must be taken not to + # peak at the test-set. The only valid model to use is the one + # obtained after the final epoch where the epoch count is + # decided in advance of the experiment. + val_set = test_set + elif val_size == 0: + val_set = None else: - if val_size == 0 and evaluate: - print("No validation set used, using test set for evaluation purposes") - # NOTE that when using this option care must be taken not to - # peak at the test-set. The only valid model to use is the one - # obtained after the final epoch where the epoch count is - # decided in advance of the experiment. - val_set = test_set - elif val_size == 0: - val_set = None - else: - print(f"using {val_size} of training set as validation set") - train_idx, val_idx = split( - train_idx, - random_state=data_seed, - test_size=val_size / (1 - test_size), - ) - val_set = torch.utils.data.Subset(dataset, val_idx) + print(f"using {val_size} of training set as validation set") + train_idx, val_idx = split( + train_idx, + random_state=data_seed, + test_size=val_size / (1 - test_size), + ) + val_set = torch.utils.data.Subset(dataset, val_idx) train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) diff --git a/examples/wren-example.py b/examples/wren-example.py index 6a68b401..666fe3b1 100644 --- a/examples/wren-example.py +++ b/examples/wren-example.py @@ -135,23 +135,22 @@ def main( task_dict=task_dict, ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) + elif val_size == 0 and evaluate: + print("No validation set used, using test set for evaluation purposes") + # NOTE that when using this option care must be taken not to + # peak at the test-set. The only valid model to use is the one + # obtained after the final epoch where the epoch count is + # decided in advance of the experiment. + val_set = test_set + elif val_size == 0: + val_set = None else: - if val_size == 0 and evaluate: - print("No validation set used, using test set for evaluation purposes") - # NOTE that when using this option care must be taken not to - # peak at the test-set. The only valid model to use is the one - # obtained after the final epoch where the epoch count is - # decided in advance of the experiment. - val_set = test_set - elif val_size == 0: - val_set = None - else: - print(f"using {val_size} of training set as validation set") - test_size = val_size / (1 - test_size) - train_idx, val_idx = split( - train_idx, random_state=data_seed, test_size=test_size - ) - val_set = torch.utils.data.Subset(dataset, val_idx) + print(f"using {val_size} of training set as validation set") + test_size = val_size / (1 - test_size) + train_idx, val_idx = split( + train_idx, random_state=data_seed, test_size=test_size + ) + val_set = torch.utils.data.Subset(dataset, val_idx) train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) diff --git a/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py b/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py index 7c2fdbd8..6382d24e 100644 --- a/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py +++ b/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py @@ -10,7 +10,10 @@ from tqdm import tqdm from aviary import ROOT -from aviary.wren.utils import get_aflow_label_from_aflow, get_aflow_label_from_spglib +from aviary.wren.utils import ( + get_protostructure_label_from_aflow, + get_protostructure_label_from_spglib, +) from examples.wrenformer.mat_bench import DATA_PATHS __author__ = "Janosh Riebesell" @@ -39,7 +42,7 @@ # takes ~6h (when running uninterrupted) for idx, struct in tqdm(df_perovskites.structure.items(), total=len(df_perovskites)): if pd.isna(df_perovskites.aflow_wyckoff[idx]): - df_perovskites.loc[idx, "aflow_wyckoff"] = get_aflow_label_from_aflow( + df_perovskites.loc[idx, "aflow_wyckoff"] = get_protostructure_label_from_aflow( struct, "/Users/janosh/bin/aflow" ) @@ -47,7 +50,7 @@ # %% # takes ~30 sec for struct in tqdm(df_perovskites.structure, total=len(df_perovskites)): - get_aflow_label_from_spglib(struct) + get_protostructure_label_from_spglib(struct) # %% @@ -62,16 +65,15 @@ # %% -# df_perovskites.drop("structure", axis=1).to_csv( -# f"{ROOT}/datasets/matbench_perovskites_aflow_labels.csv" -# ) -df_perovskites = pd.read_csv( - f"{ROOT}/datasets/matbench_perovskites_aflow_labels.csv" -).set_index("mbid") +df_perovskites.drop("structure", axis=1).to_csv( + f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv" +) # %% -f"{ROOT}/datasets/matbench_perovskites_aflow_labels.csv" +df_perovskites = pd.read_csv( + f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv" +).set_index("mbid") # %% diff --git a/examples/wrenformer/mat_bench/make_plots.py b/examples/wrenformer/mat_bench/make_plots.py index d42b3da7..55012837 100644 --- a/examples/wrenformer/mat_bench/make_plots.py +++ b/examples/wrenformer/mat_bench/make_plots.py @@ -13,7 +13,7 @@ from matbench import MatbenchBenchmark from matbench.constants import CLF_KEY, REG_KEY from matbench.metadata import mbv01_metadata as matbench_metadata -from pymatviz.utils import add_identity_line +from pymatviz.powerups import add_identity_line from sklearn.metrics import r2_score, roc_auc_score from examples.wrenformer.mat_bench import DATA_PATHS diff --git a/examples/wrenformer/mat_bench/plotting_functions.py b/examples/wrenformer/mat_bench/plotting_functions.py index d8148246..26e83764 100644 --- a/examples/wrenformer/mat_bench/plotting_functions.py +++ b/examples/wrenformer/mat_bench/plotting_functions.py @@ -35,7 +35,7 @@ def scale_regr_task(series: pd.Series, mad: float) -> pd.Series: # scale regression problems by mad/mae mask = series > 0 mask_iix = np.where(mask) - series.iloc[mask_iix] = series.iloc[mask_iix] / mad + series.iloc[mask_iix] /= mad series.loc[~mask] = np.nan return series diff --git a/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py b/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py index ff1268df..41a72126 100644 --- a/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py +++ b/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py @@ -3,7 +3,7 @@ from tqdm import tqdm from aviary import ROOT -from aviary.wren.utils import get_aflow_label_from_spglib +from aviary.wren.utils import get_protostructure_label_from_spglib __author__ = "Janosh Riebesell" __date__ = "2022-04-11" @@ -25,7 +25,7 @@ if "structure" in df: df["composition"] = [struct.formula for struct in df.structure] df["wyckoff"] = [ - get_aflow_label_from_spglib(struct) + get_protostructure_label_from_spglib(struct) for struct in tqdm(df.structure, desc="Getting Aflow Wyckoff labels") ] elif "composition" in df: diff --git a/pyproject.toml b/pyproject.toml index e59d02f8..5a589cd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aviary" -version = "0.1.2" +version = "1.0.0" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" @@ -72,7 +72,7 @@ no_implicit_optional = false [tool.ruff] target-version = "py39" -include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] +extend-include = ["*.ipynb"] lint.select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions diff --git a/tests/conftest.py b/tests/conftest.py index 4874d239..23923102 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import torch from matminer.datasets import load_dataset -from aviary.wren.utils import get_aflow_label_from_spglib +from aviary.wren.utils import get_protostructure_label_from_spglib __author__ = "Janosh Riebesell" __date__ = "2022-04-09" @@ -37,7 +37,7 @@ def df_matbench_jdft2d(): df = df.set_index("material_id", drop=False) df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] - df["wyckoff"] = df.structure.map(get_aflow_label_from_spglib) + df["wyckoff"] = df.structure.map(get_protostructure_label_from_spglib) return df @@ -48,7 +48,7 @@ def df_matbench_phonons_wyckoff(df_matbench_phonons): paying for it unless requested. """ df_matbench_phonons["wyckoff"] = df_matbench_phonons.structure.map( - get_aflow_label_from_spglib + get_protostructure_label_from_spglib ) return df_matbench_phonons diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 67ace39d..aa86b8a7 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -13,14 +13,15 @@ count_crystal_sites, count_distinct_wyckoff_letters, count_wyckoff_positions, - get_aflow_label_from_aflow, - get_aflow_label_from_spg_analyzer, - get_aflow_label_from_spglib, - get_aflow_strs_from_iso_and_composition, - get_anom_formula_from_prototype_formula, - get_isopointal_proto_from_aflow, + get_anonymous_formula_from_prototype_formula, + get_protostructure_label_from_aflow, + get_protostructure_label_from_spg_analyzer, + get_protostructure_label_from_spglib, + get_protostructures_from_aflow_label_and_composition, + get_prototype_formula_from_composition, + get_prototype_from_protostructure, get_random_structure_for_protostructure, - prototype_formula, + relab_dict, ) from .conftest import TEST_DIR @@ -43,18 +44,20 @@ ] -def test_get_aflow_label_from_spglib(): - """Check that spglib gives correct Aflow label for esseneite.""" +def test_get_protostructure_label_from_spglib(): + """Check that spglib gives correct protostructure label for esseneite""" struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - - assert get_aflow_label_from_spglib(struct) == "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" + assert ( + get_protostructure_label_from_spglib(struct) + == "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" + ) -def test_get_aflow_label_from_spglib_edge_case(): +def test_get_protostructure_label_from_spglib_edge_case(): """Check edge case where the symmetry precision is too low.""" struct = Structure.from_file(f"{TEST_DIR}/data/U2Pa4Tc6.json") - defaults = inspect.signature(get_aflow_label_from_spglib).parameters + defaults = inspect.signature(get_protostructure_label_from_spglib).parameters assert defaults["init_symprec"].default == 0.1 @@ -67,41 +70,45 @@ def test_get_aflow_label_from_spglib_edge_case(): "expected U(PaTc3)2 to be UPa2Tc3" ) with pytest.raises(ValueError, match=re.escape(raises_str)): - get_aflow_label_from_spg_analyzer(spg_analyzer, raise_errors=True) + get_protostructure_label_from_spg_analyzer(spg_analyzer, raise_errors=True) assert ( - get_aflow_label_from_spg_analyzer(spg_analyzer, raise_errors=False) + get_protostructure_label_from_spg_analyzer(spg_analyzer, raise_errors=False) == raises_str ) # Test that it gives invalid protostructure if fallback is None. with pytest.raises(ValueError, match=re.escape(raises_str)): - get_aflow_label_from_spglib(struct, raise_errors=True, fallback_symprec=None) + get_protostructure_label_from_spglib( + struct, raise_errors=True, fallback_symprec=None + ) assert ( - get_aflow_label_from_spglib(struct, raise_errors=False, fallback_symprec=None) + get_protostructure_label_from_spglib( + struct, raise_errors=False, fallback_symprec=None + ) == raises_str ) - assert get_aflow_label_from_spglib(struct, raise_errors=True) == ( + assert get_protostructure_label_from_spglib(struct, raise_errors=True) == ( "A2B3C_hP6_191_c_g_a:Pa-Tc-U" ) - assert get_aflow_label_from_spglib(struct, raise_errors=False) == ( + assert get_protostructure_label_from_spglib(struct, raise_errors=False) == ( "A2B3C_hP6_191_c_g_a:Pa-Tc-U" ) @pytest.mark.parametrize( - "aflow_label, expected", + "protostructure_label, expected", [ ("ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", 6), # esseneite ("A6B11CD7_aP50_2_6i_ac10i_i_7i:C-H-N-O", 26), ("foo_bar_47_abc_A_b:X-Y-Z", 5), ], ) -def test_count_wyckoff_positions(aflow_label, expected): - count = count_wyckoff_positions(aflow_label) +def test_count_wyckoff_positions(protostructure_label, expected): + count = count_wyckoff_positions(protostructure_label) assert isinstance(count, int) assert count == expected @@ -121,7 +128,7 @@ def test_count_crystal_sites(): @pytest.mark.parametrize( - "aflow_label, expected", + "protostructure_label, expected", [ ("ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", "ABC2D6_mC40_15_e_e_f_3f"), ("ABC6D2_mC40_15_e_a_3f_f:Ca-Fe-O-Si", "ABC2D6_mC40_15_a_e_f_3f"), @@ -131,13 +138,38 @@ def test_count_crystal_sites(): ("A4BC20D2_oC108_41_2b_a_10b_b:B-Ca-H-N", "AB2C4D20_oC108_41_a_b_2b_10b"), ], ) -def test_get_isopointal_proto(aflow_label, expected): +def test_get_prototype_from_protostructure(protostructure_label, expected): """Get a recanonicalized prototype string without chemical system""" - assert get_isopointal_proto_from_aflow(aflow_label) == expected + aflow_label, chemsys = protostructure_label.split(":") + prototype_formula, pearson_symbol, spg_num, *wyckoffs = aflow_label.split("_") + + element_wyckoff = "_".join(wyckoffs) + + isopointal_element_wyckoffs = list( + { + element_wyckoff.translate(str.maketrans(trans)) + for trans in relab_dict[spg_num] + } + ) + + protostructure_labels = [ + f"{prototype_formula}_{pearson_symbol}_{spg_num}_{element_wyckoff}:{chemsys}" + for element_wyckoff in isopointal_element_wyckoffs + ] + + print(protostructure_label) + print(protostructure_labels) + print(get_prototype_from_protostructure(protostructure_label)) + print(expected) + + assert all( + get_prototype_from_protostructure(protostructure_label) == expected + for protostructure_label in protostructure_labels + ) @pytest.mark.parametrize( - "isopointal_proto, composition, expected", + "aflow_label, composition, expected", [ ( "AB2C3D4_tP10_115_a_g_bg_cdg", @@ -152,17 +184,18 @@ def test_get_isopointal_proto(aflow_label, expected): ), ], ) -def test_get_aflow_strs_from_iso_and_composition( - isopointal_proto, composition, expected +def test_get_protostructures_from_aflow_label_and_composition( + aflow_label, composition, expected ): - aflows = get_aflow_strs_from_iso_and_composition( - isopointal_proto, Composition(composition) + protostructures = get_protostructures_from_aflow_label_and_composition( + aflow_label, Composition(composition) ) - assert set(aflows) == set(expected.split(" ")) + assert set(protostructures) == set(expected.split(" ")) # check the round trip assert all( - get_isopointal_proto_from_aflow(aflow) == isopointal_proto for aflow in aflows + get_prototype_from_protostructure(protostructure) == aflow_label + for protostructure in protostructures ) @@ -216,20 +249,25 @@ def test_find_translations_performance(): "composition, expected", [("Ce2Al3GaPd4", "A3B2CD4"), ("YbNiO3", "AB3C"), ("K2NaAlF6", "AB6C2D")], ) -def test_prototype_formula(composition: str, expected: str): - assert prototype_formula(Composition(composition)) == expected +def test_get_prototype_formula_from_composition(composition: str, expected: str): + assert get_prototype_formula_from_composition(Composition(composition)) == expected @pytest.mark.parametrize( - "composition, expected", - [("Ce2Al3GaPd4", "AB2C3D4"), ("YbNiO3", "ABC3"), ("K2NaAlF6", "ABC2D6")], + "anonymous_formula, prototype_formula", + [("AB", "AB"), ("A2B", "AB2"), ("A3B2CD4", "AB2C3D4")], ) -def test_get_anom_formula_from_prototype_formula(composition: str, expected: str): - assert get_anom_formula_from_prototype_formula("A3B2CD4") == "AB2C3D4" +def test_get_anonymous_formula_from_prototype_formula( + anonymous_formula: str, prototype_formula: str +): + assert ( + get_anonymous_formula_from_prototype_formula(anonymous_formula) + == prototype_formula + ) @pytest.mark.parametrize( - "aflow_label, expected", + "protostructure_label, expected", [ ("A20BC14D8E5F2_oP800_61_40c_2c_28c_16c_10c_4c:C-Cd-H-N-O-S", 1), ("ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si", 2), @@ -237,19 +275,16 @@ def test_get_anom_formula_from_prototype_formula(composition: str, expected: str ("A6B11CD7_aP50_2_6i_ac10i_i_7i:C-H-N-O", 3), ], ) -def test_count_distinct_wyckoff_letters(aflow_label, expected): - assert count_distinct_wyckoff_letters(aflow_label) == expected - - -aflow_cli = which("aflow") +def test_count_distinct_wyckoff_letters(protostructure_label, expected): + assert count_distinct_wyckoff_letters(protostructure_label) == expected -@pytest.mark.skipif(aflow_cli is None, reason="aflow CLI not installed") -def test_get_aflow_label_from_aflow(): - """Check we extract corred correct aflow label for esseneite from Aflow CLI""" +@pytest.mark.skipif(which("aflow") is None, reason="AFLOW CLI not installed") +def test_get_protostructure_label_from_aflow(): + """Check we extract correct protostructure label for esseneite using AFLOW CLI.""" struct = Structure.from_file(f"{TEST_DIR}/data/ABC6D2_mC40_15_e_e_3f_f.cif") - out = get_aflow_label_from_aflow(struct, aflow_cli) + out = get_protostructure_label_from_aflow(struct, which("aflow")) expected = "ABC6D2_mC40_15_e_e_3f_f:Ca-Fe-O-Si" assert out == expected @@ -264,7 +299,7 @@ def test_get_aflow_label_from_aflow(): ) def test_get_random_structure_for_protostructure_roundtrip(protostructure): """Check roundtrip for generating a random structure from a prototype string""" - assert protostructure == get_aflow_label_from_spglib( + assert protostructure == get_protostructure_label_from_spglib( get_random_structure_for_protostructure(protostructure) )