diff --git a/pymatgen/phonon/dos.py b/pymatgen/phonon/dos.py index dcdfb13d645..c08e7f8aca7 100644 --- a/pymatgen/phonon/dos.py +++ b/pymatgen/phonon/dos.py @@ -161,7 +161,7 @@ def _positive_densities(self) -> np.ndarray: """Numpy array containing the list of densities corresponding to positive frequencies.""" return self.densities[self.ind_zero_freq :] - def cv(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def cv(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Constant volume specific heat C_v at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number @@ -198,7 +198,7 @@ def csch2(x): return cv - def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def entropy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Vibrational entropy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number @@ -233,7 +233,7 @@ def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> return entropy - def internal_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def internal_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Phonon contribution to the internal energy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number @@ -268,7 +268,7 @@ def internal_energy(self, temp: float, structure: Structure | None = None, **kwa return e_phonon - def helmholtz_free_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def helmholtz_free_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Phonon contribution to the Helmholtz free energy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 784ee8f866a..a89972fd24c 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -8,7 +8,6 @@ import matplotlib.pyplot as plt import numpy as np -import palettable import scipy.constants as const from matplotlib.collections import LineCollection from monty.json import jsanitize @@ -95,19 +94,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None: ) self.stack = stack self.sigma = sigma - self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {} + self._doses: dict[str, dict[str, np.ndarray]] = {} - def add_dos(self, label: str, dos: PhononDos) -> None: + def add_dos(self, label: str, dos: PhononDos, **kwargs: Any) -> None: """Adds a dos for plotting. Args: - label: - label for the DOS. Must be unique. - dos: - PhononDos object + label (str): label for the DOS. Must be unique. + dos (PhononDos): DOS object + **kwargs: kwargs supported by matplotlib.pyplot.plot """ densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities - self._doses[label] = {"frequencies": dos.frequencies, "densities": densities} + self._doses[label] = {"frequencies": dos.frequencies, "densities": densities, **kwargs} def add_dos_dict(self, dos_dict: dict, key_sort_func=None) -> None: """Add a dictionary of doses, with an optional sorting function for the @@ -160,8 +158,6 @@ def get_plot( n_colors = max(3, len(self._doses)) n_colors = min(9, n_colors) - colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors - y = None all_densities = [] all_frequencies = [] @@ -186,18 +182,14 @@ def get_plot( all_densities.reverse() all_frequencies.reverse() all_pts = [] + colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive") for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)): + color = self._doses[key].get("color", colors[idx % n_colors]) all_pts.extend(list(zip(frequencies, densities))) if self.stack: - ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key)) + ax.fill(frequencies, densities, color=color, label=str(key)) else: - ax.plot( - frequencies, - densities, - color=colors[idx % n_colors], - label=str(key), - linewidth=3, - ) + ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3) if xlim: ax.set_xlim(xlim) @@ -297,13 +289,9 @@ def _make_ticks(self, ax: Axes) -> Axes: ax.set_xticks(uniq_d) ax.set_xticklabels(uniq_l) - for idx in range(len(ticks["label"])): - if ticks["label"][idx] is not None: - # don't print the same label twice - if idx != 0: - ax.axvline(ticks["distance"][idx], color="k") - else: - ax.axvline(ticks["distance"][idx], color="k") + for idx, label in enumerate(ticks["label"]): + if label is not None: + ax.axvline(ticks["distance"][idx], color="k") return ax def bs_plot_data(self) -> dict[str, Any]: @@ -356,14 +344,11 @@ def get_plot( ax = pretty_plot(12, 8) data = self.bs_plot_data() - for d in range(len(data["distances"])): + kwargs.setdefault("color", "blue") + for dists, freqs in zip(data["distances"], data["frequency"]): for idx in range(self._nb_bands): - ax.plot( - data["distances"][d], - [data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))], - "b-", - **kwargs, - ) + ys = [freqs[idx][j] * u.factor for j in range(len(dists))] + ax.plot(dists, ys, **kwargs) self._make_ticks(ax) @@ -598,15 +583,15 @@ def get_ticks(self) -> dict[str, list]: label0 = f"${label0}$" tick_labels.pop() tick_distance.pop() - tick_labels.append(f"{label0}$\\mid${label1}") + tick_labels.append(f"{label0}|{label1}") elif point.label.startswith("\\") or point.label.find("_") != -1: tick_labels.append(f"${point.label}$") else: - # map atomate2 all-upper-case point.labels to pretty LaTeX - label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label) - tick_labels.append(label) + tick_labels.append(point.label) previous_label = point.label previous_branch = this_branch + # map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols + tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ").replace("SIGMA", "Σ") for label in tick_labels] return {"distance": tick_distance, "label": tick_labels} def plot_compare( @@ -616,6 +601,7 @@ def plot_compare( labels: tuple[str, str] | None = None, legend_kwargs: dict | None = None, on_incompatible: Literal["raise", "warn", "ignore"] = "raise", + other_kwargs: dict | None = None, **kwargs, ) -> Axes: """Plot two band structure for comparison. One is in red the other in blue. @@ -634,6 +620,7 @@ def plot_compare( legend_kwargs: dict[str, Any]: kwargs passed to ax.legend(). on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible. Defaults to 'raise'. + other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot(). **kwargs: passed to ax.plot(). Returns: @@ -641,7 +628,8 @@ def plot_compare( """ unit = freq_units(units) legend_kwargs = legend_kwargs or {} - legend_kwargs.setdefault("fontsize", 22) + other_kwargs = other_kwargs or {} + legend_kwargs.setdefault("fontsize", 20) data_orig = self.bs_plot_data() data = other_plotter.bs_plot_data() @@ -656,24 +644,22 @@ def plot_compare( line_width = kwargs.setdefault("linewidth", 1) ax = self.get_plot(units=units, **kwargs) - for band_idx in range(other_plotter._nb_bands): - for dist_idx in range(len(data_orig["distances"])): - ax.plot( - data_orig["distances"][dist_idx], - [ - data["frequency"][dist_idx][band_idx][j] * unit.factor - for j in range(len(data_orig["distances"][dist_idx])) - ], - "r-", - **kwargs, - ) - # add legend showing which color correspond to which band structure - if labels is None and self._label and other_plotter._label: - labels = (self._label, other_plotter._label) - if labels: - ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width) - ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width) + kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color + + for band_idx in range(other_plotter._nb_bands): + for dist_idx, dists in enumerate(data_orig["distances"]): + xs = dists + ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))] + ax.plot(xs, ys, **(kwargs | other_kwargs)) + + # add legend showing which color corresponds to which band structure + if labels or (self._label and other_plotter._label): + color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color() + label_self, label_other = labels or (self._label, other_plotter._label) + ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self) + linestyle = other_kwargs.get("linestyle", "-") + ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle) ax.legend(**legend_kwargs) return ax