Skip to content

Commit

Permalink
Bug fix for ComputedEntry from_dict (#3897)
Browse files Browse the repository at this point in the history
* fix bug in ComputedEntry.from_dict when kwargs are null, add test

* linting

* slight refactor

* clearer legacy case
  • Loading branch information
esoteric-ephemera authored Jun 26, 2024
1 parent f10ace8 commit 5cb59e4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
29 changes: 14 additions & 15 deletions pymatgen/entries/computed_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,27 +482,26 @@ def from_dict(cls, dct: dict) -> Self:
Returns:
ComputedEntry
"""
# the first block here is for legacy ComputedEntry that were
# serialized before we had the energy_adjustments attribute.
if dct["correction"] != 0 and not dct.get("energy_adjustments"):
return cls(
dct["composition"],
dct["energy"],
dct["correction"],
parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()},
data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()},
entry_id=dct.get("entry_id"),
)
# Must handle cases where some kwargs exist in `dct` but are None
# include extra logic to ensure these get properly treated
energy_adj = [MontyDecoder().process_decoded(e) for e in (dct.get("energy_adjustments", []) or [])]

# this is the preferred / modern way of instantiating ComputedEntry
# we don't pass correction explicitly because it will be calculated
# on the fly from energy_adjustments
correction = 0
if dct["correction"] != 0 and len(energy_adj) == 0:
# this block is for legacy ComputedEntry that were
# serialized before we had the energy_adjustments attribute.
correction = dct["correction"]

return cls(
dct["composition"],
dct["energy"],
correction=0,
energy_adjustments=[MontyDecoder().process_decoded(e) for e in dct.get("energy_adjustments", {})],
parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()},
data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()},
correction=correction,
energy_adjustments=energy_adj,
parameters={k: MontyDecoder().process_decoded(v) for k, v in (dct.get("parameters", {}) or {}).items()},
data={k: MontyDecoder().process_decoded(v) for k, v in (dct.get("data", {}) or {}).items()},
entry_id=dct.get("entry_id"),
)

Expand Down
13 changes: 13 additions & 0 deletions tests/entries/test_computed_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ def test_copy(self):
assert entry == copy
assert str(entry) == str(copy)

def test_from_dict_null_fields(self):
ce_dict = self.entry.as_dict()
for k in (
"energy_adjustments",
"parameters",
"data",
):
ce = ce_dict.copy()
ce[k] = None
new_ce = ComputedEntry.from_dict(ce)
assert new_ce == self.entry
assert getattr(new_ce, k, None) is not None


class TestComputedStructureEntry(TestCase):
def setUp(self):
Expand Down

0 comments on commit 5cb59e4

Please sign in to comment.