Skip to content

Commit

Permalink
Fix savefig in pmg.cli.plot (#4109)
Browse files Browse the repository at this point in the history
* fix plot save fig and enhance type

* tweak decorator usage

* plt -> fig to avoid shadowing matplotlib.pyplot as plt

* fix type error for electronic_structure plotter

* pin 3.12 in ci

* fix sources path for src layout

* split tests for cli

* Revert "split tests for cli"

This reverts commit f159d5e.

* Revert "fix sources path for src layout"

This reverts commit 3dfe21c.

* recover no type check decorator, intend for another PR

* Revert "recover no type check decorator, intend for another PR"

This reverts commit f4dabf9.

* Reapply "split tests for cli"

This reverts commit b01deab.

* Reapply "fix sources path for src layout"

This reverts commit a7b7e25.

* pin mypy for now

* improve comment

* add unit test for plot

* also check incorrect usage
  • Loading branch information
DanielYang59 authored Oct 21, 2024
1 parent 8a4822a commit a28e1da
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 53 deletions.
35 changes: 18 additions & 17 deletions src/pymatgen/cli/pmg_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pymatgen.util.plotting import pretty_plot


def get_dos_plot(args):
"""Plot DOS.
def get_dos_plot(args) -> plt.Axes:
"""Plot DOS from vasprun.xml file.
Args:
args (dict): Args from argparse.
Expand Down Expand Up @@ -46,8 +46,8 @@ def get_dos_plot(args):
return plotter.get_plot()


def get_chgint_plot(args, ax: plt.Axes = None) -> plt.Axes:
"""Plot integrated charge.
def get_chgint_plot(args, ax: plt.Axes | None = None) -> plt.Axes:
"""Plot integrated charge from CHGCAR file.
Args:
args (dict): args from argparse.
Expand Down Expand Up @@ -77,33 +77,34 @@ def get_chgint_plot(args, ax: plt.Axes = None) -> plt.Axes:
return ax


def get_xrd_plot(args):
"""Plot XRD.
def get_xrd_plot(args) -> plt.Axes:
"""Plot XRD from structure.
Args:
args (dict): Args from argparse
"""
struct = Structure.from_file(args.xrd_structure_file)
c = XRDCalculator()
return c.get_plot(struct)
calculator = XRDCalculator()
return calculator.get_plot(struct)


def plot(args):
"""Master control method calling other plot methods based on args.
def plot(args) -> None:
"""Master control function calling other plot functions based on args.
Args:
args (dict): Args from argparse.
"""
plt = None
if args.chgcar_file:
plt = get_chgint_plot(args)
fig: plt.Figure | None = get_chgint_plot(args).figure
elif args.xrd_structure_file:
plt = get_xrd_plot(args)
fig = get_xrd_plot(args).figure
elif args.dos_file:
plt = get_dos_plot(args)
fig = get_dos_plot(args).figure
else:
fig = None

if plt:
if fig is not None:
if args.out_file:
plt.savefig(args.out_file)
fig.savefig(args.out_file)
else:
plt.show()
fig.show()
15 changes: 8 additions & 7 deletions src/pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import itertools
import logging
import math
import typing
import warnings
from collections import Counter
from typing import TYPE_CHECKING, cast
Expand All @@ -20,6 +19,7 @@
from matplotlib.gridspec import GridSpec
from monty.dev import requires
from monty.json import jsanitize
from numpy.typing import ArrayLike

from pymatgen.core import Element
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
Expand All @@ -36,8 +36,6 @@
from collections.abc import Sequence
from typing import Literal

from numpy.typing import ArrayLike

from pymatgen.electronic_structure.dos import CompleteDos, Dos

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,7 +128,6 @@ def get_dos_dict(self):
"""
return jsanitize(self._doses)

@typing.no_type_check
def get_plot(
self,
xlim: tuple[float, float] | None = None,
Expand Down Expand Up @@ -163,8 +160,8 @@ def get_plot(
# Note that this complicated processing of energies is to allow for
# stacked plots in matplotlib.
for dos in self._doses.values():
energies = dos["energies"]
densities = dos["densities"]
energies = cast(ArrayLike, dos["energies"])
densities = cast(ArrayLike, dos["densities"])
if not ys:
ys = {
Spin.up: np.zeros(energies.shape),
Expand Down Expand Up @@ -211,10 +208,14 @@ def get_plot(
ax.set_ylim(ylim)
elif not invert_axes:
xlim = ax.get_xlim()
if xlim is None:
raise RuntimeError("xlim cannot be None.")
relevant_y = [p[1] for p in all_pts if xlim[0] < p[0] < xlim[1]]
ax.set_ylim((min(relevant_y), max(relevant_y)))
if not xlim and invert_axes:
ylim = ax.get_ylim()
if ylim is None:
raise RuntimeError("ylim cannot be None.")
relevant_y = [p[0] for p in all_pts if ylim[0] < p[1] < ylim[1]]
ax.set_xlim((min(relevant_y), max(relevant_y)))

Expand Down Expand Up @@ -3900,7 +3901,7 @@ def show(self, xlim=None, ylim=None) -> None:
plt.show()


@requires(mlab is not None, "MayAvi mlab not imported! Please install mayavi.")
@requires(mlab is not None, "MayAvi mlab not installed! Please install mayavi.")
def plot_fermi_surface(
data,
structure,
Expand Down
14 changes: 14 additions & 0 deletions tests/cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from pathlib import Path


@pytest.fixture
def cd_tmp_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
return tmp_path
18 changes: 18 additions & 0 deletions tests/cli/test_pmg_analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

import os
import subprocess
from typing import TYPE_CHECKING

from pymatgen.util.testing import TEST_FILES_DIR

if TYPE_CHECKING:
from pathlib import Path


def test_pmg_analyze(cd_tmp_path: Path):
subprocess.run(
["pmg", "analyze", f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation"],
check=True,
)
assert os.path.isfile("vasp_data.gz")
16 changes: 16 additions & 0 deletions tests/cli/test_pmg_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

import subprocess
from typing import TYPE_CHECKING

from pymatgen.util.testing import VASP_IN_DIR

if TYPE_CHECKING:
from pathlib import Path


def test_pmg_diff(cd_tmp_path: Path):
subprocess.run(
["pmg", "diff", "--incar", f"{VASP_IN_DIR}/INCAR", f"{VASP_IN_DIR}/INCAR_2"],
check=True,
)
51 changes: 51 additions & 0 deletions tests/cli/test_pmg_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import os
import subprocess
from typing import TYPE_CHECKING

import pytest

from pymatgen.util.testing import VASP_IN_DIR, VASP_OUT_DIR

if TYPE_CHECKING:
from pathlib import Path


def test_plot_xrd(cd_tmp_path: Path):
subprocess.run(
["pmg", "plot", "--xrd", f"{VASP_IN_DIR}/POSCAR_Fe3O4", "--out_file", "xrd.png"],
check=True,
)
assert os.path.isfile("xrd.png")
assert os.path.getsize("xrd.png") > 1024


def test_plot_dos(cd_tmp_path: Path):
subprocess.run(
["pmg", "plot", "--dos", f"{VASP_OUT_DIR}/vasprun_Li_no_projected.xml.gz", "--out_file", "dos.png"],
check=True,
)
assert os.path.isfile("dos.png")
assert os.path.getsize("dos.png") > 1024


def test_plot_chgint(cd_tmp_path: Path):
subprocess.run(
["pmg", "plot", "--chgint", f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz", "--out_file", "chg.png"],
check=True,
)
assert os.path.isfile("chg.png")
assert os.path.getsize("chg.png") > 1024


def test_plot_wrong_arg(cd_tmp_path: Path):
with pytest.raises(subprocess.CalledProcessError) as exc_info:
subprocess.run(
["pmg", "plot", "--wrong", f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz"],
check=True,
capture_output=True,
)

assert exc_info.value.returncode == 2
assert "one of the arguments -d/--dos -c/--chgint -x/--xrd is required" in exc_info.value.stderr.decode("utf-8")
25 changes: 1 addition & 24 deletions tests/test_cli.py → tests/cli/test_pmg_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,12 @@
import subprocess
from typing import TYPE_CHECKING

import pytest

from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR
from pymatgen.util.testing import TEST_FILES_DIR

if TYPE_CHECKING:
from pathlib import Path


@pytest.fixture
def cd_tmp_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
return tmp_path


def test_pmg_analyze(cd_tmp_path: Path):
subprocess.run(
["pmg", "analyze", f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation"],
check=True,
)
assert os.path.isfile("vasp_data.gz")


def test_pmg_structure(cd_tmp_path: Path):
subprocess.run(
["pmg", "structure", "--convert", "--filenames", f"{TEST_FILES_DIR}/cif/Li2O.cif", "POSCAR_Li2O_test"],
Expand Down Expand Up @@ -53,10 +37,3 @@ def test_pmg_structure(cd_tmp_path: Path):
subprocess.run(
["pmg", "structure", "--localenv", "Li-O=3", "--filenames", f"{TEST_FILES_DIR}/cif/Li2O.cif"], check=True
)


def test_pmg_diff(cd_tmp_path: Path):
subprocess.run(
["pmg", "diff", "--incar", f"{VASP_IN_DIR}/INCAR", f"{VASP_IN_DIR}/INCAR_2"],
check=True,
)
11 changes: 6 additions & 5 deletions tests/test_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

import pytest

src_txt_path = "pymatgen.egg-info/SOURCES.txt"
src_txt_missing = not os.path.isfile(src_txt_path)
SRC_TXT_PATH = "src/pymatgen.egg-info/SOURCES.txt"


@pytest.mark.skipif(src_txt_missing, reason=f"{src_txt_path} not found. Run `pip install .` to create")
@pytest.mark.skipif(
not os.path.isfile(SRC_TXT_PATH), reason=f"{SRC_TXT_PATH=} not found. Run `pip install .` to create"
)
def test_egg_sources_txt_is_complete():
"""Check that all source and data files in pymatgen/ are listed in pymatgen.egg-info/SOURCES.txt."""

with open(src_txt_path) as file:
with open(SRC_TXT_PATH, encoding="utf-8") as file:
sources = file.read()

# check that all files listed in SOURCES.txt exist
Expand All @@ -28,6 +29,6 @@ def test_egg_sources_txt_is_complete():
continue
if unix_path not in sources:
raise ValueError(
f"{unix_path} not found in {src_txt_path}. check setup.py package_data for "
f"{unix_path} not found in {SRC_TXT_PATH}. check setup.py package_data for "
"outdated inclusion rules."
)

0 comments on commit a28e1da

Please sign in to comment.