Skip to content

Commit

Permalink
[WIP] Download files to right locations to allow plotting. (#201)
Browse files Browse the repository at this point in the history
* clean: clearer about protostructures in enum for wyckoff_spglib allowing it to be removed from pymatviz given deprecated terminology (see #131)

* clean: remove Alex sAlex Duplication from dpa models

* WIP: attempt to write a download script that enables users to generate the plots themselves again. blocked as the file urls are private

* Model prediction file path get new auto-download mechanism

- add maybe_auto_download_file() used in Model.(discovery|geo_opt|kappa_103)_path
- introduce env var MBD_AUTO_DOWNLOAD_FILES and MBD_CACHE_DIR for download control
- Files enum add abstractmethods for url and label properties so child classes must implement
- test coverage for new download functionality
- remove scripts/download_model_preds_from_figshare.py

* fix test_model_enum

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
CompRhys and janosh authored Feb 6, 2025
1 parent 991f41f commit 3516d23
Show file tree
Hide file tree
Showing 16 changed files with 297 additions and 112 deletions.
20 changes: 10 additions & 10 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from matbench_discovery import STABILITY_THRESHOLD, today
from matbench_discovery.data import DataFiles
from matbench_discovery.enums import MbdKey

__author__ = "Janosh Riebesell"
__date__ = "2023-01-10"
Expand All @@ -28,7 +29,7 @@
fields = {
Key.mat_id,
"formula_pretty",
Key.form_energy,
e_form_col := "formation_energy_per_atom",
"energy_per_atom",
"symmetry",
"energy_above_hull",
Expand All @@ -53,7 +54,7 @@
df_spg = pd.json_normalize(df_mp.pop("symmetry"))[["number", "symbol"]]
df_mp["spacegroup_symbol"] = df_spg.symbol.to_numpy()

df_mp.energy_type.value_counts().plot.pie(backend=pmv.utils.PLOTLY, autopct="%1.1f%%")
df_mp.energy_type.value_counts().plot.pie(autopct="%1.1f%%")
# GGA: 72.2%, GGA+U: 27.8%


Expand All @@ -65,15 +66,14 @@
df_cse[Key.structure] = [
Structure.from_dict(cse[Key.structure]) for cse in tqdm(df_cse.entry)
]
df_cse[Key.wyckoff] = [
get_protostructure_label_from_spglib(struct, errors="ignore")
for struct in tqdm(df_cse.structure)
df_cse[MbdKey.wyckoff_spglib] = [
get_protostructure_label_from_spglib(struct) for struct in tqdm(df_cse.structure)
]
# make sure symmetry detection succeeded for all structures
assert df_cse[Key.wyckoff].str.startswith("invalid").sum() == 0
df_mp[Key.wyckoff] = df_cse[Key.wyckoff]
assert df_cse[MbdKey.wyckoff_spglib].str.startswith("invalid").sum() == 0
df_mp[MbdKey.wyckoff_spglib] = df_cse[MbdKey.wyckoff_spglib]

spg_nums = df_mp[Key.wyckoff].str.split("_").str[2].astype(int)
spg_nums = df_mp[MbdKey.wyckoff_spglib].str.split("_").str[2].astype(int)
# make sure all our spacegroup numbers match MP's
assert (spg_nums.sort_index() == df_spg["number"].sort_index()).all()

Expand All @@ -83,7 +83,7 @@

# %% reproduce fig. 1b from https://arxiv.org/abs/2001.10591 (as data consistency check)
ax = df_mp.plot.scatter(
x=Key.form_energy,
x=e_form_col,
y="decomposition_enthalpy",
alpha=0.1,
xlim=[-5, 1],
Expand All @@ -109,7 +109,7 @@
x="decomposition_enthalpy",
y="energy_above_hull",
color=mask_above_line.map({True: "red", False: "blue"}),
hover_data=["index", Key.formula, Key.form_energy],
hover_data=["index", Key.formula, e_form_col],
)
# most points lie on line y=x for x > 0 and y = 0 for x < 0.
n_above_line = sum(mask_above_line)
Expand Down
18 changes: 10 additions & 8 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,19 +634,19 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:

# from initial structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(MbdKey.init_wyckoff)):
if not pd.isna(df_summary.loc[idx].get(MbdKey.init_wyckoff_spglib)):
continue # Aflow label already computed
try:
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
df_summary.loc[idx, MbdKey.init_wyckoff] = (
df_summary.loc[idx, MbdKey.init_wyckoff_spglib] = (
get_protostructure_label_from_spglib(struct)
)
except Exception as exc:
print(f"{idx=} {exc=}")

# from relaxed structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(Key.wyckoff)):
if not pd.isna(df_summary.loc[idx].get(MbdKey.wyckoff_spglib)):
continue

try:
Expand All @@ -658,8 +658,8 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
except Exception as exc:
print(f"{idx=} {exc=}")

assert df_summary[MbdKey.init_wyckoff].isna().sum() == 0
assert df_summary[Key.wyckoff].isna().sum() == 0
assert df_summary[MbdKey.init_wyckoff_spglib].isna().sum() == 0
assert df_summary[MbdKey.wyckoff_spglib].isna().sum() == 0
except ImportError:
print("aviary not installed, skipping Wyckoff label generation")
except Exception as exception:
Expand All @@ -684,11 +684,13 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
df_mp = pd.read_csv(DataFiles.mp_energies.path, index_col=0)

# mask WBM materials with matching prototype in MP
mask_proto_in_mp = df_summary[MbdKey.init_wyckoff].isin(df_mp["wyckoff_spglib"])
mask_proto_in_mp = df_summary[MbdKey.init_wyckoff_spglib].isin(
df_mp[MbdKey.wyckoff_spglib]
)
# mask duplicate prototypes in WBM (keeping the lowest energy one)
mask_dupe_protos = df_summary.sort_values(
by=[MbdKey.init_wyckoff, MbdKey.each_wbm]
).duplicated(subset=MbdKey.init_wyckoff, keep="first")
by=[MbdKey.init_wyckoff_spglib, MbdKey.each_wbm]
).duplicated(subset=MbdKey.init_wyckoff_spglib, keep="first")
assert sum(mask_proto_in_mp) == 11_175, f"{sum(mask_proto_in_mp)=:_}"
assert sum(mask_dupe_protos) == 32_784, f"{sum(mask_dupe_protos)=:_}"

Expand Down
6 changes: 4 additions & 2 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@


# %%
df_wbm[Key.spg_num] = df_wbm[MbdKey.init_wyckoff].str.split("_").str[2].astype(int)
df_wbm[Key.spg_num] = (
df_wbm[MbdKey.init_wyckoff_spglib].str.split("_").str[2].astype(int)
)
df_mp[Key.spg_num] = df_mp[f"{Key.wyckoff}_spglib"].str.split("_").str[2].astype(int)


Expand Down Expand Up @@ -350,7 +352,7 @@

# %% find large structures that changed symmetry during relaxation
df_sym_change = (
df_wbm.query(f"{MbdKey.init_wyckoff} != {Key.wyckoff}_spglib")
df_wbm.query(f"{MbdKey.init_wyckoff_spglib} != {MbdKey.wyckoff_spglib}")
.filter(regex="wyckoff|sites")
.nlargest(10, Key.n_sites)
)
Expand Down
113 changes: 68 additions & 45 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Download, cache and hydrate data files from the Matbench Discovery Figshare article.
https://figshare.com/articles/dataset/22715158
Environment Variables:
MBD_AUTO_DOWNLOAD_FILES: Controls whether to auto-download missing data files.
Defaults to "true". Set to "false" to be prompted before downloading.
This affects both model prediction files and dataset files.
MBD_CACHE_DIR: Directory to cache downloaded data files.
Defaults to DATA_DIR if the full repo was cloned, otherwise ~/.cache/matbench-discovery.
"""

import abc
import builtins
import functools
import io
import os
import sys
import traceback
import zipfile
from collections import defaultdict
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -35,14 +44,13 @@
RAW_REPO_URL = "https://github.com/janosh/matbench-discovery/raw"
# directory to cache downloaded data files
DEFAULT_CACHE_DIR = os.getenv(
"MATBENCH_DISCOVERY_CACHE_DIR",
"MBD_CACHE_DIR",
DATA_DIR # use DATA_DIR to locally cache data files if full repo was cloned
if os.path.isdir(DATA_DIR)
# use ~/.cache if matbench-discovery was installed from PyPI
else os.path.expanduser("~/.cache/matbench-discovery"),
)


round_trip_yaml = YAML() # round-trippable YAML for updating model metadata files
round_trip_yaml.preserve_quotes = True
round_trip_yaml.width = 1000 # avoid changing line wrapping
Expand Down Expand Up @@ -205,7 +213,9 @@ def ase_atoms_to_zip(


def download_file(file_path: str, url: str) -> None:
"""Download the file from the given URL to the given file path."""
"""Download the file from the given URL to the given file path.
Prints rather than raises if the file cannot be downloaded.
"""
file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True)
try:
Expand All @@ -215,8 +225,31 @@ def download_file(file_path: str, url: str) -> None:

with open(file_path, "wb") as file:
file.write(response.content)
except requests.exceptions.RequestException as exc:
print(f"Error downloading {url=}\nto {file_path=}.\n{exc!s}")
except requests.exceptions.RequestException:
print(f"Error downloading {url=}\nto {file_path=}.\n{traceback.format_exc()}")


def maybe_auto_download_file(url: str, abs_path: str, label: str | None = None) -> None:
"""Download file if it doesn't exist and user confirms or auto-download is enabled."""
if os.path.isfile(abs_path):
return

# whether to auto-download model prediction files without prompting
auto_download_files = os.getenv("MBD_AUTO_DOWNLOAD_FILES", "true").lower() == "true"

is_ipython = hasattr(builtins, "__IPYTHON__")
# default to 'y' if auto-download is enabled or not in interactive session (TTY or iPython)
answer = (
"y"
if auto_download_files or not (is_ipython or sys.stdin.isatty())
else input(
f"{abs_path!r} associated with {label=} does not exist. Download it "
"now? This will cache the file for future use. [y/n] "
)
)
if answer.lower().strip() == "y":
print(f"Downloading {label!r} from {url!r} to {abs_path!r}")
download_file(abs_path, url)


class MetaFiles(EnumMeta):
Expand Down Expand Up @@ -248,46 +281,20 @@ def base_dir(cls) -> str:
class Files(StrEnum, metaclass=MetaFiles):
"""Enum of data files with associated file directories and URLs."""

def __new__(
cls, file_path: str, url: str | None = None, label: str | None = None
) -> Self:
"""Create a new member of the FileUrls enum with a given URL where to load the
file from and directory where to save it to.
"""
obj = str.__new__(cls)
obj._value_ = file_path.split("/")[-1] # use file name as enum value

obj._rel_path = file_path # type: ignore[attr-defined] # noqa: SLF001
obj._url = url # type: ignore[attr-defined] # noqa: SLF001
obj._label = label # type: ignore[attr-defined] # noqa: SLF001

return obj

def __str__(self) -> str:
"""File path associated with the file URL. Use str(DataFiles.some_key) if you
want the absolute file path without auto-downloading the file if it doesn't
exist yet, e.g. for use in script that generates the file in the first place.
"""
return f"{type(self).base_dir}/{self._rel_path}" # type: ignore[attr-defined]

def __repr__(self) -> str:
"""Return enum attribute's string representation."""
return f"{type(self).__name__}.{self.name}"

@property
@abc.abstractmethod
def url(self) -> str:
"""Url associated with the file URL."""
return self._url # type: ignore[attr-defined]
"""URL associated with the file."""

@property
def rel_path(self) -> str:
"""Relative path of the file associated with the file URL."""
return self._rel_path # type: ignore[attr-defined]
"""Path of the file relative to the repo's ROOT directory."""
return self.value

@property
@abc.abstractmethod
def label(self) -> str:
"""Label associated with the file URL."""
return self._label # type: ignore[attr-defined]
"""Label associated with the file."""

@classmethod
def from_label(cls, label: str) -> Self:
Expand Down Expand Up @@ -358,6 +365,11 @@ def url(self) -> str:
raise ValueError(f"{self.name!r} does not have a URL")
return url

@property
def label(self) -> str:
"""No pretty label for DataFiles, use name instead."""
return self.name

@property
def description(self) -> str:
"""Description associated with the file."""
Expand All @@ -368,7 +380,7 @@ def path(self) -> str:
"""File path associated with the file URL if it exists, otherwise
download the file first, then return the path.
"""
key, rel_path = self.name, self._rel_path # type: ignore[attr-defined]
key, rel_path = self.name, self.rel_path

if rel_path not in self.yaml[key]["path"]:
raise ValueError(f"{rel_path=} does not match {self.yaml[key]['path']}")
Expand All @@ -386,10 +398,7 @@ def path(self) -> str:
else "y"
)
if answer.lower().strip() == "y":
if not is_ipython:
print(
f"Downloading {key!r} from {self.url} to {abs_path} for caching"
)
print(f"Downloading {key!r} from {self.url} to {abs_path}")
download_file(abs_path, self.url)
return abs_path

Expand Down Expand Up @@ -489,6 +498,11 @@ def label(self) -> str:
"""Pretty label associated with the model."""
return self.metadata["model_name"]

@property
def url(self) -> str:
"""Pull request URL in which the model was originally added to the repo."""
return self.metadata["pr_url"]

@property
def key(self) -> str:
"""Key associated with the file URL."""
Expand All @@ -508,11 +522,14 @@ def yaml_path(self) -> str:
def discovery_path(self) -> str:
"""Prediction file path associated with the model."""
rel_path = self.metrics.get("discovery", {}).get("pred_file")
file_url = self.metrics.get("discovery", {}).get("pred_file_url")
if not rel_path:
raise ValueError(
f"metrics.discovery.pred_file not found in {self.rel_path!r}"
)
return f"{ROOT}/{rel_path}"
abs_path = f"{ROOT}/{rel_path}"
maybe_auto_download_file(file_url, abs_path, label=self.label)
return abs_path

@property
def geo_opt_path(self) -> str | None:
Expand All @@ -523,11 +540,14 @@ def geo_opt_path(self) -> str | None:
if geo_opt_metrics in ("not available", "not applicable"):
return None
rel_path = geo_opt_metrics.get("pred_file")
file_url = geo_opt_metrics.get("pred_file_url")
if not rel_path:
raise ValueError(
f"metrics.geo_opt.pred_file not found in {self.rel_path!r}"
)
return f"{ROOT}/{rel_path}"
abs_path = f"{ROOT}/{rel_path}"
maybe_auto_download_file(file_url, abs_path, label=self.label)
return abs_path

@property
def kappa_103_path(self) -> str | None:
Expand All @@ -538,11 +558,14 @@ def kappa_103_path(self) -> str | None:
if phonons_metrics in ("not available", "not applicable"):
return None
rel_path = phonons_metrics.get("kappa_103", {}).get("pred_file")
file_url = phonons_metrics.get("kappa_103", {}).get("pred_file_url")
if not rel_path:
raise ValueError(
f"metrics.phonons.kappa_103.pred_file not found in {self.rel_path!r}"
)
return f"{ROOT}/{rel_path}"
abs_path = f"{ROOT}/{rel_path}"
maybe_auto_download_file(file_url, abs_path, label=self.label)
return abs_path


# render model keys as labels in plotly axes and legends
Expand Down
9 changes: 6 additions & 3 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ class MbdKey(LabelEnum):
openness = "openness", "Openness"
e_above_hull_error = f"Error in E<sub>hull dist</sub> {eV_per_atom}"

init_wyckoff = (
init_wyckoff_spglib = (
"wyckoff_spglib_initial_structure",
"Aflow-Wyckoff Label Initial Structure",
"Protostructure Label for Initial Structure using spglib",
)
wyckoff_spglib = (
"wyckoff_spglib",
"Protostructure Label for Relaxed Structure using spglib",
)
wyckoff_spglib = "wyckoff_spglib", "Aflow-Wyckoff Label"
international_spg_name = "international_spg_name", "International space group name"
spg_num_diff = "spg_num_diff", "Difference in space group number"
n_sym_ops_diff = "n_sym_ops_diff", "Difference in number of symmetry operations"
Expand Down
2 changes: 1 addition & 1 deletion models/deepmd/dpa3-v1-openlam.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ requirements:
pymatgen: 2024.6.10
numpy: 1.26.4

training_set: [OMat24, MPtrj, sAlex, Alex] # need to update to OpenLAM
training_set: [OMat24, MPtrj, sAlex] # need to update to OpenLAM

notes:
Description: |
Expand Down
Loading

0 comments on commit 3516d23

Please sign in to comment.