Skip to content

Commit

Permalink
fix bad merge behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Dec 12, 2024
1 parent ec33a39 commit ee603fa
Showing 1 changed file with 53 additions and 49 deletions.
102 changes: 53 additions & 49 deletions src/pymatgen/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import json
import pickle # use pickle, not cPickle so that we get the traceback in case of errors
import pickle # use pickle over cPickle to get traceback in case of errors
import string
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -33,10 +33,10 @@
VASP_IN_DIR: str = f"{TEST_FILES_DIR}/io/vasp/inputs"
VASP_OUT_DIR: str = f"{TEST_FILES_DIR}/io/vasp/outputs"

# fake POTCARs have original header information, meaning properties like number of electrons,
# Fake POTCARs have original header information, meaning properties like number of electrons,
# nuclear charge, core radii, etc. are unchanged (important for testing) while values of the and
# pseudopotential kinetic energy corrections are scrambled to avoid VASP copyright infringement
FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars"
FAKE_POTCAR_DIR: str = f"{VASP_IN_DIR}/fake_potcars"


class MatSciTest:
Expand All @@ -52,31 +52,47 @@ class MatSciTest:
# dict of lazily-loaded test structures (initialized to None)
TEST_STRUCTURES: ClassVar[dict[str | Path, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*"))

@pytest.fixture(autouse=True) # make all tests run a in a temporary directory accessible via self.tmp_path
@pytest.fixture(autouse=True)
def _tmp_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
# https://pytest.org/en/latest/how-to/unittest.html#using-autouse-fixtures-and-accessing-other-fixtures
"""Make all tests run a in a temporary directory accessible via self.tmp_path.
References:
https://docs.pytest.org/en/stable/how-to/tmp_path.html
"""
monkeypatch.chdir(tmp_path) # change to pytest-provided temporary directory
self.tmp_path = tmp_path

@classmethod
def get_structure(cls, name: str) -> Structure:
"""
Load a structure from `pymatgen.util.structures`.
def assert_msonable(self, obj: MSONable, test_is_subclass: bool = True) -> str:
"""Test if an object is MSONable and verify the contract is fulfilled,
and return the serialized object.
By default, the method tests whether obj is an instance of MSONable.
This check can be deactivated by setting `test_is_subclass` to False.
Args:
name (str): Name of the structure file, for example "LiFePO4".
obj (Any): The object to be checked.
test_is_subclass (bool): Check if object is an instance of MSONable
or its subclasses.
Returns:
Structure
str: Serialized object.
"""
try:
struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json")
except FileNotFoundError as exc:
raise FileNotFoundError(f"structure for {name} doesn't exist") from exc
obj_name = obj.__class__.__name__

cls.TEST_STRUCTURES[name] = struct
# Check if is an instance of MONable (or its subclasses)
if test_is_subclass and not isinstance(obj, MSONable):
raise TypeError(f"{obj_name} object is not MSONable")

return struct.copy()
# Check if the object can be accurately reconstructed from its dict representation
if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict():
raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.")

# Verify that the deserialized object's class is a subclass of the original object's class
json_str = json.dumps(obj.as_dict(), cls=MontyEncoder)
round_trip = json.loads(json_str, cls=MontyDecoder)
if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}")
return json_str

@staticmethod
def assert_str_content_equal(actual, expected):
Expand All @@ -99,6 +115,26 @@ def assert_str_content_equal(actual, expected):
f"{expected}\n"
)

@classmethod
def get_structure(cls, name: str) -> Structure:
"""
Load a structure from `pymatgen.util.structures`.
Args:
name (str): Name of the structure file, for example "LiFePO4".
Returns:
Structure
"""
try:
struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json")
except FileNotFoundError as exc:
raise FileNotFoundError(f"structure for {name} doesn't exist") from exc

cls.TEST_STRUCTURES[name] = struct

return struct.copy()

def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None = None, test_eq: bool = True):
"""Test whether the object(s) can be serialized and deserialized with
`pickle`. This method tries to serialize the objects with `pickle` and the
Expand Down Expand Up @@ -163,38 +199,6 @@ def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None =
return [o[0] for o in objects_by_protocol]
return objects_by_protocol

def assert_msonable(self, obj: MSONable, test_is_subclass: bool = True) -> str:
"""Test if an object is MSONable and verify the contract is fulfilled,
and return the serialized object.
By default, the method tests whether obj is an instance of MSONable.
This check can be deactivated by setting `test_is_subclass` to False.
Args:
obj (Any): The object to be checked.
test_is_subclass (bool): Check if object is an instance of MSONable
or its subclasses.
Returns:
str: Serialized object.
"""
obj_name = obj.__class__.__name__

# Check if is an instance of MONable (or its subclasses)
if test_is_subclass and not isinstance(obj, MSONable):
raise TypeError(f"{obj_name} object is not MSONable")

# Check if the object can be accurately reconstructed from its dict representation
if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict():
raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.")

# Verify that the deserialized object's class is a subclass of the original object's class
json_str = json.dumps(obj.as_dict(), cls=MontyEncoder)
round_trip = json.loads(json_str, cls=MontyDecoder)
if not issubclass(type(round_trip), type(obj)):
raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}")
return json_str


@deprecated(MatSciTest, deadline=(2026, 1, 1))
class PymatgenTest(TestCase, MatSciTest):
Expand Down

0 comments on commit ee603fa

Please sign in to comment.