Skip to content

Commit 9d1e77e

Browse files
committed
link MPtrj dataset from /contribute page "Direct Download" section
update MACE readme for 16M MPtrj checkpoint from pbenner define formula_col to ensure consistency across code base
1 parent 4ce353b commit 9d1e77e

22 files changed

+90
-68
lines changed

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1111
from tqdm import tqdm
1212

13-
from matbench_discovery import ROOT, id_col, today
13+
from matbench_discovery import ROOT, formula_col, id_col, today
1414
from matbench_discovery.data import DATA_FILES, df_wbm
1515
from matbench_discovery.energy import get_e_form_per_atom
1616
from matbench_discovery.plots import plt
@@ -68,7 +68,9 @@
6868

6969

7070
# %%
71-
df_wbm["chem_sys"] = df_wbm.formula.str.replace("[0-9]+", "", regex=True).str.split()
71+
df_wbm["chem_sys"] = (
72+
df_wbm[formula_col].str.replace("[0-9]+", "", regex=True).str.split()
73+
)
7274
df_wbm["anion"] = None
7375
df_wbm["anion"][df_wbm.chem_sys.astype(str).str.contains("'O'")] = "oxide"
7476
df_wbm["anion"][df_wbm.chem_sys.astype(str).str.contains("'S'")] = "sulfide"

data/wbm/eda.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
)
1313
from pymatviz.io import save_fig
1414

15-
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD, id_col
15+
from matbench_discovery import (
16+
PDF_FIGS,
17+
ROOT,
18+
SITE_FIGS,
19+
STABILITY_THRESHOLD,
20+
formula_col,
21+
id_col,
22+
)
1623
from matbench_discovery import plots as plots
1724
from matbench_discovery.data import DATA_FILES, df_wbm
1825
from matbench_discovery.energy import mp_elem_reference_entries
@@ -35,8 +42,10 @@
3542

3643

3744
# %%
38-
wbm_occu_counts = count_elements(df_wbm.formula, count_mode="occurrence").astype(int)
39-
wbm_comp_counts = count_elements(df_wbm.formula, count_mode="composition")
45+
wbm_occu_counts = count_elements(df_wbm[formula_col], count_mode="occurrence").astype(
46+
int
47+
)
48+
wbm_comp_counts = count_elements(df_wbm[formula_col], count_mode="composition")
4049

