Skip to content

Commit

Permalink
fix bad merge behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Dec 12, 2024
1 parent ec33a39 commit def9886
Showing 1 changed file with 67 additions and 54 deletions.
121 changes: 67 additions & 54 deletions src/pymatgen/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import json
import pickle # use pickle, not cPickle so that we get the traceback in case of errors
import pickle # use pickle over cPickle to get traceback in case of errors
import string
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -19,12 +19,15 @@
from monty.json import MontyDecoder, MontyEncoder, MSONable
from monty.serialization import loadfn

from pymatgen.core import ROOT, SETTINGS, Structure
from pymatgen.core import ROOT, SETTINGS

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, ClassVar

from pymatgen.core import Structure
from pymatgen.util.typing import PathLike

_MODULE_DIR: Path = Path(__file__).absolute().parent

STRUCTURES_DIR: Path = _MODULE_DIR / "structures"
Expand All @@ -33,10 +36,10 @@
VASP_IN_DIR: str = f"{TEST_FILES_DIR}/io/vasp/inputs"
VASP_OUT_DIR: str = f"{TEST_FILES_DIR}/io/vasp/outputs"

# fake POTCARs have original header information, meaning properties like number of electrons,
# Fake POTCARs have original header information, meaning properties like number of electrons,
# nuclear charge, core radii, etc. are unchanged (important for testing) while values of the and
# pseudopotential kinetic energy corrections are scrambled to avoid VASP copyright infringement
FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars"
FAKE_POTCAR_DIR: str = f"{VASP_IN_DIR}/fake_potcars"


class MatSciTest:
Expand All @@ -50,36 +53,53 @@ class MatSciTest:
"""

# dict of lazily-loaded test structures (initialized to None)
TEST_STRUCTURES: ClassVar[dict[str | Path, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*"))
TEST_STRUCTURES: ClassVar[dict[PathLike, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*"))

@pytest.fixture(autouse=True) # make all tests run a in a temporary directory accessible via self.tmp_path
@pytest.fixture(autouse=True)
def _tmp_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
# https://pytest.org/en/latest/how-to/unittest.html#using-autouse-fixtures-and-accessing-other-fixtures
monkeypatch.chdir(tmp_path) # change to pytest-provided temporary directory
self.tmp_path = tmp_path
"""Make all tests run a in a temporary directory accessible via self.tmp_path.
@classmethod
def get_structure(cls, name: str) -> Structure:
References:
https://docs.pytest.org/en/stable/how-to/tmp_path.html
"""
Load a structure from `pymatgen.util.structures`.
monkeypatch.chdir(tmp_path) # change to temporary directory
self.tmp_path = tmp_path

@staticmethod
def assert_msonable(obj: Any, test_is_subclass: bool = True) -> str:
"""Test if an object is MSONable and verify the contract is fulfilled,
and return the serialized object.
By default, the method tests whether obj is an instance of MSONable.
This check can be deactivated by setting `test_is_subclass` to False.
Args:
name (str): Name of the structure file, for example "LiFePO4".
obj (Any): The object to be checked.
test_is_subclass (bool): Check if object is an instance of MSONable
or its subclasses.
Returns:
Structure
str: Serialized object.
"""
try:
struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json")
except FileNotFoundError as exc:
raise FileNotFoundError(f"structure for {name} doesn't exist") from exc
obj_name = obj.__class__.__name__

cls.TEST_STRUCTURES[name] = struct
# Check if is an instance of MONable (or its subclasses)
if test_is_subclass and not isinstance(obj, MSONable):
raise TypeError(f"{obj_name} object is not MSONable")

return struct.copy()
# Check if the object can be accurately reconstructed from its dict representation
if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict():
raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.")

# Verify that the deserialized object's class is a subclass of the original object's class
json_str = json.dumps(obj.as_dict(), cls=MontyEncoder)
round_trip = json.loads(json_str, cls=MontyDecoder)
if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}")
return json_str

@staticmethod
def assert_str_content_equal(actual, expected):
def assert_str_content_equal(actual: str, expected: str) -> None:
"""Test if two strings are equal, ignoring whitespaces.
Args:
Expand All @@ -99,7 +119,32 @@ def assert_str_content_equal(actual, expected):
f"{expected}\n"
)

def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None = None, test_eq: bool = True):
@classmethod
def get_structure(cls, name: str) -> Structure:
"""
Load a structure from `pymatgen.util.structures`.
Args:
name (str): Name of the structure file, for example "LiFePO4".
Returns:
Structure
"""
try:
struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json")
except FileNotFoundError as exc:
raise FileNotFoundError(f"structure for {name} doesn't exist") from exc

cls.TEST_STRUCTURES[name] = struct

return struct.copy()

def serialize_with_pickle(
self,
objects: Any,
protocols: Sequence[int] | None = None,
test_eq: bool = True,
) -> list:
"""Test whether the object(s) can be serialized and deserialized with
`pickle`. This method tries to serialize the objects with `pickle` and the
protocols specified in input. Then it deserializes the pickled format
Expand Down Expand Up @@ -163,38 +208,6 @@ def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None =
return [o[0] for o in objects_by_protocol]
return objects_by_protocol

def assert_msonable(self, obj: MSONable, test_is_subclass: bool = True) -> str:
"""Test if an object is MSONable and verify the contract is fulfilled,
and return the serialized object.
By default, the method tests whether obj is an instance of MSONable.
This check can be deactivated by setting `test_is_subclass` to False.
Args:
obj (Any): The object to be checked.
test_is_subclass (bool): Check if object is an instance of MSONable
or its subclasses.
Returns:
str: Serialized object.
"""
obj_name = obj.__class__.__name__

# Check if is an instance of MONable (or its subclasses)
if test_is_subclass and not isinstance(obj, MSONable):
raise TypeError(f"{obj_name} object is not MSONable")

# Check if the object can be accurately reconstructed from its dict representation
if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict():
raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.")

# Verify that the deserialized object's class is a subclass of the original object's class
json_str = json.dumps(obj.as_dict(), cls=MontyEncoder)
round_trip = json.loads(json_str, cls=MontyDecoder)
if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}")
return json_str


@deprecated(MatSciTest, deadline=(2026, 1, 1))
class PymatgenTest(TestCase, MatSciTest):
Expand Down

0 comments on commit def9886

Please sign in to comment.