From 68cf6b4532e3fb6c79d5051f1edf0d9992ecdfff Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 27 Nov 2023 13:34:04 -0800 Subject: [PATCH] `PhononDosPlotter.plot_dos()` add support for existing `plt.Axes` (#3487) * breaking: PhononBSPlotter.save_plot remove redundant img_format keyword use filename extension to determine image format * formatting * Add support for existing plt axes in PhononDosPlotter.plot_dos() * pretty print DELTA as $\Delta$ in PhononBSPlotter.get_ticks --- .../connectivity/connected_components.py | 9 +------ .../chemenv_strategies.py | 10 +++---- pymatgen/analysis/pourbaix_diagram.py | 4 +-- pymatgen/phonon/plotter.py | 27 +++++++++---------- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/pymatgen/analysis/chemenv/connectivity/connected_components.py b/pymatgen/analysis/chemenv/connectivity/connected_components.py index 4a5114b3821..34f740d706d 100644 --- a/pymatgen/analysis/chemenv/connectivity/connected_components.py +++ b/pymatgen/analysis/chemenv/connectivity/connected_components.py @@ -155,14 +155,7 @@ def make_supergraph(graph, multiplicity, periodicity_vectors): connecting_edges.append((n1, n2, key, new_data)) else: if not np.all(np.array(data["delta"]) == 0): - print( - "delta not equal to periodicity nor 0 ... : ", - n1, - n2, - key, - data["delta"], - data, - ) + print("delta not equal to periodicity nor 0 ... : ", n1, n2, key, data["delta"], data) input("Are we ok with this ?") other_edges.append((n1, n2, key, data)) diff --git a/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py b/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py index 4be3979b8ef..da293298897 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py +++ b/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py @@ -1129,22 +1129,22 @@ def _get_map(self, isite): target_cns = [cg.coordination_number for cg in target_cgs] for ii in range(min([len(maps_and_surfaces), self.max_nabundant])): my_map_and_surface = maps_and_surfaces[order[ii]] - mymap = my_map_and_surface["map"] - cn = mymap[0] + my_map = my_map_and_surface["map"] + cn = my_map[0] if cn not in target_cns or cn > 12 or cn == 0: continue all_conditions = [params[2] for params in my_map_and_surface["parameters_indices"]] if self._additional_condition not in all_conditions: continue - cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][mymap[0]][ - mymap[1] + cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][my_map[0]][ + my_map[1] ].minimum_geometry(symmetry_measure_type=self._symmetry_measure_type) if ( cg in self.target_environments and cgdict["symmetry_measure"] <= self.max_csm and cgdict["symmetry_measure"] < current_target_env_csm ): - current_map = mymap + current_map = my_map current_target_env_csm = cgdict["symmetry_measure"] if current_map is not None: return current_map diff --git a/pymatgen/analysis/pourbaix_diagram.py b/pymatgen/analysis/pourbaix_diagram.py index 2bc62487029..bd776c5eb09 100644 --- a/pymatgen/analysis/pourbaix_diagram.py +++ b/pymatgen/analysis/pourbaix_diagram.py @@ -473,7 +473,7 @@ def __init__( for entry in ion_entries: ion_elts = list(set(entry.elements) - ELEMENTS_HO) # TODO: the logic here for ion concentration setting is in two - # places, in PourbaixEntry and here, should be consolidated + # places, in PourbaixEntry and here, should be consolidated if len(ion_elts) == 1: entry.concentration = conc_dict[ion_elts[0].symbol] * entry.normalization_factor elif len(ion_elts) > 1 and not entry.concentration: @@ -481,7 +481,7 @@ def __init__( self._unprocessed_entries = solid_entries + ion_entries - if not len(solid_entries + ion_entries) == len(entries): + if len(solid_entries + ion_entries) != len(entries): raise ValueError('All supplied entries must have a phase type of either "Solid" or "Ion"') if self.filter_solids: diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 510139bea7c..5dc2221398f 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -95,7 +95,7 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None: ) self.stack = stack self.sigma = sigma - self._doses: dict = {} + self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {} def add_dos(self, label: str, dos: PhononDos) -> None: """Adds a dos for plotting. @@ -138,6 +138,7 @@ def get_plot( ylim: float | None = None, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", legend: dict | None = None, + ax: Axes | None = None, ) -> Axes: """Get a matplotlib plot showing the DOS. @@ -149,6 +150,8 @@ def get_plot( legend: dict with legend options. For example, {"loc": "upper right"} will place the legend in the upper right corner. Defaults to {"fontsize": 30}. + ax (Axes): An existing axes object onto which the plot will be + added. If None, a new figure will be created. """ legend = legend or {"fontsize": 30} unit = freq_units(units) @@ -161,7 +164,7 @@ def get_plot( y = None all_densities = [] all_frequencies = [] - ax = pretty_plot(12, 8) + ax = pretty_plot(12, 8, ax=ax) # Note that this complicated processing of frequencies is to allow for # stacked plots in matplotlib. @@ -516,9 +519,8 @@ def show( """Show the plot using matplotlib. Args: - ylim: Specify the y-axis (frequency) limits; by default None let - the code choose. - units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. + ylim (float): Specifies the y-axis limits. + units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies. """ self.get_plot(ylim, units=units) plt.show() @@ -526,20 +528,18 @@ def show( def save_plot( self, filename: str | PathLike, - img_format: str = "eps", ylim: float | None = None, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", ) -> None: """Save matplotlib plot to a file. Args: - filename: Filename to write to. - img_format: Image format to use. Defaults to EPS. - ylim: Specifies the y-axis limits. - units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. + filename (str | Path): Filename to write to. + ylim (float): Specifies the y-axis limits. + units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies. """ self.get_plot(ylim=ylim, units=units) - plt.savefig(filename, format=img_format) + plt.savefig(filename) plt.close() def show_proj( @@ -598,9 +598,8 @@ def get_ticks(self) -> dict[str, list]: elif point.label.startswith("\\") or point.label.find("_") != -1: tick_labels.append(f"${point.label}$") else: - label = point.label - if label == "GAMMA": - label = r"$\Gamma$" + # map atomate2 all-upper-case point.labels to pretty LaTeX + label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label) tick_labels.append(label) previous_label = point.label previous_branch = this_branch