Skip to content

Commit

Permalink
add color bar labels to MP/WBM/MPtrj ptable element occurrence heatmaps
Browse files Browse the repository at this point in the history
tweak plot scripts
update site deps, esp. elementari to fix black text on ptable element tiles missing data
  • Loading branch information
janosh committed Jan 15, 2024
1 parent 12e8477 commit e7d57ee
Show file tree
Hide file tree
Showing 36 changed files with 139 additions and 86 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.56.0
rev: v9.0.0-alpha.0
hooks:
- id: eslint
types: [file]
Expand Down
47 changes: 36 additions & 11 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@
spacegroup_sunburst,
)
from pymatviz.io import save_fig
from pymatviz.utils import si_fmt
from pymatviz.utils import si_fmt, si_fmt_int

from matbench_discovery import (
PDF_FIGS,
ROOT,
SITE_FIGS,
STABILITY_THRESHOLD,
e_form_raw_col,
formula_col,
id_col,
)
from matbench_discovery import plots as plots
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.energy import mp_elem_reference_entries
from matbench_discovery.preds import df_each_err, each_true_col
from matbench_discovery.preds import df_each_err, e_form_col, each_true_col

__author__ = "Janosh Riebesell"
__date__ = "2023-03-30"
Expand Down Expand Up @@ -141,8 +142,8 @@

# %% histogram of energy distance to MP convex hull for WBM
e_col = each_true_col # or e_form_col
e_col = "e_form_per_atom_uncorrected"
e_col = "e_form_per_atom_mp2020_corrected"
# e_col = e_form_raw_col
# e_col = e_form_col
mean, std = df_wbm[e_col].mean(), df_wbm[e_col].std()

range_x = (mean - 2 * std, mean + 2 * std)
Expand Down Expand Up @@ -170,9 +171,28 @@
dummy_mae = (df_wbm[e_col] - df_wbm[e_col].mean()).abs().mean()

title = (
f"{len(df_wbm.dropna()):,} structures with {n_stable:,} stable + {n_unstable:,}"
f"{si_fmt_int(len(df_wbm.dropna()))} structures with {si_fmt_int(n_stable)} "
f"stable + {si_fmt_int(n_unstable)} unstable (stable rate="
f"{n_stable / len(df_wbm):.1%})"
)
fig.layout.title = dict(text=title, x=0.5)
fig.layout.title = dict(text=title, x=0.5, font_size=16, y=0.95)

# add red/blue annotations to left and right of mean saying stable/unstable
for idx, (label, x_pos) in enumerate(
(("stable", mean - std), ("unstable", mean + std))
):
fig.add_annotation(
x=x_pos,
y=0.5,
text=label,
showarrow=False,
font_size=18,
font_color=px.colors.qualitative.Plotly[idx],
yref="paper",
xanchor="right",
xshift=-40,
)


fig.layout.margin = dict(l=0, r=0, b=0, t=40)
fig.update_layout(showlegend=False)
Expand All @@ -183,14 +203,19 @@
(mean + std, f"{mean + std = :.2f}"),
):
anno = dict(text=label, yshift=-10, xshift=-5, xanchor="right")
line_width = 1 if x_pos == mean else 0.5
line_width = 3 if x_pos == mean else 2
fig.add_vline(x=x_pos, line=dict(width=line_width, dash="dash"), annotation=anno)

fig.show()

save_fig(fig, f"{SITE_FIGS}/hist-wbm-hull-dist.svelte")
# save_fig(fig, "./figs/hist-wbm-hull-dist.svg", width=1000, height=500)
save_fig(fig, f"{PDF_FIGS}/hist-wbm-hull-dist.pdf")
suffix = {
each_true_col: "hull-dist",
e_form_col: "e-form",
e_form_raw_col: "e-form-uncorrected",
}[e_col]
img_name = f"hist-wbm-{suffix}"
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
# save_fig(fig, f"./figs/{img_name}.svg", width=800, height=500)
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)


