Skip to content

Commit c9dc8ee

Browse files
authored
Add pyright pre-commit hook and fix possibly unbound variables (#93)
* add pyright pre-commit hook * mv scripts/model_figs/(make_metrics_tables->metrics_tables).py mv scripts/model_figs/(make_hull_dist_box_plot->hull_dist_box_plot).py * fix most pyright PossiblyUnboundVariable * remove __init__.py convenience re-exports of enums * LabelEnum add new to_dict methods val_desc_dict + label_desc_dict renamed key_val_dict, val_label_dict
1 parent e596392 commit c9dc8ee

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+210
-145
lines changed

.github/workflows/test-scripts.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
script:
17-
- scripts/model_figs/make_metrics_tables.py
17+
- scripts/model_figs/metrics_tables.py
1818
- scripts/model_figs/rolling_mae_vs_hull_dist_models.py
1919
steps:
2020
- name: Check out repository

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ repos:
5656
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5757

5858
- repo: https://github.com/pre-commit/mirrors-eslint
59-
rev: v9.0.0-beta.0
59+
rev: v9.0.0-beta.1
6060
hooks:
6161
- id: eslint
6262
types: [file]
@@ -78,3 +78,9 @@ repos:
7878
files: ^models/(.+)/\1.*\.yml$
7979
args: [--schemafile, tests/model-schema.yml]
8080
- id: check-github-actions
81+
82+
- repo: https://github.com/RobertCraigie/pyright-python
83+
rev: v1.1.351
84+
hooks:
85+
- id: pyright
86+
args: [--level, error]

data/mp/build_phase_diagram.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from pymatgen.ext.matproj import MPRester
1818
from tqdm import tqdm
1919

20-
from matbench_discovery import MP_DIR, ROOT, Key, today
20+
from matbench_discovery import MP_DIR, ROOT, today
2121
from matbench_discovery.data import DATA_FILES
2222
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
23+
from matbench_discovery.enums import Key
2324

2425
module_dir = os.path.dirname(__file__)
2526

data/mp/eda_mp_trj.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
from pymatviz.utils import si_fmt
2222
from tqdm import tqdm
2323

24-
from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS, Key
24+
from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS
2525
from matbench_discovery.data import DATA_FILES, df_wbm
26+
from matbench_discovery.enums import Key
2627

2728
__author__ = "Janosh Riebesell"
2829
__date__ = "2023-11-22"
@@ -108,10 +109,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
108109

109110
# %% plot per-element magmom histograms
110111
ptable_magmom_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-magmoms.json.bz2"
112+
srs_mp_trj_elem_magmoms = locals().get("srs_mp_trj_elem_magmoms")
111113

112114
if os.path.isfile(ptable_magmom_hist_path):
113115
srs_mp_trj_elem_magmoms = pd.read_json(ptable_magmom_hist_path, typ="series")
114-
elif "srs_mp_trj_elem_magmoms" not in locals():
116+
if srs_mp_trj_elem_magmoms is None:
115117
# project magmoms onto symbols in dict
116118
df_mp_trj_elem_magmom = pd.DataFrame(
117119
[
@@ -151,10 +153,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
151153

152154
# %% plot per-element force histograms
153155
ptable_force_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-forces.json.bz2"
156+
srs_mp_trj_elem_forces = locals().get("srs_mp_trj_elem_forces")
154157

155158
if os.path.isfile(ptable_force_hist_path):
156159
srs_mp_trj_elem_forces = pd.read_json(ptable_force_hist_path, typ="series")
157-
elif "srs_mp_trj_elem_forces" not in locals():
160+
if srs_mp_trj_elem_forces is None:
158161
df_mp_trj_elem_forces = pd.DataFrame(
159162
[
160163
dict(zip(elems, np.abs(forces).mean(axis=1)))
@@ -193,10 +196,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
193196

194197
# %% plot histogram of number of sites per element
195198
ptable_n_sites_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-n-sites.json.bz2"
199+
srs_mp_trj_elem_n_sites = locals().get("srs_mp_trj_elem_n_sites")
196200

197201
if os.path.isfile(ptable_n_sites_hist_path):
198202
srs_mp_trj_elem_n_sites = pd.read_json(ptable_n_sites_hist_path, typ="series")
199-
elif "mp_trj_elem_n_sites" not in locals():
203+
elif srs_mp_trj_elem_n_sites is None:
200204
# construct a series of lists of site numbers per element (i.e. how often each
201205
# element appears in a structure with n sites)
202206
# create all df cols as int dtype
@@ -320,8 +324,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
320324
pdf_kwds = dict(width=500, height=300)
321325

322326
x_col, y_col = "E<sub>form</sub> (eV/atom)", count_col
327+
df_e_form = locals().get("df_e_form")
323328

324-
if "df_e_form" not in locals(): # only compute once for speed
329+
if df_e_form is None: # only compute once for speed
325330
e_form_hist = np.histogram(df_mp_trj[Key.e_form], bins=300)
326331
df_e_form = pd.DataFrame(e_form_hist, index=[y_col, x_col]).T.round(3)
327332

@@ -340,8 +345,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
340345
# %% plot forces distribution
341346
# use numpy to pre-compute histogram
342347
x_col, y_col = "|Forces| (eV/Å)", count_col
348+
df_forces = locals().get("df_forces")
343349

344-
if "df_forces" not in locals(): # only compute once for speed
350+
if df_forces is None: # only compute once for speed
345351
forces_hist = np.histogram(
346352
df_mp_trj[Key.forces].explode().explode().abs(), bins=300
347353
)
@@ -361,8 +367,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
361367

362368
# %% plot hydrostatic stress distribution
363369
x_col, y_col = "1/3 Tr(σ) (eV/ų)", count_col # noqa: RUF001
370+
df_stresses = locals().get("df_stresses")
364371

365-
if "df_stresses" not in locals(): # only compute once for speed
372+
if df_stresses is None: # only compute once for speed
366373
stresses_hist = np.histogram(df_mp_trj[Key.stress_trace], bins=300)
367374
df_stresses = pd.DataFrame(stresses_hist, index=[y_col, x_col]).T.round(3)
368375

@@ -381,8 +388,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
381388

382389
# %% plot magmoms distribution
383390
x_col, y_col = "Magmoms (μ<sub>B</sub>)", count_col
391+
df_magmoms = locals().get("df_magmoms")
384392

385-
if "df_magmoms" not in locals(): # only compute once for speed
393+
if df_magmoms is None: # only compute once for speed
386394
magmoms_hist = np.histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300)
387395
df_magmoms = pd.DataFrame(magmoms_hist, index=[y_col, x_col]).T.round(3)
388396

data/mp/get_mp_energies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from pymatviz.utils import annotate_metrics
99
from tqdm import tqdm
1010

11-
from matbench_discovery import STABILITY_THRESHOLD, Key, today
11+
from matbench_discovery import STABILITY_THRESHOLD, today
1212
from matbench_discovery.data import DATA_FILES
13+
from matbench_discovery.enums import Key
1314

1415
"""
1516
Download all MP formation and above hull energies on 2023-01-10.

data/mp/get_mp_traj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from pymongo.database import Database
1818
from tqdm import tqdm, trange
1919

20-
from matbench_discovery import ROOT, Key, today
20+
from matbench_discovery import ROOT, today
21+
from matbench_discovery.enums import Key
2122

2223
__author__ = "Janosh Riebesell"
2324
__date__ = "2023-03-15"

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1111
from tqdm import tqdm
1212

13-
from matbench_discovery import ROOT, Key, today
13+
from matbench_discovery import ROOT, today
1414
from matbench_discovery.data import DATA_FILES, df_wbm
1515
from matbench_discovery.energy import get_e_form_per_atom
16+
from matbench_discovery.enums import Key
1617
from matbench_discovery.plots import plt
1718

1819
"""

data/wbm/compile_wbm_test_set.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@
1919
from pymatviz.io import save_fig
2020
from tqdm import tqdm
2121

22-
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today
22+
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today
2323
from matbench_discovery.data import DATA_FILES
2424
from matbench_discovery.energy import get_e_form_per_atom
25+
from matbench_discovery.enums import Key
2526

2627
try:
2728
import gdown
28-
except ImportError:
29-
print(
29+
except ImportError as exc:
30+
exc.add_note(
3031
"gdown not installed. Needed for downloading WBM initial + relaxed structures "
3132
"from Google Drive."
3233
)
34+
raise
3335

3436
"""
3537
Dataset generated with DFT and published in Jan 2021 as
@@ -90,8 +92,8 @@
9092
18198704957443186264,
9193
)
9294

93-
if "dfs_wbm_structs" not in locals():
94-
dfs_wbm_structs = {}
95+
dfs_wbm_structs = locals().get("dfs_wbm_structs", {})
96+
9597
for json_path in json_paths:
9698
step = int(json_path.split(".json.bz2")[0][-1])
9799
assert step in range(1, 6)
@@ -179,8 +181,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
179181
print(f"{file_path} already exists, skipping")
180182
continue
181183

184+
url = f"{mat_cloud_url}&filename={filename}"
182185
try:
183-
url = f"{mat_cloud_url}&filename={filename}"
184186
urllib.request.urlretrieve(url, file_path)
185187
except urllib.error.HTTPError as exc:
186188
print(f"failed to download {url=}: {exc}")

matbench_discovery/__init__.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,7 @@
1313
import plotly.io as pio
1414
import pymatviz # noqa: F401
1515

16-
from matbench_discovery.enums import ( # noqa: F401
17-
Key,
18-
Model,
19-
ModelType,
20-
Open,
21-
Quantity,
22-
Targets,
23-
Task,
24-
)
16+
from matbench_discovery.enums import Model, Quantity
2517

2618
PKG_NAME = "matbench-discovery"
2719
direct_url = Distribution.from_name(PKG_NAME).read_text("direct_url.json") or "{}"
@@ -69,17 +61,17 @@
6961
FIGSHARE_URLS = json.load(file)
7062

7163
# --- start global plot settings
72-
px.defaults.labels = Quantity.val_dict() | Model.val_dict()
64+
px.defaults.labels = Quantity.key_val_dict() | Model.key_val_dict()
7365

7466
global_layout = dict(
7567
paper_bgcolor="rgba(0,0,0,0)",
7668
font_size=13,
7769
# increase legend marker size and make background transparent
7870
legend=dict(itemsizing="constant", bgcolor="rgba(0, 0, 0, 0)"),
7971
)
80-
pio.templates["global"] = dict(layout=global_layout)
81-
pio.templates.default = "pymatviz_dark+global"
82-
px.defaults.template = "pymatviz_dark+global"
72+
pio.templates["mbd_global"] = dict(layout=global_layout)
73+
pio.templates.default = "pymatviz_dark+mbd_global"
74+
px.defaults.template = "pymatviz_dark+mbd_global"
8375

8476
# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
8577
# when seeing MathJax "loading" message in exported PDFs,

matbench_discovery/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from monty.json import MontyDecoder
1515
from tqdm import tqdm
1616

17-
from matbench_discovery import FIGSHARE_DIR, Key
17+
from matbench_discovery import FIGSHARE_DIR
18+
from matbench_discovery.enums import Key
1819

1920
if TYPE_CHECKING:
2021
from pathlib import Path

matbench_discovery/enums.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,24 @@ def description(self) -> str:
2929
return self.__dict__["desc"]
3030

3131
@classmethod
32-
def val_dict(cls) -> dict[str, str]:
33-
"""Return the Enum as dictionary."""
32+
def key_val_dict(cls) -> dict[str, str]:
33+
"""Map of keys to values."""
3434
return {key: str(val) for key, val in cls.__members__.items()}
3535

3636
@classmethod
37-
def label_dict(cls) -> dict[str, str]:
38-
"""Return the Enum as dictionary."""
39-
return {str(val): val.label for key, val in cls.__members__.items()}
37+
def val_label_dict(cls) -> dict[str, str | None]:
38+
"""Map of values to labels."""
39+
return {str(val): val.label for val in cls.__members__.values()}
40+
41+
@classmethod
42+
def val_desc_dict(cls) -> dict[str, str | None]:
43+
"""Map of values to descriptions."""
44+
return {str(val): val.description for val in cls.__members__.values()}
45+
46+
@classmethod
47+
def label_desc_dict(cls) -> dict[str | None, str | None]:
48+
"""Map of labels to descriptions."""
49+
return {str(val.label): val.description for val in cls.__members__.values()}
4050

4151

4252
@unique

matbench_discovery/plots.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def hist_classified_stable_vs_hull_dist(
103103
fixed in Inkscape or similar by merging regions by color.
104104
"""
105105
x_col = dict(true=each_true_col, pred=each_pred_col)[which_energy]
106+
clf_col, value_name = "classified", "count"
106107

107108
df_plot = pd.DataFrame()
109+
108110
for facet, df_group in (
109111
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]
110112
):
@@ -113,16 +115,16 @@ def hist_classified_stable_vs_hull_dist(
113115
)
114116

115117
# switch between hist of DFT-computed and model-predicted convex hull distance
116-
e_above_hull = df_group[x_col]
117-
each_true_pos = e_above_hull[true_pos]
118-
each_true_neg = e_above_hull[true_neg]
119-
each_false_neg = e_above_hull[false_neg]
120-
each_false_pos = e_above_hull[false_pos]
118+
srs_each = df_group[x_col]
119+
each_true_pos = srs_each[true_pos]
120+
each_true_neg = srs_each[true_neg]
121+
each_false_neg = srs_each[false_neg]
122+
each_false_pos = srs_each[false_pos]
121123
# n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
122124
# sum, (true_pos, false_pos, true_neg, false_neg)
123125
# )
124126

125-
df_group[(clf_col := "classified")] = np.array(clf_labels)[
127+
df_group[clf_col] = np.array(clf_labels)[
126128
true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
127129
]
128130

@@ -144,7 +146,6 @@ def hist_classified_stable_vs_hull_dist(
144146
index=clf_labels,
145147
).T
146148
df_hist[x_col] = bin_edges[:-1]
147-
value_name = "count"
148149
df_melt = df_hist.melt(
149150
id_vars=x_col,
150151
value_vars=clf_labels,
@@ -714,6 +715,7 @@ def cumulative_metrics(
714715
)
715716
df = dfs[metric]
716717
ax.set(ylim=(0, 1), xlim=(0, None), ylabel=metric)
718+
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
717719
for model in df_preds:
718720
# TODO is this really necessary?
719721
if len(df[model].dropna()) == 0:
@@ -722,7 +724,6 @@ def cumulative_metrics(
722724
y_end = df[model].dropna().iloc[-1]
723725
# add some visual guidelines to the plot
724726
intersect_kwargs = dict(linestyle=":", alpha=0.4, linewidth=2)
725-
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
726727
# place model name at the end of every line
727728
ax.text(x_end, y_end, model, va="bottom", rotation=30, bbox=bbox)
728729
if "x" in project_end_point:

matbench_discovery/preds.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import pandas as pd
66
from tqdm import tqdm
77

8-
from matbench_discovery import ROOT, STABILITY_THRESHOLD, Key, Model
8+
from matbench_discovery import ROOT, STABILITY_THRESHOLD, Model
99
from matbench_discovery.data import Files, df_wbm, glob_to_df
10+
from matbench_discovery.enums import Key
1011
from matbench_discovery.metrics import stable_metrics
1112
from matbench_discovery.plots import plotly_colors, plotly_line_styles, plotly_markers
1213

@@ -65,7 +66,7 @@ class PredFiles(Files):
6566

6667

6768
# key_map maps model keys to pretty labels
68-
PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.val_dict())
69+
PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.key_val_dict())
6970

7071

7172
def load_df_wbm_with_preds(
@@ -101,15 +102,14 @@ def load_df_wbm_with_preds(
101102
)
102103

103104
dfs: dict[str, pd.DataFrame] = {}
104-
105105
try:
106106
for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
107107
bar.set_postfix_str(model_name)
108108
df = glob_to_df(PRED_FILES[model_name], pbar=False, **kwargs)
109109
df = df.set_index(id_col)
110110
dfs[model_name] = df
111111
except Exception as exc:
112-
raise RuntimeError(f"Failed to load {model_name=}") from exc
112+
raise RuntimeError(f"Failed to load {locals().get('model_name')=}") from exc
113113

114114
from matbench_discovery.data import df_wbm
115115

models/alignn/test_alignn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from sklearn.metrics import r2_score
1818
from tqdm import tqdm
1919

20-
from matbench_discovery import Key, Task, today
20+
from matbench_discovery import today
2121
from matbench_discovery.data import DATA_FILES, df_wbm
22+
from matbench_discovery.enums import Key, Task
2223
from matbench_discovery.plots import wandb_scatter
2324
from matbench_discovery.slurm import slurm_submit
2425

0 commit comments

Comments
 (0)