Skip to content

Commit 769ade3

Browse files
authored
ENH: implement FrozenDict with frozendict (#310)
* DX: implement hash test for `FrozenDict` and `ReactionInfo` * ENH: inherit `FrozenDict` from `frozendict` * MAINT: install `frozendict` as direct dependency
1 parent 63fff63 commit 769ade3

File tree

6 files changed

+97
-68
lines changed

6 files changed

+97
-68
lines changed

docs/conf.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from _extend_docstrings import extend_docstrings # noqa: PLC2701
1919

2020

21-
def pick_newtype_attrs(some_type: type) -> list:
21+
def __get_newtypes(some_type: type) -> list:
2222
return [
2323
attr
2424
for attr in dir(some_type)
@@ -278,25 +278,20 @@ def pick_newtype_attrs(some_type: type) -> list:
278278
nb_execution_show_tb = True
279279
nb_execution_timeout = -1
280280
nb_output_stderr = "remove"
281-
282-
283-
nitpick_temp_names = [
284-
*pick_newtype_attrs(EdgeQuantumNumbers),
285-
*pick_newtype_attrs(NodeQuantumNumbers),
286-
]
287-
nitpick_temp_patterns = [
288-
(r"py:(class|obj)", r"qrules\.quantum_numbers\." + name)
289-
for name in nitpick_temp_names
290-
]
291281
nitpick_ignore_regex = [
292282
(r"py:(class|obj)", "json.encoder.JSONEncoder"),
283+
(r"py:(class|obj)", r"frozendict(\.frozendict)?"),
293284
(r"py:(class|obj)", r"qrules\.topology\.EdgeType"),
294285
(r"py:(class|obj)", r"qrules\.topology\.KT"),
295286
(r"py:(class|obj)", r"qrules\.topology\.NewEdgeType"),
296287
(r"py:(class|obj)", r"qrules\.topology\.NewNodeType"),
297288
(r"py:(class|obj)", r"qrules\.topology\.NodeType"),
298289
(r"py:(class|obj)", r"qrules\.topology\.VT"),
299-
*nitpick_temp_patterns,
290+
*[
291+
(r"py:(class|obj)", r"qrules\.quantum_numbers\." + name)
292+
for name in __get_newtypes(EdgeQuantumNumbers)
293+
+ __get_newtypes(NodeQuantumNumbers)
294+
],
300295
]
301296
nitpicky = True
302297
primary_domain = "py"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"PyYAML",
2929
"attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen
30+
"frozendict",
3031
"jsonschema",
3132
"particle",
3233
"python-constraint",

src/qrules/topology.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,12 @@
2727
import attrs
2828
from attrs import define, field, frozen
2929
from attrs.validators import deep_iterable, deep_mapping, instance_of
30+
from frozendict import frozendict
3031

3132
from qrules._implementers import implement_pretty_repr
3233

3334
if TYPE_CHECKING:
34-
from collections.abc import (
35-
ItemsView,
36-
Iterable,
37-
Iterator,
38-
KeysView,
39-
Mapping,
40-
Sequence,
41-
ValuesView,
42-
)
35+
from collections.abc import Iterable, Mapping, Sequence
4336

4437
from IPython.lib.pretty import PrettyPrinter
4538

@@ -56,31 +49,8 @@ def __lt__(self, other: Any) -> bool: ...
5649

5750

5851
@total_ordering
59-
class FrozenDict(abc.Hashable, abc.Mapping, Generic[KT, VT]):
60-
"""An **immutable** and **hashable** version of a `dict`.
61-
62-
`FrozenDict` makes it possible to make classes hashable if they are decorated with
63-
:func:`attr.frozen` and contain `~typing.Mapping`-like attributes. If these
64-
attributes were to be implemented with a normal `dict`, the instance is strictly
65-
speaking still mutable (even if those attributes are a `property`) and the class is
66-
therefore not safely hashable.
67-
68-
.. warning:: The keys have to be comparable, that is, they need to have a
69-
:meth:`~object.__lt__` method.
70-
"""
71-
72-
def __init__(self, mapping: Mapping | None = None) -> None:
73-
self.__mapping: dict[KT, VT] = {}
74-
if mapping is not None:
75-
self.__mapping = dict(mapping)
76-
self.__hash = hash(None)
77-
if len(self.__mapping) != 0:
78-
self.__hash = 0
79-
for key_value_pair in self.items():
80-
self.__hash ^= hash(key_value_pair)
81-
82-
def __repr__(self) -> str:
83-
return f"{type(self).__name__}({self.__mapping})"
52+
class FrozenDict(frozendict, Generic[KT, VT]):
53+
"""A sortable version of :code:`frozendict`."""
8454

8555
def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
8656
class_name = type(self).__name__
@@ -96,15 +66,6 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
9666
p.breakable()
9767
p.text("})")
9868

99-
def __iter__(self) -> Iterator[KT]:
100-
return iter(self.__mapping)
101-
102-
def __len__(self) -> int:
103-
return len(self.__mapping)
104-
105-
def __getitem__(self, key: KT) -> VT:
106-
return self.__mapping[key]
107-
10869
def __gt__(self, other: Any) -> bool:
10970
if isinstance(other, abc.Mapping):
11071
sorted_self = _convert_mapping_to_sorted_tuple(self)
@@ -117,18 +78,6 @@ def __gt__(self, other: Any) -> bool:
11778
)
11879
raise NotImplementedError(msg)
11980

120-
def __hash__(self) -> int:
121-
return self.__hash
122-
123-
def keys(self) -> KeysView[KT]:
124-
return self.__mapping.keys()
125-
126-
def items(self) -> ItemsView[KT, VT]:
127-
return self.__mapping.items()
128-
129-
def values(self) -> ValuesView[VT]:
130-
return self.__mapping.values()
131-
13281

13382
def _convert_mapping_to_sorted_tuple(
13483
mapping: Mapping[KT, VT],

tests/unit/test_topology.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import hashlib
2+
import pickle # noqa: S403
13
import typing
24

35
import pytest
@@ -6,7 +8,7 @@
68

79
from qrules.topology import (
810
Edge,
9-
FrozenDict, # noqa: F401 # pyright: ignore[reportUnusedImport]
11+
FrozenDict, # pyright: ignore[reportUnusedImport]
1012
InteractionNode,
1113
MutableTopology,
1214
SimpleStateTransitionTopologyBuilder,
@@ -39,6 +41,23 @@ def test_immutability(self):
3941
edge.ending_node_id += 1
4042

4143

44+
class TestFrozenDict:
45+
def test_hash(self):
46+
obj: FrozenDict = FrozenDict({})
47+
assert _compute_hash(obj) == "067705e70d037311d05daae1e32e1fce"
48+
49+
obj = FrozenDict({"key1": "value1"})
50+
assert _compute_hash(obj) == "56b0520e2a3af550c0f488cd5de2d474"
51+
52+
obj = FrozenDict({
53+
"key1": "value1",
54+
"key2": 2,
55+
"key3": (1, 2, 3),
56+
"key4": FrozenDict({"nested_key": "nested_value"}),
57+
})
58+
assert _compute_hash(obj) == "8568f73c07fce099311f010061f070c6"
59+
60+
4261
class TestInteractionNode:
4362
def test_constructor_exceptions(self):
4463
with pytest.raises(TypeError):
@@ -188,6 +207,9 @@ def test_constructor_exceptions(self, nodes, edges):
188207
):
189208
assert Topology(nodes, edges)
190209

210+
def test_hash(self, two_to_three_decay: Topology):
211+
assert _compute_hash(two_to_three_decay) == "cbaea5d94038a3ad30888014a7b3ae20"
212+
191213
@pytest.mark.parametrize("repr_method", [repr, pretty])
192214
def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology):
193215
topology = eval(repr_method(two_to_three_decay))
@@ -299,3 +321,15 @@ def test_create_n_body_topology(n_initial: int, n_final: int, exception):
299321
assert len(topology.outgoing_edge_ids) == n_final
300322
assert len(topology.intermediate_edge_ids) == 0
301323
assert len(topology.nodes) == 1
324+
325+
326+
def _compute_hash(obj) -> str:
327+
b = _to_bytes(obj)
328+
h = hashlib.md5(b) # noqa: S324
329+
return h.hexdigest()
330+
331+
332+
def _to_bytes(obj) -> bytes:
333+
if isinstance(obj, bytes | bytearray):
334+
return obj
335+
return pickle.dumps(obj)

tests/unit/test_transition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# pyright: reportUnusedImport=false
2+
import hashlib
3+
import pickle # noqa: S403
24
from copy import deepcopy
35
from fractions import Fraction
46

@@ -44,6 +46,13 @@ def test_repr(self, repr_method, reaction: ReactionInfo):
4446
def test_hash(self, reaction: ReactionInfo):
4547
assert hash(deepcopy(reaction)) == hash(reaction)
4648

49+
def test_hash_value(self, reaction: ReactionInfo):
50+
expected_hash = {
51+
"canonical-helicity": "65106a44301f9340e633d09f66ad7d17",
52+
"helicity": "9646d3ee5c5e8534deb8019435161f2e",
53+
}[reaction.formalism]
54+
assert _compute_hash(reaction) == expected_hash
55+
4756

4857
class TestState:
4958
@pytest.mark.parametrize(
@@ -106,3 +115,15 @@ def test_regex_pattern(self):
106115
"Delta(1900)++",
107116
"Delta(1920)++",
108117
]
118+
119+
120+
def _compute_hash(obj) -> str:
121+
b = _to_bytes(obj)
122+
h = hashlib.md5(b) # noqa: S324
123+
return h.hexdigest()
124+
125+
126+
def _to_bytes(obj) -> bytes:
127+
if isinstance(obj, bytes | bytearray):
128+
return obj
129+
return pickle.dumps(obj)

uv.lock

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)