# %%
Expand Down
31 changes: 18 additions & 13 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
from pymatviz.io import save_fig
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS, formula_col, id_col, today
from matbench_discovery import (
PDF_FIGS,
SITE_FIGS,
e_form_raw_col,
formula_col,
id_col,
today,
)
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom

Expand All @@ -39,7 +46,7 @@


module_dir = os.path.dirname(__file__)
e_form_col = "e_form_per_atom_wbm"
e_form_wbm_col = "e_form_per_atom_wbm"


# %% links to google drive files received via email from 1st author Hai-Chen Wang
Expand Down Expand Up @@ -296,7 +303,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
"nsites": "n_sites",
"vol": "volume",
"e": "uncorrected_energy",
"e_form": e_form_col,
"e_form": e_form_wbm_col,
"e_hull": "e_above_hull_wbm",
"gap": "bandgap_pbe",
"id": id_col,
Expand Down Expand Up @@ -440,15 +447,15 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:

# %% remove suspicious formation energy outliers
e_form_cutoff = 5
n_too_stable = sum(df_summary[e_form_col] < -e_form_cutoff)
n_too_stable = sum(df_summary[e_form_wbm_col] < -e_form_cutoff)
print(f"{n_too_stable = }") # n_too_stable = 502
n_too_unstable = sum(df_summary[e_form_col] > e_form_cutoff)
n_too_unstable = sum(df_summary[e_form_wbm_col] > e_form_cutoff)
print(f"{n_too_unstable = }") # n_too_unstable = 22

