Skip to content

Commit

Permalink
refactor from matplotlib to plotly in EDA scripts
Browse files Browse the repository at this point in the history
- eda_mp_trj.py replace matplotlib with Plotly for element count heatmaps and ratios
- plot_structure_perturbation.py utilize Plotly for histogram and structure visualizations
- wbm_umap_projection.py use Plotly for UMAP scatter plots, improving interactivity and aesthetics
  • Loading branch information
janosh committed Dec 21, 2024
1 parent 99f1646 commit 45fd5b0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 43 deletions.
38 changes: 17 additions & 21 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd
import plotly.express as px
import pymatviz as pmv
from matplotlib.colors import SymLogNorm
from pymatgen.core import Composition
from pymatgen.core.tensors import Tensor
from pymatviz.enums import Key
Expand Down Expand Up @@ -243,42 +242,39 @@ def info_dict_to_id(info: dict[str, int | str]) -> str:

excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()

ax_ptable = pmv.ptable_heatmap( # matplotlib version looks better for SI
fig = pmv.ptable_heatmap_plotly(
trj_elem_counts,
# zero_color="#efefef",
log=(log := SymLogNorm(linthresh=10_000)),
exclude_elements=excl_elems, # drop noble gases
# cbar_range=None if excl_noble else (10_000, None),
show_values=(show_vals := True),
# label_font_size=17 if show_vals else 25,
# value_font_size=14,
cbar_title="MPtrj Element Counts",
log=(log := True),
colorbar=dict(title="MPtrj Element Counts"),
)

img_name = f"mp-trj-element-counts-by-{count_mode}"
if log:
img_name += "-symlog" if isinstance(log, SymLogNorm) else "-log"
if excl_noble:
img_name += "-excl-noble"
pmv.save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
if log:
img_name += "-log"
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
fig.show()


# %%
normalized = True
ax_ptable = pmv.ptable_heatmap_ratio(
trj_elem_counts / (len(df_mp_trj) if normalized else 1),
mp_occu_counts / (len(df_mp) if normalized else 1),
zero_color="#efefef",
fmt=".2f",
not_in_denominator=None,
not_in_numerator=None,
not_in_either=None,
fig = pmv.ptable_heatmap_plotly(
{
elem: (trj_count / 1_580_395) / (mp_count / len(df_mp))
for elem, trj_count in trj_elem_counts.items()
if elem in mp_occu_counts
for mp_count in [mp_occu_counts[elem]] # clever way to get mp_count in scope
},
colorbar=dict(title="MPtrj/MP Element Count Ratio"),
)

img_name = "mp-trj-mp-ratio-element-counts-by-occurrence"
if normalized:
img_name += "-normalized"
pmv.save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
fig.show()


# %% plot formation energy per atom distribution
Expand Down
23 changes: 12 additions & 11 deletions models/cgcnn/plot_structure_perturbation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# %%
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymatviz as pmv
Expand All @@ -11,28 +10,30 @@
__date__ = "2022-12-02"

rng = np.random.default_rng(0)
pmv.set_plotly_template("pymatviz_dark")


# %%
ax = pd.Series(rng.weibull(1.5, 100_000)).hist(bins=100)
fig = pd.Series(rng.weibull(1.5, 100_000)).hist(bins=100, backend="plotly")
title = "Distribution of perturbation magnitudes"
ax.set(xlabel="magnitude of perturbation", ylabel="count", title=title)
fig.layout.update(xaxis_title="Perturbation Magnitude", title=title)
fig.show()


# %%
struct = Structure(
lattice=Lattice.cubic(5),
species=("Fe", "O"),
species=("Fe", "Fe"),
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
)

ax = pmv.structure_2d(struct)
ax.set(title=f"Original structure: {struct.formula}")
ax.set_aspect("equal")
fig = pmv.structure_2d_plotly(struct)
fig.layout.update(title=f"Original structure: {struct.formula}")
fig.show()


# %%
fig, axs = plt.subplots(3, 4, figsize=(12, 10))
for idx, ax in enumerate(axs.flat, start=1):
pmv.structure_2d(perturb_structure(struct), ax=ax)
ax.set(title=f"perturbation {idx}")
pmv.structure_2d_plotly(
[perturb_structure(struct) for _ in range(12)],
subplot_title=lambda _struct, idx: f"perturbation {idx}",
)
23 changes: 12 additions & 11 deletions scripts/wbm_umap_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# %%
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymatviz as pmv
Expand Down Expand Up @@ -189,18 +188,20 @@ def features_to_drop(df_in: pd.DataFrame, threshold: float = 0.95) -> list[str]:
umap_cols = list(df_umap)
if umap_cols != ["UMAP 1", "UMAP 2"]:
raise ValueError(f"Unexpected {umap_cols=}")
min_step, max_step = df_umap.index.min(), df_umap.index.max()
ax = df_umap.plot.scatter(
*umap_cols, c=df_umap.index, cmap="Spectral", s=5, figsize=(6, 4), colorbar=False

fig = df_umap.plot.scatter(
x="UMAP 1",
y="UMAP 2",
color=df_umap.index.astype(str).str.replace("0", "MP original structures"),
backend="plotly",
template="pymatviz_white",
)
cbar = ax.figure.colorbar(
ax.collections[0],
boundaries=np.arange(min_step, max_step + 2) - 0.5,
ticks=range(min_step, max_step + 1),
fig.layout.legend.update(
title="WBM step:", orientation="h", y=1.1, itemsizing="constant"
)
cbar.ax.set_title("WBM step (0 = MP)", rotation=90, y=0.5, x=3, va="center")
fig.show()


# %%
plt.tight_layout()
pmv.save_fig(ax, f"{PDF_FIGS}/wbm-final-struct-matminer-features-2d-umap.png", dpi=300)
img_path = f"{PDF_FIGS}/wbm-final-struct-matminer-features-2d-umap.png"
pmv.save_fig(fig, img_path, scale=3)

0 comments on commit 45fd5b0

Please sign in to comment.