4150
mp_occu_counts = count_elements(df_mp.formula_pretty, count_mode="occurrence").astype(
4251
int
@@ -60,16 +69,16 @@
6069
df_wbm["step"] = df_wbm.index.str.split("-").str[1].astype(int)
6170
assert df_wbm.step.between(1, 5).all()
6271
for batch in range(1, 6):
63-
count_elements(df_wbm[df_wbm.step == batch].formula).to_json(
72+
count_elements(df_wbm[df_wbm.step == batch][formula_col]).to_json(
6473
f"{data_page}/wbm-element-counts-{batch=}.json"
6574
)
6675

6776
# export element counts by arity (how many elements in the formula)
6877
comp_col = "composition"
69-
df_wbm[comp_col] = df_wbm.formula.map(Composition)
78+
df_wbm[comp_col] = df_wbm[formula_col].map(Composition)
7079

7180
for arity, df_mp in df_wbm.groupby(df_wbm[comp_col].map(len)):
72-
count_elements(df_mp.formula).to_json(
81+
count_elements(df_mp[formula_col]).to_json(
7382
f"{data_page}/wbm-element-counts-{arity=}.json"
7483
)
7584

@@ -206,7 +215,7 @@
206215
y="2d t-SNE 2",
207216
color=color_col,
208217
hover_name=id_col,
209-
hover_data=("formula", each_true_col),
218+
hover_data=(formula_col, each_true_col),
210219
range_color=(0, clr_range_max),
211220
)
212221
fig.show()
@@ -219,7 +228,7 @@
219228
y="3d t-SNE 2",
220229
z="3d t-SNE 3",
221230
color=color_col,
222-
custom_data=[id_col, "formula", each_true_col, color_col],
231+
custom_data=[id_col, formula_col, each_true_col, color_col],
223232
range_color=(0, clr_range_max),
224233
)
225234
fig.data[0].hovertemplate = (

data/wbm/fetch_process_wbm_dataset.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pymatviz.io import save_fig
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import SITE_FIGS, id_col, today
21+
from matbench_discovery import SITE_FIGS, formula_col, id_col, today
2222
from matbench_discovery.data import DATA_FILES
2323
from matbench_discovery.energy import get_e_form_per_atom
2424
from matbench_discovery.plots import pio
@@ -289,7 +289,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
289289

290290
# %%
291291
col_map = {
292-
"# comp": "formula",
292+
"# comp": formula_col,
293293
"nsites": "n_sites",
294294
"vol": "volume",
295295
"e": "uncorrected_energy",
@@ -319,7 +319,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
319319

320320
assert sum(no_id_mask := df_summary.index.isna()) == 6, f"{sum(no_id_mask)=}"
321321
# the 'None' materials have 0 volume, energy, n_sites, bandgap, etc.
322-
assert all(df_summary[no_id_mask].drop(columns=["formula"]) == 0)
322+
assert all(df_summary[no_id_mask].drop(columns=[formula_col]) == 0)
323323
assert len(df_summary.query("volume > 0")) == len(df_wbm) + len(nan_init_structs_ids)
324324
# make sure dropping materials with 0 volume removes exactly 6 materials, the same ones
325325
# listed in bad_struct_ids above
@@ -378,13 +378,13 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
378378

379379
# sort formulas alphabetically
380380
df_summary["alph_formula"] = [
381-
Composition(x).alphabetical_formula for x in df_summary.formula
381+
Composition(x).alphabetical_formula for x in df_summary[formula_col]
382382
]
383383
# alphabetical formula and original formula differ due to spaces, number 1 after element
384384
# symbols (FeO vs Fe1 O1), and element order (FeO vs OFe)
385-
assert sum(df_summary.alph_formula != df_summary.formula) == 257_483
385+
assert sum(df_summary.alph_formula != df_summary[formula_col]) == 257_483
386386

387-
df_summary["formula"] = df_summary.pop("alph_formula")
387+
df_summary[formula_col] = df_summary.pop("alph_formula")
388388

389389

390390
# %% write initial structures and computed structure entries to compressed json
@@ -404,10 +404,10 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
404404
# df_summary and df_wbm formulas differ because summary formulas are reduced while
405405
# df_wbm formulas are not (e.g. Ac6 U2 vs Ac3 U1 in summary). unreduced is more
406406
# informative so we use it.
407-
assert sum(df_summary.formula != df_wbm.formula_from_cse) == 114_273
408-
assert sum(df_summary.formula == df_wbm.formula_from_cse) == 143_214
407+
assert sum(df_summary[formula_col] != df_wbm.formula_from_cse) == 114_273
408+
assert sum(df_summary[formula_col] == df_wbm.formula_from_cse) == 143_214
409409

410-
df_summary.formula = df_wbm.formula_from_cse
410+
df_summary[formula_col] = df_wbm.formula_from_cse
411411

412412

413413
# fix bad energy which is 0 in df_summary but a more realistic -63.68 in CSE

matbench_discovery/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@
3737
init_struct_col = "initial_structure"
3838
struct_col = "structure"
3939
e_form_col = "formation_energy_per_atom"
40+
formula_col = "formula"

models/chgnet/analyze_chgnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
1111
from pymatviz.io import save_fig
1212

13-
from matbench_discovery import PDF_FIGS, id_col
13+
from matbench_discovery import PDF_FIGS, formula_col, id_col
1414
from matbench_discovery import plots as plots
1515
from matbench_discovery.data import DATA_FILES, df_wbm
1616
from matbench_discovery.preds import PRED_FILES
@@ -26,7 +26,7 @@
2626
df_chgnet_v020 = pd.read_csv(
2727
f"{module_dir}/2023-03-06-chgnet-0.2.0-wbm-IS2RE.csv.gz", index_col=id_col
2828
)
29-
df_chgnet["formula"] = df_wbm.formula
29+
df_chgnet[formula_col] = df_wbm[formula_col]
3030

3131
e_form_2000 = "e_form_per_atom_chgnet_relax_steps_2000"
3232
e_form_500 = "e_form_per_atom_chgnet_relax_steps_500"
@@ -51,15 +51,15 @@
5151
x=e_form_500,
5252
y=e_form_2000,
5353
hover_name=id_col,
54-
hover_data=["formula"],
54+
hover_data=[formula_col],
5555
backend="plotly",
5656
title=f"{len(df_diff)} structures have > {min_e_diff} eV/atom energy diff after "
5757
"longer relaxation",
5858
)
5959

6060

6161
# %%
62-
fig = ptable_heatmap_plotly(df_bad.formula)
62+
fig = ptable_heatmap_plotly(df_bad[formula_col])
6363
title = "structures with larger error<br>after longer relaxation"
6464
fig.layout.title.update(text=f"{len(df_diff)} {title}", x=0.4, y=0.9)
6565
fig.show()

models/chgnet/ctk_structure_viewer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer
55

6-
from matbench_discovery import id_col
6+
from matbench_discovery import formula_col, id_col
77
from matbench_discovery.preds import PRED_FILES
88

99
__author__ = "Janosh Riebesell"
@@ -47,7 +47,7 @@
4747
y=e_form_2000,
4848
backend="plotly",
4949
hover_name=id_col,
50-
hover_data=["formula"],
50+
hover_data=[formula_col],
5151
labels=plot_labels,
5252
size=e_form_abs_diff,
5353
color=e_form_abs_diff,

models/chgnet/join_chgnet_results.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pymatviz import density_scatter
1414
from tqdm import tqdm
1515

16-
from matbench_discovery import id_col
16+
from matbench_discovery import formula_col, id_col
1717
from matbench_discovery.data import as_dict_handler
1818
from matbench_discovery.energy import get_e_form_per_atom
1919
from matbench_discovery.preds import df_preds, e_form_col
@@ -54,11 +54,11 @@
5454

5555
# %% compute corrected formation energies
5656
e_form_chgnet_col = "e_form_per_atom_chgnet"
57-
df_chgnet["formula"] = df_preds.formula
57+
df_chgnet[formula_col] = df_preds[formula_col]
5858
df_chgnet[e_form_chgnet_col] = [
5959
get_e_form_per_atom(dict(energy=ene, composition=formula))
6060
for formula, ene in tqdm(
61-
df_chgnet.set_index("formula").chgnet_energy.items(), total=len(df_chgnet)
61+
df_chgnet.set_index(formula_col).chgnet_energy.items(), total=len(df_chgnet)
6262
)
6363
]
6464
df_preds[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]

models/chgnet/test_chgnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pymatgen.core import Structure
2222
from tqdm import tqdm
2323

24-
from matbench_discovery import id_col, timestamp, today
24+
from matbench_discovery import formula_col, id_col, timestamp, today
2525
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
2626
from matbench_discovery.plots import wandb_scatter
2727
from matbench_discovery.slurm import slurm_submit
@@ -125,7 +125,7 @@
125125
df_wbm[e_pred_col] = df_out[e_pred_col]
126126
table = wandb.Table(
127127
dataframe=df_wbm.dropna()[
128-
["uncorrected_energy", e_pred_col, "formula"]
128+
["uncorrected_energy", e_pred_col, formula_col]
129129
].reset_index()
130130
)
131131

models/mace/analyze_mace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
99
from pymatviz.io import save_fig
1010

11-
from matbench_discovery import id_col
11+
from matbench_discovery import formula_col, id_col
1212
from matbench_discovery import plots as plots
1313
from matbench_discovery.data import df_wbm
1414
from matbench_discovery.preds import PRED_FILES
@@ -44,7 +44,7 @@
4444

4545

4646
# %%
47-
fig = ptable_heatmap_plotly(df_low.formula)
47+
fig = ptable_heatmap_plotly(df_low[formula_col])
4848
title = f"Elements in {len(df_low):,} MACE severe energy underpredictions"
4949
fig.layout.title.update(text=title, x=0.4, y=0.95)
5050
fig.show()

models/mace/join_mace_results.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pymatviz import density_scatter
1717
from tqdm import tqdm
1818

19-
from matbench_discovery import id_col
19+
from matbench_discovery import formula_col, id_col
2020
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
2121
from matbench_discovery.energy import get_e_form_per_atom
2222
from matbench_discovery.preds import e_form_col
@@ -80,11 +80,11 @@
8080

8181
# %% compute corrected formation energies
8282
e_form_mace_col = "e_form_per_atom_mace"
83-
df_mace["formula"] = df_wbm.formula
83+
df_mace[formula_col] = df_wbm[formula_col]
8484
df_mace[e_form_mace_col] = [
8585
get_e_form_per_atom(dict(energy=cse.energy, composition=formula))
8686
for formula, cse in tqdm(
87-
df_mace.set_index("formula")[entry_col].items(), total=len(df_mace)
87+
df_mace.set_index(formula_col)[entry_col].items(), total=len(df_mace)
8888
)
8989
]
9090
df_wbm[e_form_mace_col] = df_mace[e_form_mace_col]
@@ -106,6 +106,6 @@
106106
df_bad[e_form_col] = df_wbm[e_form_col]
107107
df_bad.to_csv(f"{out_path}-bad.csv")
108108

109-
# in_path = f"{module_dir}/2023-08-14-mace-wbm-IS2RE-FIRE"
109+
# in_path = f"{module_dir}/2023-11-02-mace-wbm-IS2RE-FIRE"
110110
# df_mace = pd.read_csv(f"{in_path}.csv.gz").set_index(id_col)
111111
# df_mace = pd.read_json(f"{in_path}.json.gz").set_index(id_col)

models/mace/json_to_extxyz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
"""This script converts the MPTrj relaxation trajectories from JSON to
2-
extended XYZ format. The JSON data was downloaded from
3-
https://figshare.com/articles/dataset/23713842.
1+
"""This script converts the MPTrj relaxation trajectories downloaded from
2+
https://figshare.com/articles/dataset/23713842 from JSON to extended XYZ format.
43
"""
54

65
import json

models/mace/readme.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
## MACE formation energy predictions on WBM test set
22

3-
This submission uses the [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) checkpoint trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
3+
The original MACE submission used the 2M parameter checkpoint [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
44
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery v1 submission.
55

6+
In late October (received 2023-10-29), Philipp Benner trained a much larger 16M parameter MACE for over 100 epochs in MPtrj which achieved an (at the time SOTA) F1 score of 0.64 and DAF of 3.13.
7+
68
### Convergence criteria
79

810
MACE relaxed each test set structure until the maximum force in the training set dropped below 0.05 eV/Å or 500 optimization steps were reached, whichever occurred first.

models/mace/test_mace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pymatgen.io.ase import AseAtomsAdaptor
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import ROOT, id_col, timestamp, today
21+
from matbench_discovery import ROOT, formula_col, id_col, timestamp, today
2222
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
2323
from matbench_discovery.plots import wandb_scatter
2424
from matbench_discovery.slurm import slurm_submit
@@ -164,7 +164,7 @@
164164
df_wbm[e_pred_col] = df_out[e_pred_col]
165165
table = wandb.Table(
166166
dataframe=df_wbm.dropna()[
167-
["uncorrected_energy", e_pred_col, "formula"]
167+
["uncorrected_energy", e_pred_col, formula_col]
168168
].reset_index()
169169
)
170170

models/voronoi/train_test_voronoi_rf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sklearn.metrics import r2_score
1414
from sklearn.pipeline import Pipeline
1515

16-
from matbench_discovery import ROOT, id_col, today
16+
from matbench_discovery import ROOT, formula_col, id_col, today
1717
from matbench_discovery.data import DATA_FILES, df_wbm, glob_to_df
1818
from matbench_discovery.plots import wandb_scatter
1919
from matbench_discovery.preds import e_form_col as test_e_form_col
@@ -123,7 +123,7 @@
123123
df_wbm[pred_col].round(4).to_csv(out_path)
124124

125125
table = wandb.Table(
126-
dataframe=df_wbm[["formula", test_e_form_col, pred_col]].reset_index()
126+
dataframe=df_wbm[[formula_col, test_e_form_col, pred_col]].reset_index()
127127
)
128128

129129
df_wbm[pred_col].isna().sum()

models/wrenformer/analyze_wrenformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pymatviz.ptable import ptable_heatmap_plotly
1111
from pymatviz.utils import add_identity_line, bin_df_cols
1212

13-
from matbench_discovery import PDF_FIGS, SITE_FIGS, id_col
13+
from matbench_discovery import PDF_FIGS, SITE_FIGS, formula_col, id_col
1414
from matbench_discovery.data import DATA_FILES, df_wbm
1515
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col
1616

@@ -85,7 +85,7 @@
8585

8686

8787
# %%
88-
fig = ptable_heatmap_plotly(df_bad.formula)
88+
fig = ptable_heatmap_plotly(df_bad[formula_col])
8989
fig.layout.title = f"Elements in {title}"
9090
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
9191
fig.show()

0 commit comments

Comments
 (0)