From a28e1dae228d0b8926918a691520c52894684ed2 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Tue, 22 Oct 2024 04:52:55 +0800 Subject: [PATCH] Fix `savefig` in `pmg.cli.plot` (#4109) * fix plot save fig and enhance type * tweak decorator usage * plt -> fig to avoid shadowing matplotlib.pyplot as plt * fix type error for electronic_structure plotter * pin 3.12 in ci * fix sources path for src layout * split tests for cli * Revert "split tests for cli" This reverts commit f159d5ee9cf7521b63311356a5661b9b95d62479. * Revert "fix sources path for src layout" This reverts commit 3dfe21c020465468bd89b2f43698d8681c8fda81. * recover no type check decorator, intend for another PR * Revert "recover no type check decorator, intend for another PR" This reverts commit f4dabf93acd9efee1a1162896a52a3d7ad23fae6. * Reapply "split tests for cli" This reverts commit b01deabd1c839f8060aef58bc19a3abc130afd76. * Reapply "fix sources path for src layout" This reverts commit a7b7e259ce10bf0c181ed4109d3bcec92420910b. * pin mypy for now * improve comment * add unit test for plot * also check incorrect usage --- src/pymatgen/cli/pmg_plot.py | 35 ++++++------- src/pymatgen/electronic_structure/plotter.py | 15 +++--- tests/cli/conftest.py | 14 +++++ tests/cli/test_pmg_analyze.py | 18 +++++++ tests/cli/test_pmg_diff.py | 16 ++++++ tests/cli/test_pmg_plot.py | 51 +++++++++++++++++++ .../test_pmg_structure.py} | 25 +-------- tests/test_pkg.py | 11 ++-- 8 files changed, 132 insertions(+), 53 deletions(-) create mode 100644 tests/cli/conftest.py create mode 100644 tests/cli/test_pmg_analyze.py create mode 100644 tests/cli/test_pmg_diff.py create mode 100644 tests/cli/test_pmg_plot.py rename tests/{test_cli.py => cli/test_pmg_structure.py} (62%) diff --git a/src/pymatgen/cli/pmg_plot.py b/src/pymatgen/cli/pmg_plot.py index 89d0fa269ae..d5c032f7f71 100755 --- a/src/pymatgen/cli/pmg_plot.py +++ b/src/pymatgen/cli/pmg_plot.py @@ -14,8 +14,8 @@ from pymatgen.util.plotting import pretty_plot -def get_dos_plot(args): - """Plot DOS. +def get_dos_plot(args) -> plt.Axes: + """Plot DOS from vasprun.xml file. Args: args (dict): Args from argparse. @@ -46,8 +46,8 @@ def get_dos_plot(args): return plotter.get_plot() -def get_chgint_plot(args, ax: plt.Axes = None) -> plt.Axes: - """Plot integrated charge. +def get_chgint_plot(args, ax: plt.Axes | None = None) -> plt.Axes: + """Plot integrated charge from CHGCAR file. Args: args (dict): args from argparse. @@ -77,33 +77,34 @@ def get_chgint_plot(args, ax: plt.Axes = None) -> plt.Axes: return ax -def get_xrd_plot(args): - """Plot XRD. +def get_xrd_plot(args) -> plt.Axes: + """Plot XRD from structure. Args: args (dict): Args from argparse """ struct = Structure.from_file(args.xrd_structure_file) - c = XRDCalculator() - return c.get_plot(struct) + calculator = XRDCalculator() + return calculator.get_plot(struct) -def plot(args): - """Master control method calling other plot methods based on args. +def plot(args) -> None: + """Master control function calling other plot functions based on args. Args: args (dict): Args from argparse. """ - plt = None if args.chgcar_file: - plt = get_chgint_plot(args) + fig: plt.Figure | None = get_chgint_plot(args).figure elif args.xrd_structure_file: - plt = get_xrd_plot(args) + fig = get_xrd_plot(args).figure elif args.dos_file: - plt = get_dos_plot(args) + fig = get_dos_plot(args).figure + else: + fig = None - if plt: + if fig is not None: if args.out_file: - plt.savefig(args.out_file) + fig.savefig(args.out_file) else: - plt.show() + fig.show() diff --git a/src/pymatgen/electronic_structure/plotter.py b/src/pymatgen/electronic_structure/plotter.py index 3532995766c..d12c9334e9b 100644 --- a/src/pymatgen/electronic_structure/plotter.py +++ b/src/pymatgen/electronic_structure/plotter.py @@ -6,7 +6,6 @@ import itertools import logging import math -import typing import warnings from collections import Counter from typing import TYPE_CHECKING, cast @@ -20,6 +19,7 @@ from matplotlib.gridspec import GridSpec from monty.dev import requires from monty.json import jsanitize +from numpy.typing import ArrayLike from pymatgen.core import Element from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine @@ -36,8 +36,6 @@ from collections.abc import Sequence from typing import Literal - from numpy.typing import ArrayLike - from pymatgen.electronic_structure.dos import CompleteDos, Dos logger = logging.getLogger(__name__) @@ -130,7 +128,6 @@ def get_dos_dict(self): """ return jsanitize(self._doses) - @typing.no_type_check def get_plot( self, xlim: tuple[float, float] | None = None, @@ -163,8 +160,8 @@ def get_plot( # Note that this complicated processing of energies is to allow for # stacked plots in matplotlib. for dos in self._doses.values(): - energies = dos["energies"] - densities = dos["densities"] + energies = cast(ArrayLike, dos["energies"]) + densities = cast(ArrayLike, dos["densities"]) if not ys: ys = { Spin.up: np.zeros(energies.shape), @@ -211,10 +208,14 @@ def get_plot( ax.set_ylim(ylim) elif not invert_axes: xlim = ax.get_xlim() + if xlim is None: + raise RuntimeError("xlim cannot be None.") relevant_y = [p[1] for p in all_pts if xlim[0] < p[0] < xlim[1]] ax.set_ylim((min(relevant_y), max(relevant_y))) if not xlim and invert_axes: ylim = ax.get_ylim() + if ylim is None: + raise RuntimeError("ylim cannot be None.") relevant_y = [p[0] for p in all_pts if ylim[0] < p[1] < ylim[1]] ax.set_xlim((min(relevant_y), max(relevant_y))) @@ -3900,7 +3901,7 @@ def show(self, xlim=None, ylim=None) -> None: plt.show() -@requires(mlab is not None, "MayAvi mlab not imported! Please install mayavi.") +@requires(mlab is not None, "MayAvi mlab not installed! Please install mayavi.") def plot_fermi_surface( data, structure, diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py new file mode 100644 index 00000000000..44154c88172 --- /dev/null +++ b/tests/cli/conftest.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture +def cd_tmp_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.chdir(tmp_path) + return tmp_path diff --git a/tests/cli/test_pmg_analyze.py b/tests/cli/test_pmg_analyze.py new file mode 100644 index 00000000000..2f2a1d3dcd2 --- /dev/null +++ b/tests/cli/test_pmg_analyze.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import os +import subprocess +from typing import TYPE_CHECKING + +from pymatgen.util.testing import TEST_FILES_DIR + +if TYPE_CHECKING: + from pathlib import Path + + +def test_pmg_analyze(cd_tmp_path: Path): + subprocess.run( + ["pmg", "analyze", f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation"], + check=True, + ) + assert os.path.isfile("vasp_data.gz") diff --git a/tests/cli/test_pmg_diff.py b/tests/cli/test_pmg_diff.py new file mode 100644 index 00000000000..6a478c37bc0 --- /dev/null +++ b/tests/cli/test_pmg_diff.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import subprocess +from typing import TYPE_CHECKING + +from pymatgen.util.testing import VASP_IN_DIR + +if TYPE_CHECKING: + from pathlib import Path + + +def test_pmg_diff(cd_tmp_path: Path): + subprocess.run( + ["pmg", "diff", "--incar", f"{VASP_IN_DIR}/INCAR", f"{VASP_IN_DIR}/INCAR_2"], + check=True, + ) diff --git a/tests/cli/test_pmg_plot.py b/tests/cli/test_pmg_plot.py new file mode 100644 index 00000000000..635b934ae26 --- /dev/null +++ b/tests/cli/test_pmg_plot.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import os +import subprocess +from typing import TYPE_CHECKING + +import pytest + +from pymatgen.util.testing import VASP_IN_DIR, VASP_OUT_DIR + +if TYPE_CHECKING: + from pathlib import Path + + +def test_plot_xrd(cd_tmp_path: Path): + subprocess.run( + ["pmg", "plot", "--xrd", f"{VASP_IN_DIR}/POSCAR_Fe3O4", "--out_file", "xrd.png"], + check=True, + ) + assert os.path.isfile("xrd.png") + assert os.path.getsize("xrd.png") > 1024 + + +def test_plot_dos(cd_tmp_path: Path): + subprocess.run( + ["pmg", "plot", "--dos", f"{VASP_OUT_DIR}/vasprun_Li_no_projected.xml.gz", "--out_file", "dos.png"], + check=True, + ) + assert os.path.isfile("dos.png") + assert os.path.getsize("dos.png") > 1024 + + +def test_plot_chgint(cd_tmp_path: Path): + subprocess.run( + ["pmg", "plot", "--chgint", f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz", "--out_file", "chg.png"], + check=True, + ) + assert os.path.isfile("chg.png") + assert os.path.getsize("chg.png") > 1024 + + +def test_plot_wrong_arg(cd_tmp_path: Path): + with pytest.raises(subprocess.CalledProcessError) as exc_info: + subprocess.run( + ["pmg", "plot", "--wrong", f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz"], + check=True, + capture_output=True, + ) + + assert exc_info.value.returncode == 2 + assert "one of the arguments -d/--dos -c/--chgint -x/--xrd is required" in exc_info.value.stderr.decode("utf-8") diff --git a/tests/test_cli.py b/tests/cli/test_pmg_structure.py similarity index 62% rename from tests/test_cli.py rename to tests/cli/test_pmg_structure.py index 3d0741732b5..7bdaf67cc52 100644 --- a/tests/test_cli.py +++ b/tests/cli/test_pmg_structure.py @@ -4,28 +4,12 @@ import subprocess from typing import TYPE_CHECKING -import pytest - -from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR +from pymatgen.util.testing import TEST_FILES_DIR if TYPE_CHECKING: from pathlib import Path -@pytest.fixture -def cd_tmp_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): - monkeypatch.chdir(tmp_path) - return tmp_path - - -def test_pmg_analyze(cd_tmp_path: Path): - subprocess.run( - ["pmg", "analyze", f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation"], - check=True, - ) - assert os.path.isfile("vasp_data.gz") - - def test_pmg_structure(cd_tmp_path: Path): subprocess.run( ["pmg", "structure", "--convert", "--filenames", f"{TEST_FILES_DIR}/cif/Li2O.cif", "POSCAR_Li2O_test"], @@ -53,10 +37,3 @@ def test_pmg_structure(cd_tmp_path: Path): subprocess.run( ["pmg", "structure", "--localenv", "Li-O=3", "--filenames", f"{TEST_FILES_DIR}/cif/Li2O.cif"], check=True ) - - -def test_pmg_diff(cd_tmp_path: Path): - subprocess.run( - ["pmg", "diff", "--incar", f"{VASP_IN_DIR}/INCAR", f"{VASP_IN_DIR}/INCAR_2"], - check=True, - ) diff --git a/tests/test_pkg.py b/tests/test_pkg.py index 4748bced01b..3d277fe8948 100644 --- a/tests/test_pkg.py +++ b/tests/test_pkg.py @@ -5,15 +5,16 @@ import pytest -src_txt_path = "pymatgen.egg-info/SOURCES.txt" -src_txt_missing = not os.path.isfile(src_txt_path) +SRC_TXT_PATH = "src/pymatgen.egg-info/SOURCES.txt" -@pytest.mark.skipif(src_txt_missing, reason=f"{src_txt_path} not found. Run `pip install .` to create") +@pytest.mark.skipif( + not os.path.isfile(SRC_TXT_PATH), reason=f"{SRC_TXT_PATH=} not found. Run `pip install .` to create" +) def test_egg_sources_txt_is_complete(): """Check that all source and data files in pymatgen/ are listed in pymatgen.egg-info/SOURCES.txt.""" - with open(src_txt_path) as file: + with open(SRC_TXT_PATH, encoding="utf-8") as file: sources = file.read() # check that all files listed in SOURCES.txt exist @@ -28,6 +29,6 @@ def test_egg_sources_txt_is_complete(): continue if unix_path not in sources: raise ValueError( - f"{unix_path} not found in {src_txt_path}. check setup.py package_data for " + f"{unix_path} not found in {SRC_TXT_PATH}. check setup.py package_data for " "outdated inclusion rules." )