diff --git a/src/pymatgen/phonon/plotter.py b/src/pymatgen/phonon/plotter.py index abbd219d530..d5dd63bc2ec 100644 --- a/src/pymatgen/phonon/plotter.py +++ b/src/pymatgen/phonon/plotter.py @@ -8,7 +8,9 @@ import matplotlib.pyplot as plt import numpy as np import scipy.constants as const +from matplotlib import colors from matplotlib.collections import LineCollection +from matplotlib.colors import LinearSegmentedColormap from monty.json import jsanitize from pymatgen.electronic_structure.plotter import BSDOSPlotter, plot_brillouin_zone from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine @@ -1052,26 +1054,58 @@ def bs_plot_data(self) -> dict[str, Any]: "lattice": self._bs.lattice_rec.as_dict(), } - def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes: + def get_plot_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> Axes: """Get a matplotlib object for the Gruneisen bandstructure plot. Args: ylim: Specify the y-axis (gruneisen) limits; by default None let the code choose. + plot_ph_bs_with_gruneisen (bool): Plot phonon band-structure with bands coloured + as per Gruneisen parameter values on a logarithmic scale **kwargs: additional keywords passed to ax.plot(). """ + u = freq_units(kwargs.get("units", "THz")) ax = pretty_plot(12, 8) + # Create a colormap (default is red to blue) + cmap = LinearSegmentedColormap.from_list("cmap", kwargs.get("cmap", ["red", "blue"])) + kwargs.setdefault("linewidth", 2) kwargs.setdefault("marker", "o") kwargs.setdefault("markersize", 2) data = self.bs_plot_data() - for dist_idx in range(len(data["distances"])): + + # extract min and max Grüneisen parameter values + max_gruneisen = np.array(data["gruneisen"]).max() + min_gruneisen = np.array(data["gruneisen"]).min() + + # LogNormalize colormap based on the min and max Grüneisen parameter values + norm = colors.SymLogNorm( + vmin=min_gruneisen, + vmax=max_gruneisen, + linthresh=1e-2, + linscale=1, + ) + + sc = None + for (dists_inx, dists), (_, freqs) in zip(enumerate(data["distances"]), enumerate(data["frequency"])): for band_idx in range(self.n_bands): - ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))] + if plot_ph_bs_with_gruneisen: + ys = [freqs[band_idx][j] * u.factor for j in range(len(dists))] + ys_gru = [ + data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx])) + ] + sc = ax.scatter(dists, ys, c=ys_gru, cmap=cmap, norm=norm, marker="o", s=1) + else: + keys_to_remove = ("units", "cmap") # needs to be removed before passing to line-plot + for k in keys_to_remove: + kwargs.pop(k, None) + ys = [ + data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx])) + ] - ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs) + ax.plot(data["distances"][dists_inx], ys, "b-", **kwargs) self._make_ticks(ax) @@ -1079,8 +1113,16 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes: ax.axhline(0, linewidth=1, color="black") # Main X and Y Labels - ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) - ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30) + if plot_ph_bs_with_gruneisen: + ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) + units = kwargs.get("units", "THz") + ax.set_ylabel(f"Frequencies ({units})", fontsize=30) + + cbar = plt.colorbar(sc, ax=ax) + cbar.set_label(r"$\gamma \ \mathrm{(logarithmized)}$", fontsize=30) + else: + ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) + ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30) # X range (K) # last distance point @@ -1094,24 +1136,37 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes: return ax - def show_gs(self, ylim: float | None = None) -> None: + def show_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> None: """Show the plot using matplotlib. Args: ylim: Specifies the y-axis limits. + plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured + as per Gruneisen parameter values on a logarithmic scale + **kwargs: kwargs passed to get_plot_gs """ - self.get_plot_gs(ylim) + self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs) plt.show() - def save_plot_gs(self, filename: str | PathLike, img_format: str = "eps", ylim: float | None = None) -> None: + def save_plot_gs( + self, + filename: str | PathLike, + img_format: str = "eps", + ylim: float | None = None, + plot_ph_bs_with_gruneisen: bool = False, + **kwargs, + ) -> None: """Save matplotlib plot to a file. Args: filename: Filename to write to. img_format: Image format to use. Defaults to EPS. ylim: Specifies the y-axis limits. + plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured + as per Gruneisen parameter values on a logarithmic scale + **kwargs: kwargs passed to get_plot_gs """ - self.get_plot_gs(ylim=ylim) + self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs) plt.savefig(filename, format=img_format) plt.close() diff --git a/tests/phonon/test_gruneisen.py b/tests/phonon/test_gruneisen.py index 321702140af..3040db72cf3 100644 --- a/tests/phonon/test_gruneisen.py +++ b/tests/phonon/test_gruneisen.py @@ -1,7 +1,9 @@ from __future__ import annotations import matplotlib.pyplot as plt +import numpy as np import pytest +from matplotlib import colors from pymatgen.io.phonopy import get_gruneisen_ph_bs_symm_line, get_gruneisenparameter from pymatgen.phonon.gruneisen import GruneisenParameter from pymatgen.phonon.plotter import GruneisenPhononBandStructureSymmLine, GruneisenPhononBSPlotter, GruneisenPlotter @@ -31,6 +33,39 @@ def test_plot(self): ax = plotter.get_plot_gs() assert isinstance(ax, plt.Axes) + def test_ph_plot_w_gruneisen(self): + plotter = GruneisenPhononBSPlotter(bs=self.bs_symm_line) + ax = plotter.get_plot_gs(plot_ph_bs_with_gruneisen=True, units="THz", cmap=["red", "royalblue"]) + assert ax.get_ylabel() == "Frequencies (THz)" + assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$" + assert ax.get_figure()._localaxes[-1].get_ylabel() == "$\\gamma \\ \\mathrm{(logarithmized)}$" + assert len(ax._children) == plotter.n_bands + 1 # check for number of bands + # check for x and y data is really the band-structure data + for inx, band in enumerate(plotter._bs.bands): + xy_data = { + "x": [point[0] for point in ax._children[inx].get_offsets().data], + "y": [point[1] for point in ax._children[inx].get_offsets().data], + } + assert band == pytest.approx(xy_data["y"]) + assert plotter._bs.distance == pytest.approx(xy_data["x"]) + + # check if color bar max value matches maximum gruneisen parameter value + data = plotter.bs_plot_data() + + # get reference min and max Grüneisen parameter values + max_gruneisen = np.array(data["gruneisen"]).max() + min_gruneisen = np.array(data["gruneisen"]).min() + + norm = colors.SymLogNorm( + vmin=min_gruneisen, + vmax=max_gruneisen, + linthresh=1e-2, + linscale=1, + ) + + assert max(norm.inverse(ax.get_figure()._localaxes[-1].get_yticks())) == pytest.approx(max_gruneisen) + assert isinstance(ax, plt.Axes) + def test_as_dict_from_dict(self): new_dict = self.bs_symm_line.as_dict() self.new_bs_symm_line = GruneisenPhononBandStructureSymmLine.from_dict(new_dict)