Skip to content

Commit

Permalink
Add cli.py, diatomics.py + refactor metrics/discovery.py (#203)
Browse files Browse the repository at this point in the history
* add diatomics module for generating atomic pair-repulsion curves

- diatomics.py has functions to generate diatomic molecules and calculate potential energy and forces
- test_diatomics.py tests molecule generation and energy/force results
- support both homo-nuclear and hetero-nuclear diatomics

* add cli.py with central argument parser for plotting scripts

- support command-line configuration for models, test subsets, energy types, and plot options
- add corresponding test suite in `test_cli.py`
- ensure compatible with unknown sys.argv from Jupyter kernel
- update DataFiles and Model enums to use `auto()` and tuple-based file path specification
- add test cases for MbdKey, Task, ModelType, Open, and TestSubset enums

* move all enums into enums.py: Files, DataFiles and Model enums were in data.py before

improves code org

* add metrics_df_from_yaml() function in metrics/__init__.py to extract metrics from model YAML files

- metrics_df_from_yaml used to be get_df_metrics() in scripts/evals/discovery.py
- move df_metrics, df_metrics_10k, df_metrics_uniq_protos from preds/discovery.py to metrics/discovery.py
- add test_metrics_init.py to cover metrics_df_from_yaml()
- add test_discovery_preds.py to cover df_preds, df_each_pred, df_each_err
- update plot scripts to use YAML-read metrics instead of expensive on-the-fly calculated df_metrics{_test_subset}

* fix mypy and plotting scripts still importing models from matbench_discovery.preds.discovery.models
  • Loading branch information
janosh authored Feb 8, 2025
1 parent 012ccfe commit 426ff2b
Show file tree
Hide file tree
Showing 91 changed files with 1,580 additions and 1,144 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
fail-fast: false
matrix:
script:
- scripts/metrics/eval_discovery.py
- scripts/evals/discovery.py
# TODO run_all.py was commented out during removal of WBM energy preds from version control. consider partially reinstating with on-the-fly downloaded or mock WBM energy preds
# - scripts/model_figs/run_all.py
steps:
Expand Down
2 changes: 1 addition & 1 deletion data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from tqdm import tqdm

from matbench_discovery import MP_DIR, ROOT, today
from matbench_discovery.data import DataFiles
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
from matbench_discovery.enums import DataFiles

module_dir = os.path.dirname(__file__)

Expand Down
4 changes: 2 additions & 2 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from tqdm import tqdm

from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS
from matbench_discovery.data import DataFiles, ase_atoms_from_zip, df_wbm
from matbench_discovery.data import ase_atoms_from_zip, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.enums import MbdKey
from matbench_discovery.enums import DataFiles, MbdKey

__author__ = "Janosh Riebesell"
__date__ = "2023-11-22"
Expand Down
3 changes: 1 addition & 2 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from tqdm import tqdm

from matbench_discovery import STABILITY_THRESHOLD, today
from matbench_discovery.data import DataFiles
from matbench_discovery.enums import MbdKey
from matbench_discovery.enums import DataFiles, MbdKey

__author__ = "Janosh Riebesell"
__date__ = "2023-01-10"
Expand Down
2 changes: 1 addition & 1 deletion data/phonons/phonondb_103_pbe_eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import moyopy.interface
import pymatviz as pmv

from matbench_discovery.data import DataFiles
from matbench_discovery.enums import DataFiles

__date__ = "2025-01-14"

Expand Down
3 changes: 2 additions & 1 deletion data/pmg_structs_to_ase_extxyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from tqdm import tqdm

from matbench_discovery import MP_DIR, WBM_DIR, today
from matbench_discovery.data import DataFiles, ase_atoms_to_zip
from matbench_discovery.data import ase_atoms_to_zip
from matbench_discovery.enums import DataFiles

__author__ = "Yuan Chiang, Janosh Riebesell"
__date__ = "2023-08-10"
Expand Down
3 changes: 2 additions & 1 deletion data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from tqdm import tqdm

from matbench_discovery import ROOT, today
from matbench_discovery.data import DataFiles, df_wbm
from matbench_discovery.data import df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.enums import DataFiles

wbm_cse_path = DataFiles.wbm_computed_structure_entries.path
df_cse = pd.read_json(wbm_cse_path).set_index(Key.mat_id)
Expand Down
3 changes: 1 addition & 2 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today
from matbench_discovery.data import DataFiles
from matbench_discovery.energy import calc_energy_from_e_refs, mp_elemental_ref_energies
from matbench_discovery.enums import MbdKey
from matbench_discovery.enums import DataFiles, MbdKey

try:
import gdown
Expand Down
4 changes: 2 additions & 2 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
from matbench_discovery import plots as plots
from matbench_discovery.data import DataFiles, df_wbm
from matbench_discovery.data import df_wbm
from matbench_discovery.energy import mp_elem_ref_entries
from matbench_discovery.enums import MbdKey
from matbench_discovery.enums import DataFiles, MbdKey
from matbench_discovery.preds.discovery import df_each_err

__author__ = "Janosh Riebesell"
Expand Down
15 changes: 10 additions & 5 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import plotly.express as px
import plotly.io as pio
import pymatviz as pmv # needed for pymatviz_dark template

from matbench_discovery.enums import MbdKey
import pymatviz as pmv # needed for pymatviz_dark template # noqa: F401

PKG_NAME = "matbench-discovery"
__version__ = "1.3.1"
Expand All @@ -23,6 +21,15 @@
SCRIPTS = f"{ROOT}/scripts" # model and date analysis scripts
PDF_FIGS = f"{ROOT}/paper/figs" # directory for light-themed PDF figures

# directory to cache downloaded data files
DEFAULT_CACHE_DIR = os.getenv(
"MBD_CACHE_DIR",
DATA_DIR # use DATA_DIR to locally cache data files if full repo was cloned
if os.path.isdir(DATA_DIR)
# use ~/.cache if matbench-discovery was installed from PyPI
else os.path.expanduser("~/.cache/matbench-discovery"),
)

for directory in (SITE_FIGS, SITE_LIB, PDF_FIGS):
os.makedirs(directory, exist_ok=True)

Expand Down Expand Up @@ -59,8 +66,6 @@


# --- start global plot settings
px.defaults.labels |= {key.name: key.label for key in (*MbdKey, *pmv.enums.Key)}

global_layout = dict(
paper_bgcolor="rgba(0,0,0,0)",
font_size=13,
Expand Down
55 changes: 55 additions & 0 deletions matbench_discovery/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Central argument parser for Matbench Discovery scripts."""

import argparse

from pymatviz.enums import Key

from matbench_discovery.enums import Model, TestSubset

cli_parser = argparse.ArgumentParser(
description="CLI flags for plotting and analysis scripts."
)
plot_group = cli_parser.add_argument_group(
"plot", "Arguments for controlling figure generation"
)
plot_group.add_argument(
"--models",
nargs="*",
type=Model, # type: ignore[arg-type]
choices=Model,
default=list(Model),
help="Models to analyze. If none specified, analyzes all models.",
)
plot_group.add_argument(
"--test-subset",
type=TestSubset,
default=TestSubset.uniq_protos,
choices=list(TestSubset),
help="Which subset of the WBM test set to use for evaluation. "
"Default is to only use unique Aflow protostructures. "
"Training sets like MPtrj, sAlex and Omat24 were filtered to remove protostructures"
" overlap with WBM, resulting in a slightly more out-of-distribution test set.",
)
plot_group.add_argument(
"--energy-type",
type=str,
default=Key.each,
choices=[Key.e_form, Key.each],
help="Whether to use formation energy or convex hull distance.",
)
plot_group.add_argument(
"--show-non-compliant",
action="store_true",
help="Whether to show non-compliant models.",
)
plot_group.add_argument(
"--use-full-rows",
action="store_true",
help="Whether to drop models that don't fit in complete rows.",
)
plot_group.add_argument(
"--update-existing",
action="store_true",
help="Whether to update figures whose file paths already exist.",
)
cli_args, _ignore_unknown = cli_parser.parse_known_args()
Loading

0 comments on commit 426ff2b

Please sign in to comment.