Skip to content

Commit

Permalink
Equalize Phonon(Dos|BS)Plotter colors, allow custom plot settings p…
Browse files Browse the repository at this point in the history
…er-DOS (#3514)

* make temp optional to allow falling back to t if temp not passed

* allow passing arbitrary kwargs into PhononDosPlotter.add_dos for use in e.g. color customization

* change default line colors of PhononDosPlotter and PhononBSPlotter to tab:10

tab:blue and tab:orange in particular

* fix overlapping an non-symbol band struct x-labels

label.replace("GAMMA", "Γ").replace("DELTA", "Δ")

* change colors from tab10 back to regular red/blue

* plot_compare
add keyword other_kwargs to customize 2nd set of band lines
  • Loading branch information
janosh authored Dec 14, 2023
1 parent d860b0a commit 529eceb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 58 deletions.
8 changes: 4 additions & 4 deletions pymatgen/phonon/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _positive_densities(self) -> np.ndarray:
"""Numpy array containing the list of densities corresponding to positive frequencies."""
return self.densities[self.ind_zero_freq :]

def cv(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def cv(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Constant volume specific heat C_v at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -198,7 +198,7 @@ def csch2(x):

return cv

def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def entropy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Vibrational entropy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -233,7 +233,7 @@ def entropy(self, temp: float, structure: Structure | None = None, **kwargs) ->

return entropy

def internal_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def internal_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Phonon contribution to the internal energy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -268,7 +268,7 @@ def internal_energy(self, temp: float, structure: Structure | None = None, **kwa

return e_phonon

def helmholtz_free_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def helmholtz_free_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Phonon contribution to the Helmholtz free energy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down
94 changes: 40 additions & 54 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import matplotlib.pyplot as plt
import numpy as np
import palettable
import scipy.constants as const
from matplotlib.collections import LineCollection
from monty.json import jsanitize
Expand Down Expand Up @@ -95,19 +94,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
)
self.stack = stack
self.sigma = sigma
self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {}
self._doses: dict[str, dict[str, np.ndarray]] = {}

def add_dos(self, label: str, dos: PhononDos) -> None:
def add_dos(self, label: str, dos: PhononDos, **kwargs: Any) -> None:
"""Adds a dos for plotting.
Args:
label:
label for the DOS. Must be unique.
dos:
PhononDos object
label (str): label for the DOS. Must be unique.
dos (PhononDos): DOS object
**kwargs: kwargs supported by matplotlib.pyplot.plot
"""
densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities}
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities, **kwargs}

def add_dos_dict(self, dos_dict: dict, key_sort_func=None) -> None:
"""Add a dictionary of doses, with an optional sorting function for the
Expand Down Expand Up @@ -160,8 +158,6 @@ def get_plot(
n_colors = max(3, len(self._doses))
n_colors = min(9, n_colors)

colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors

y = None
all_densities = []
all_frequencies = []
Expand All @@ -186,18 +182,14 @@ def get_plot(
all_densities.reverse()
all_frequencies.reverse()
all_pts = []
colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive")
for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)):
color = self._doses[key].get("color", colors[idx % n_colors])
all_pts.extend(list(zip(frequencies, densities)))
if self.stack:
ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key))
ax.fill(frequencies, densities, color=color, label=str(key))
else:
ax.plot(
frequencies,
densities,
color=colors[idx % n_colors],
label=str(key),
linewidth=3,
)
ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3)

if xlim:
ax.set_xlim(xlim)
Expand Down Expand Up @@ -297,13 +289,9 @@ def _make_ticks(self, ax: Axes) -> Axes:
ax.set_xticks(uniq_d)
ax.set_xticklabels(uniq_l)

for idx in range(len(ticks["label"])):
if ticks["label"][idx] is not None:
# don't print the same label twice
if idx != 0:
ax.axvline(ticks["distance"][idx], color="k")
else:
ax.axvline(ticks["distance"][idx], color="k")
for idx, label in enumerate(ticks["label"]):
if label is not None:
ax.axvline(ticks["distance"][idx], color="k")
return ax

def bs_plot_data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -356,14 +344,11 @@ def get_plot(
ax = pretty_plot(12, 8)

data = self.bs_plot_data()
for d in range(len(data["distances"])):
kwargs.setdefault("color", "blue")
for dists, freqs in zip(data["distances"], data["frequency"]):
for idx in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))],
"b-",
**kwargs,
)
ys = [freqs[idx][j] * u.factor for j in range(len(dists))]
ax.plot(dists, ys, **kwargs)

self._make_ticks(ax)

Expand Down Expand Up @@ -598,15 +583,15 @@ def get_ticks(self) -> dict[str, list]:
label0 = f"${label0}$"
tick_labels.pop()
tick_distance.pop()
tick_labels.append(f"{label0}$\\mid${label1}")
tick_labels.append(f"{label0}|{label1}")
elif point.label.startswith("\\") or point.label.find("_") != -1:
tick_labels.append(f"${point.label}$")
else:
# map atomate2 all-upper-case point.labels to pretty LaTeX
label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label)
tick_labels.append(label)
tick_labels.append(point.label)
previous_label = point.label
previous_branch = this_branch
# map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols
tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ").replace("SIGMA", "Σ") for label in tick_labels]
return {"distance": tick_distance, "label": tick_labels}

def plot_compare(
Expand All @@ -616,6 +601,7 @@ def plot_compare(
labels: tuple[str, str] | None = None,
legend_kwargs: dict | None = None,
on_incompatible: Literal["raise", "warn", "ignore"] = "raise",
other_kwargs: dict | None = None,
**kwargs,
) -> Axes:
"""Plot two band structure for comparison. One is in red the other in blue.
Expand All @@ -634,14 +620,16 @@ def plot_compare(
legend_kwargs: dict[str, Any]: kwargs passed to ax.legend().
on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible.
Defaults to 'raise'.
other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot().
**kwargs: passed to ax.plot().
Returns:
a matplotlib object with both band structures
"""
unit = freq_units(units)
legend_kwargs = legend_kwargs or {}
legend_kwargs.setdefault("fontsize", 22)
other_kwargs = other_kwargs or {}
legend_kwargs.setdefault("fontsize", 20)

data_orig = self.bs_plot_data()
data = other_plotter.bs_plot_data()
Expand All @@ -656,24 +644,22 @@ def plot_compare(
line_width = kwargs.setdefault("linewidth", 1)

ax = self.get_plot(units=units, **kwargs)
for band_idx in range(other_plotter._nb_bands):
for dist_idx in range(len(data_orig["distances"])):
ax.plot(
data_orig["distances"][dist_idx],
[
data["frequency"][dist_idx][band_idx][j] * unit.factor
for j in range(len(data_orig["distances"][dist_idx]))
],
"r-",
**kwargs,
)

# add legend showing which color correspond to which band structure
if labels is None and self._label and other_plotter._label:
labels = (self._label, other_plotter._label)
if labels:
ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width)
kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color

for band_idx in range(other_plotter._nb_bands):
for dist_idx, dists in enumerate(data_orig["distances"]):
xs = dists
ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
ax.plot(xs, ys, **(kwargs | other_kwargs))

# add legend showing which color corresponds to which band structure
if labels or (self._label and other_plotter._label):
color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color()
label_self, label_other = labels or (self._label, other_plotter._label)
ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self)
linestyle = other_kwargs.get("linestyle", "-")
ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle)
ax.legend(**legend_kwargs)

return ax
Expand Down

0 comments on commit 529eceb

Please sign in to comment.