Skip to content

Commit

Permalink
add optional gruneisen parameter colorbar plot (#3908)
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash authored Jul 3, 2024
1 parent 1bddad2 commit 0f22690
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 10 deletions.
75 changes: 65 additions & 10 deletions src/pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import matplotlib.pyplot as plt
import numpy as np
import scipy.constants as const
from matplotlib import colors
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from monty.json import jsanitize
from pymatgen.electronic_structure.plotter import BSDOSPlotter, plot_brillouin_zone
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
Expand Down Expand Up @@ -1052,35 +1054,75 @@ def bs_plot_data(self) -> dict[str, Any]:
"lattice": self._bs.lattice_rec.as_dict(),
}

def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
def get_plot_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> Axes:
"""Get a matplotlib object for the Gruneisen bandstructure plot.
Args:
ylim: Specify the y-axis (gruneisen) limits; by default None let
the code choose.
plot_ph_bs_with_gruneisen (bool): Plot phonon band-structure with bands coloured
as per Gruneisen parameter values on a logarithmic scale
**kwargs: additional keywords passed to ax.plot().
"""
u = freq_units(kwargs.get("units", "THz"))
ax = pretty_plot(12, 8)

# Create a colormap (default is red to blue)
cmap = LinearSegmentedColormap.from_list("cmap", kwargs.get("cmap", ["red", "blue"]))

kwargs.setdefault("linewidth", 2)
kwargs.setdefault("marker", "o")
kwargs.setdefault("markersize", 2)

data = self.bs_plot_data()
for dist_idx in range(len(data["distances"])):

# extract min and max Grüneisen parameter values
max_gruneisen = np.array(data["gruneisen"]).max()
min_gruneisen = np.array(data["gruneisen"]).min()

# LogNormalize colormap based on the min and max Grüneisen parameter values
norm = colors.SymLogNorm(
vmin=min_gruneisen,
vmax=max_gruneisen,
linthresh=1e-2,
linscale=1,
)

sc = None
for (dists_inx, dists), (_, freqs) in zip(enumerate(data["distances"]), enumerate(data["frequency"])):
for band_idx in range(self.n_bands):
ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))]
if plot_ph_bs_with_gruneisen:
ys = [freqs[band_idx][j] * u.factor for j in range(len(dists))]
ys_gru = [
data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx]))
]
sc = ax.scatter(dists, ys, c=ys_gru, cmap=cmap, norm=norm, marker="o", s=1)
else:
keys_to_remove = ("units", "cmap") # needs to be removed before passing to line-plot
for k in keys_to_remove:
kwargs.pop(k, None)
ys = [
data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx]))
]

ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs)
ax.plot(data["distances"][dists_inx], ys, "b-", **kwargs)

self._make_ticks(ax)

# plot y=0 line
ax.axhline(0, linewidth=1, color="black")

# Main X and Y Labels
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)
if plot_ph_bs_with_gruneisen:
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
units = kwargs.get("units", "THz")
ax.set_ylabel(f"Frequencies ({units})", fontsize=30)

cbar = plt.colorbar(sc, ax=ax)
cbar.set_label(r"$\gamma \ \mathrm{(logarithmized)}$", fontsize=30)
else:
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)

# X range (K)
# last distance point
Expand All @@ -1094,24 +1136,37 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:

return ax

def show_gs(self, ylim: float | None = None) -> None:
def show_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> None:
"""Show the plot using matplotlib.
Args:
ylim: Specifies the y-axis limits.
plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
as per Gruneisen parameter values on a logarithmic scale
**kwargs: kwargs passed to get_plot_gs
"""
self.get_plot_gs(ylim)
self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs)
plt.show()

def save_plot_gs(self, filename: str | PathLike, img_format: str = "eps", ylim: float | None = None) -> None:
def save_plot_gs(
self,
filename: str | PathLike,
img_format: str = "eps",
ylim: float | None = None,
plot_ph_bs_with_gruneisen: bool = False,
**kwargs,
) -> None:
"""Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
ylim: Specifies the y-axis limits.
plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
as per Gruneisen parameter values on a logarithmic scale
**kwargs: kwargs passed to get_plot_gs
"""
self.get_plot_gs(ylim=ylim)
self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs)
plt.savefig(filename, format=img_format)
plt.close()

Expand Down
35 changes: 35 additions & 0 deletions tests/phonon/test_gruneisen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib import colors
from pymatgen.io.phonopy import get_gruneisen_ph_bs_symm_line, get_gruneisenparameter
from pymatgen.phonon.gruneisen import GruneisenParameter
from pymatgen.phonon.plotter import GruneisenPhononBandStructureSymmLine, GruneisenPhononBSPlotter, GruneisenPlotter
Expand Down Expand Up @@ -31,6 +33,39 @@ def test_plot(self):
ax = plotter.get_plot_gs()
assert isinstance(ax, plt.Axes)

def test_ph_plot_w_gruneisen(self):
plotter = GruneisenPhononBSPlotter(bs=self.bs_symm_line)
ax = plotter.get_plot_gs(plot_ph_bs_with_gruneisen=True, units="THz", cmap=["red", "royalblue"])
assert ax.get_ylabel() == "Frequencies (THz)"
assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$"
assert ax.get_figure()._localaxes[-1].get_ylabel() == "$\\gamma \\ \\mathrm{(logarithmized)}$"
assert len(ax._children) == plotter.n_bands + 1 # check for number of bands
# check for x and y data is really the band-structure data
for inx, band in enumerate(plotter._bs.bands):
xy_data = {
"x": [point[0] for point in ax._children[inx].get_offsets().data],
"y": [point[1] for point in ax._children[inx].get_offsets().data],
}
assert band == pytest.approx(xy_data["y"])
assert plotter._bs.distance == pytest.approx(xy_data["x"])

# check if color bar max value matches maximum gruneisen parameter value
data = plotter.bs_plot_data()

# get reference min and max Grüneisen parameter values
max_gruneisen = np.array(data["gruneisen"]).max()
min_gruneisen = np.array(data["gruneisen"]).min()

norm = colors.SymLogNorm(
vmin=min_gruneisen,
vmax=max_gruneisen,
linthresh=1e-2,
linscale=1,
)

assert max(norm.inverse(ax.get_figure()._localaxes[-1].get_yticks())) == pytest.approx(max_gruneisen)
assert isinstance(ax, plt.Axes)

def test_as_dict_from_dict(self):
new_dict = self.bs_symm_line.as_dict()
self.new_bs_symm_line = GruneisenPhononBandStructureSymmLine.from_dict(new_dict)
Expand Down

0 comments on commit 0f22690

Please sign in to comment.