Skip to content

Commit

Permalink
use SymLogNorm for MP+WBM+MPtrj ptable element count heatmaps
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Dec 6, 2023
1 parent 62a6458 commit f959ee8
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 32 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.5
rev: v0.1.7
hooks:
- id: ruff
args: [--fix]
Expand All @@ -30,7 +30,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-pyyaml, types-requests]
Expand All @@ -45,7 +45,7 @@ repos:
args: [--ignore-words-list, "nd,te,fpr", --check-filenames]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v4.0.0-alpha.3
hooks:
- id: prettier
args: [--write] # edit files in-place
Expand All @@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.53.0
rev: v8.55.0
hooks:
- id: eslint
types: [file]
Expand Down
22 changes: 11 additions & 11 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import pandas as pd
import plotly.express as px
from matplotlib.colors import SymLogNorm
from pymatgen.core import Composition
from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio, ptable_hists
from pymatviz.io import save_fig
Expand Down Expand Up @@ -213,28 +214,27 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:


# %%
count_mode = "composition"
if "trj_elem_counts" not in locals():
trj_elem_counts = pd.read_json(
f"{data_page}/mp-trj-element-counts-by-{count_mode}.json",
typ="series",
)
count_mode = "occurrence"
trj_elem_counts = pd.read_json(
f"{data_page}/mp-trj-element-counts-by-{count_mode}.json", typ="series"
)

excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()

ax_ptable = ptable_heatmap( # matplotlib version looks better for SI
trj_elem_counts,
fmt=lambda x, _: si_fmt(x, ".0f"),
cbar_fmt=lambda x, _: si_fmt(x, ".0f"),
zero_color="#efefef",
log=(log := True),
log=(log := SymLogNorm(linthresh=10_000)),
exclude_elements=excl_elems, # drop noble gases
cbar_range=None if excl_noble else (10_000, None),
# cbar_range=None if excl_noble else (10_000, None),
label_font_size=17,
value_font_size=14,
cbar_title="MPtrj Element Counts",
)

img_name = f"mp-trj-element-counts-by-{count_mode}{'-log' if log else ''}"
img_name = f"mp-trj-element-counts-by-{count_mode}"
if log:
img_name += "-symlog" if isinstance(log, SymLogNorm) else "-log"
if excl_noble:
img_name += "-excl-noble"
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
Expand Down
13 changes: 6 additions & 7 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import plotly.express as px
from matplotlib.colors import SymLogNorm
from pymatgen.core import Composition
from pymatviz import (
count_elements,
Expand Down Expand Up @@ -64,13 +65,9 @@


# %%
log = True
for dataset, count_mode, elem_counts in all_counts:
filename = f"{dataset}-element-counts-by-{count_mode}"
if log:
filename += "-log"
else:
elem_counts.to_json(f"{data_page}/{filename}.json")
elem_counts.to_json(f"{data_page}/{filename}.json")

title = f"Number of {dataset.upper()} structures containing each element"
fig = ptable_heatmap_plotly(elem_counts, font_size=10)
Expand All @@ -85,9 +82,11 @@
label_font_size=17,
value_font_size=14,
cbar_title=f"{dataset.upper()} Element Count",
log=log,
cbar_range=(100, None),
log=(log := SymLogNorm(linthresh=100)),
# cbar_range=(100, None),
)
if log:
filename += "-symlog" if isinstance(log, SymLogNorm) else "-log"
save_fig(ax_mp_cnt, f"{PDF_FIGS}/{filename}.pdf")


Expand Down
7 changes: 4 additions & 3 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def load_df_wbm_with_preds(

else:
cols = list(df)
raise ValueError(
f"No pred col for {model_name=} ({model_key=}), available {cols=}"
)
msg = f"No pred col for {model_name=}, available {cols=}"
if model_name != model_key:
msg = msg.replace(", ", f" ({model_key=}), ")
raise ValueError(msg)

return df_out

Expand Down
5 changes: 3 additions & 2 deletions models/chgnet/join_chgnet_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@


# %% compute corrected formation energies
e_form_chgnet_col = "e_form_per_atom_chgnet"
e_pred_col = "chgnet_energy"
e_form_chgnet_col = f"e_form_per_atom_{e_pred_col.split('_energy')[0]}"
df_chgnet[formula_col] = df_preds[formula_col]
df_chgnet[e_form_chgnet_col] = [
get_e_form_per_atom(dict(energy=ene, composition=formula))
for formula, ene in tqdm(
df_chgnet.set_index(formula_col).chgnet_energy.items(), total=len(df_chgnet)
df_chgnet.set_index(formula_col)[e_pred_col].items(), total=len(df_chgnet)
)
]
df_preds[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]
Expand Down
11 changes: 6 additions & 5 deletions scripts/model_figs/make_metrics_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
n_structs = MODEL_METADATA[model_name]["training_set"]["n_structures"]
n_materials = MODEL_METADATA[model_name]["training_set"].get("n_materials")

formatted = si_fmt(n_structs)
n_structs_fmt = si_fmt(n_structs)
if n_materials:
formatted += f" <small>({si_fmt(n_materials)})</small>"
n_structs_fmt += f" <small>({si_fmt(n_materials)})</small>"

df_metrics.loc[train_size_col, model] = formatted
df_metrics_10k.loc[train_size_col, model] = formatted
df_metrics.loc[train_size_col, model] = n_structs_fmt
df_metrics_10k.loc[train_size_col, model] = n_structs_fmt


# %% add dummy classifier results to df_metrics
Expand Down Expand Up @@ -157,6 +157,7 @@
cmap="viridis_r", subset=list(lower_is_better & {*df_filtered})
)
)
# add up/down arrows to indicate which metrics are better when higher/lower
arrow_suffix = dict.fromkeys(higher_is_better, " ↑") | dict.fromkeys(
lower_is_better, " ↓"
)
Expand All @@ -182,7 +183,7 @@
f"{SITE_FIGS}/metrics-table{label}.svelte",
inline_props="class='roomy'",
# draw dotted line between classification and regression metrics
styles=f"{col_selector} {{ border-left: 1px dotted white; }}{hide_scroll_bar}",
styles=f"{col_selector} {{ border-left: 2px dotted white; }}{hide_scroll_bar}",
)
try:
df_to_pdf(styler, f"{PDF_FIGS}/metrics-table{label}.pdf")
Expand Down

0 comments on commit f959ee8

Please sign in to comment.