e_form_hist, e_form_bins = np.histogram(
df_summary[e_form_col], bins=300, range=(-5.5, 5.5)
df_summary[e_form_wbm_col], bins=300, range=(-5.5, 5.5)
)
x_label = {e_form_col: "WBM uncorrected formation energy (eV/atom)"}[e_form_col]
x_label = {e_form_wbm_col: "WBM uncorrected formation energy (eV/atom)"}[e_form_wbm_col]
fig = px.bar(
x=e_form_bins[:-1], # [:-1] to drop last bin edge which is not needed
y=e_form_hist,
Expand Down Expand Up @@ -485,7 +492,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# %%
assert len(df_summary) == len(df_wbm) == 257_487

query_str = f"{-e_form_cutoff} < {e_form_col} < {e_form_cutoff}"
query_str = f"{-e_form_cutoff} < {e_form_wbm_col} < {e_form_cutoff}"
dropped_ids = sorted(set(df_summary.index) - set(df_summary.query(query_str).index))
assert len(dropped_ids) == 502 + 22
assert dropped_ids[:3] == "wbm-1-12142 wbm-1-12143 wbm-1-12144".split()
Expand Down Expand Up @@ -569,8 +576,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# first make sure source and target dfs have matching indices
assert sum(df_wbm.index != df_summary.index) == 0

e_form_col = "e_form_per_atom_uncorrected"

for row in tqdm(df_wbm.itertuples(), total=len(df_wbm)):
mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
Expand All @@ -585,11 +590,11 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
assert (
abs(e_form - e_form_ppd) < 1e-4
), f"{mat_id}: {e_form=:.3} != {e_form_ppd=:.3} (diff={e_form - e_form_ppd:.3}))"
df_summary.loc[cse.entry_id, e_form_col] = e_form
df_summary.loc[cse.entry_id, e_form_raw_col] = e_form


df_summary[e_form_col.replace("uncorrected", "mp2020_corrected")] = (
df_summary[e_form_col] + df_summary["e_correction_per_atom_mp2020"]
df_summary[e_form_raw_col.replace("uncorrected", "mp2020_corrected")] = (
df_summary[e_form_raw_col] + df_summary["e_correction_per_atom_mp2020"]
)


Expand Down
2 changes: 1 addition & 1 deletion data/wbm/figs/hist-wbm-hull-dist.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ The number of stable materials (according to the MP convex hull which is spanned

The WBM test set and even more so the MP training set are heavily oxide dominated. The WBM test set is about 75% larger than the MP training set and also more chemically diverse, containing a higher fraction of transition metals, post-transition metals and metalloids. Our goal in picking such a large diverse test set is future-proofing. Ideally, this data will provide a challenging materials discovery test bed even for large foundational ML models in the future.

Below: Element counts for WBM test set consisting of 256,963 WBM `ComputedStructureEntries`

<slot name="wbm-elements-heatmap">
<img src="./figs/wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">
</slot>

Below: Element counts for MP training set consisting of 154,719 `ComputedStructureEntries`
The WBM test set consists of 256,963 WBM `ComputedStructureEntries`

<slot name="mp-elements-heatmap">
<img src="./figs/mp-elements.svg" alt="Periodic table log heatmap of MP elements">
</slot>

The MP training set consists of 154,719 `ComputedStructureEntries`

<slot name="mp-trj-elements-heatmap" />

## 📊 &thinsp; Symmetry Statistics
Expand Down
1 change: 1 addition & 0 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
init_struct_col = "initial_structure"
struct_col = "structure"
e_form_col = "formation_energy_per_atom"
e_form_raw_col = "e_form_per_atom_uncorrected"
formula_col = "formula"
stress_col = "stress"
stress_trace_col = "stress_trace"
Expand Down
8 changes: 3 additions & 5 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@ def rolling_mae_vs_hull_dist(
y=(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
name=triangle_anno,
fillcolor="red",
# remove triangle border
line=dict(color="rgba(0,0,0,0)"),
**scatter_kwds,
)
fig.add_annotation(
Expand Down Expand Up @@ -535,14 +537,10 @@ def rolling_mae_vs_hull_dist(

from matbench_discovery.preds import model_styles

for idx, trace in enumerate(fig.data):
for trace in fig.data:
if style := model_styles.get(trace.name):
ls, _marker, color = style
trace.line = dict(color=color, dash=ls, width=2)
else:
trace.line = dict(
color=plotly_colors[idx], dash=plotly_line_styles[idx], width=3
)
# marker_spacing = 2
# fig.add_scatter(
# x=trace.x[::marker_spacing],
Expand Down
6 changes: 3 additions & 3 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
each_pred_col = "e_above_hull_pred"
model_mean_each_col = "Mean prediction all models"
model_mean_err_col = "Mean error all models"
model_std_col = "Std. dev. over models"
model_std_each_col = "Std. dev. over models"

for col in (model_mean_each_col, model_mean_err_col, model_std_col):
for col in (model_mean_each_col, model_mean_err_col, model_std_each_col):
quantity_labels[col] = f"{col} {ev_per_atom}"


Expand Down Expand Up @@ -211,7 +211,7 @@ def load_df_wbm_with_preds(
)

# important: do df_each_pred.std(axis=1) before inserting model_mean_each_col into df
df_preds[model_std_col] = df_each_pred.std(axis=1)
df_preds[model_std_each_col] = df_each_pred.std(axis=1)
df_each_pred[model_mean_each_col] = df_preds[model_mean_each_col] = df_each_pred.mean(
axis=1
)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "matbench-discovery"
version = "1.0.0"
description = "A benchmark for machine learning energy models on inorganic crystal stability prediction from unrelaxed structures"
authors = [{ name = "Janosh Riebesell", email = "janosh@lbl.gov" }]
authors = [{ name = "Janosh Riebesell", email = "janosh.riebesell@gmail.com" }]
readme = "readme.md"
license = { file = "license" }
keywords = [
Expand Down Expand Up @@ -60,7 +60,7 @@ running-models = [

"aviary@git+https://github.com/CompRhys/aviary",
"m3gnet",
"mace@git+https://github.com/ACEsuit/mace",
"mace-torch",
"maml",
"megnet",
]
Expand Down
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<h4 align="center" class="toc-exclude">

[![arXiv](https://img.shields.io/badge/arXiv-2308.14920-blue)](https://arxiv.org/abs/2308.14920)
[![arXiv](https://img.shields.io/badge/arXiv-2308.14920-blue?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2308.14920)
[![Tests](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml)
[![GitHub Pages](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml)
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
Expand Down
12 changes: 7 additions & 5 deletions scripts/model_figs/analyze_model_disagreement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
each_true_col,
model_mean_each_col,
model_mean_err_col,
model_std_col,
model_std_each_col,
)

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -56,8 +56,7 @@
fig = df_plot.plot.scatter(
x=each_true_col,
y=model_mean_each_col,
color=model_std_col,
size="n_sites",
color=model_std_each_col,
backend="plotly",
hover_name=id_col,
hover_data=[formula_col],
Expand All @@ -71,11 +70,14 @@
fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14)
fig.layout.margin.update(l=0, r=30, b=0, t=60)
add_identity_line(fig)
label = {"all": "structures"}.get(material_cls, material_cls)
fig.layout.title.update(
text=f"{n_structs} largest {material_cls} model errors: Predicted vs.<br>"
"DFT hull distance colored by model disagreement",
text=f"{n_structs} {material_cls} with largest hull distance errors<br>"
"colored by model disagreement, sized by number of sites",
x=0.5,
)
# size markers by structure
fig.data[0].marker.size = df_plot["n_sites"] ** 0.5 * 3
# tried setting error_y=model_std_col but looks bad
# fig.update_traces(
# error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2)
Expand Down
19 changes: 10 additions & 9 deletions scripts/model_figs/parity_energy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
df_melt,
bin_by_cols=[e_true_col, e_pred_col],
group_by_cols=[facet_col],
n_bins=200,
n_bins=300,
bin_counts_col=(bin_cnt_col := "bin counts"),
)
df_bin = df_bin.reset_index()
Expand Down Expand Up @@ -162,7 +162,8 @@
# pick from https://plotly.com/python/builtin-colorscales
color_continuous_scale="agsunset",
)

# decrease marker size
fig.update_traces(marker=dict(size=2))
# manually set colorbar ticks and labels (needed after log1p transform)
tick_vals = [1, 10, 100, 1000, 10_000]
fig.layout.coloraxis.colorbar.update(
Expand All @@ -181,9 +182,8 @@
assert model in df_preds, f"Unexpected {model=} not in {list(df_preds)=}"
# add MAE and R2 to subplot titles
MAE, R2 = df_metrics[model][["MAE", "R2"]]
fig.layout.annotations[
idx - 1
].text = f"{model} · {MAE=:.2f} · R<sup>2</sup>={R2:.2f}"
sub_title = f"{model} · {MAE=:.2f} · R<sup>2</sup>={R2:.2f}"
fig.layout.annotations[idx - 1].text = sub_title

# remove subplot x and y axis titles
fig.layout[f"xaxis{idx}"].title.text = ""
Expand Down Expand Up @@ -222,7 +222,7 @@
yshift=-15 * sign_y,
text=label,
showarrow=False,
font=dict(size=16, color=color),
font=dict(size=14, color=color),
row="all",
col="all",
)
Expand All @@ -245,9 +245,10 @@
# fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))

axis_titles = dict(xref="paper", yref="paper", showarrow=False)
portrait = n_rows > n_cols
fig.add_annotation( # x-axis title
x=0.5,
y=-0.06,
y=-0.06 if portrait else -0.18,
text=x_title,
**axis_titles,
)
Expand All @@ -259,10 +260,10 @@
**axis_titles,
)

fig.layout.height = 230 * n_rows
fig.layout.update(height=230 * n_rows, width=180 * n_cols)
fig.layout.coloraxis.colorbar.update(orientation="h", thickness=9, len=0.5, y=1.05)
# fig.layout.width = 1100
fig.layout.margin.update(l=40, r=10, t=30, b=60)
fig.layout.margin.update(l=40, r=10, t=30 if portrait else 10, b=60 if portrait else 10)
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.show()
Expand Down
Loading

0 comments on commit e7d57ee

Please sign in to comment.