Skip to content

Commit ef85e1f

Browse files
committed
refactor(enums): improve performance enum array (#1233)
1 parent 02c0576 commit ef85e1f

File tree

8 files changed

+142
-84
lines changed

8 files changed

+142
-84
lines changed

openfisca_core/indexed_enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Enumerations for variables with a limited set of possible values."""
22

33
from . import types
4+
from ._enum_type import EnumType
45
from ._errors import EnumEncodingError, EnumMemberNotFoundError
56
from .config import ENUM_ARRAY_DTYPE
67
from .enum import Enum
@@ -12,5 +13,6 @@
1213
"EnumArray",
1314
"EnumEncodingError",
1415
"EnumMemberNotFoundError",
16+
"EnumType",
1517
"types",
1618
]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
from typing import final
4+
5+
import numpy
6+
7+
from . import types as t
8+
9+
10+
@final
11+
class EnumType(t.EnumType):
12+
"""Meta class for creating an indexed :class:`.Enum`.
13+
14+
Examples:
15+
>>> from openfisca_core import indexed_enums as enum
16+
17+
>>> class Enum(enum.Enum, metaclass=enum.EnumType):
18+
... pass
19+
20+
>>> Enum.items
21+
Traceback (most recent call last):
22+
AttributeError: ...
23+
24+
>>> class Housing(Enum):
25+
... OWNER = "Owner"
26+
... TENANT = "Tenant"
27+
28+
>>> Housing.indices
29+
array([0, 1], dtype=uint8)
30+
31+
>>> Housing.names
32+
array(['OWNER', 'TENANT'], dtype='<U6')
33+
34+
>>> Housing.enums
35+
array([Housing.OWNER, Housing.TENANT], dtype=object)
36+
37+
"""
38+
39+
def __new__(
40+
metacls,
41+
name: str,
42+
bases: tuple[type, ...],
43+
classdict: t.EnumDict,
44+
**kwds: object,
45+
) -> t.EnumType:
46+
"""Create a new indexed enum class."""
47+
# Create the enum class.
48+
cls = super().__new__(metacls, name, bases, classdict, **kwds)
49+
50+
# If the enum class has no members, return it as is.
51+
if not cls.__members__:
52+
return cls
53+
54+
# Add the indices attribute to the enum class.
55+
cls.indices = numpy.arange(len(cls), dtype=t.EnumDType)
56+
57+
# Add the names attribute to the enum class.
58+
cls.names = numpy.array(cls._member_names_, dtype=t.StrDType)
59+
60+
# Add the enums attribute to the enum class.
61+
cls.enums = numpy.array(cls, dtype=t.ObjDType)
62+
63+
# Return the modified enum class.
64+
return cls
65+
66+
def __dir__(cls) -> list[str]:
67+
return sorted({"indices", "names", "enums", *super().__dir__()})
68+
69+
70+
__all__ = ["EnumType"]

openfisca_core/indexed_enums/_guards.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def _is_enum_array(array: t.VarArray) -> TypeIs[t.ObjArray]:
5353
return array.dtype.type in objs
5454

5555

56-
def _is_enum_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[t.Enum]]:
56+
def _is_enum_array_like(
57+
array: t.VarArray | t.ArrayLike[object],
58+
) -> TypeIs[t.ArrayLike[t.Enum]]:
5759
"""Narrow the type of a given array-like to an sequence of :class:`.Enum`.
5860
5961
Args:
@@ -109,7 +111,9 @@ def _is_int_array(array: t.VarArray) -> TypeIs[t.IndexArray]:
109111
return array.dtype.type in ints
110112

111113

112-
def _is_int_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[int]]:
114+
def _is_int_array_like(
115+
array: t.VarArray | t.ArrayLike[object],
116+
) -> TypeIs[t.ArrayLike[int]]:
113117
"""Narrow the type of a given array-like to a sequence of :obj:`int`.
114118
115119
Args:
@@ -165,7 +169,9 @@ def _is_str_array(array: t.VarArray) -> TypeIs[t.StrArray]:
165169
return array.dtype.type in strs
166170

167171

168-
def _is_str_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[str]]:
172+
def _is_str_array_like(
173+
array: t.VarArray | t.ArrayLike[object],
174+
) -> TypeIs[t.ArrayLike[str]]:
169175
"""Narrow the type of a given array-like to an sequence of :obj:`str`.
170176
171177
Args:

openfisca_core/indexed_enums/enum.py

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy
66

77
from . import types as t
8+
from ._enum_type import EnumType
89
from ._errors import EnumEncodingError, EnumMemberNotFoundError
910
from ._guards import (
1011
_is_enum_array,
@@ -18,7 +19,7 @@
1819
from .enum_array import EnumArray
1920

2021

21-
class Enum(t.Enum):
22+
class Enum(t.Enum, metaclass=EnumType):
2223
"""Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_.
2324
2425
Its items have an :class:`int` index, useful and performant when running
@@ -148,11 +149,6 @@ def encode(cls, array: t.VarArray | t.ArrayLike[object]) -> t.EnumArray:
148149
Returns:
149150
EnumArray: An :class:`.EnumArray` with the encoded input values.
150151
151-
Raises:
152-
EnumEncodingError: If ``array`` is of diffent :class:`.Enum` type.
153-
EnumMemberNotFoundError: If members are not found in :class:`.Enum`.
154-
NotImplementedError: If ``array`` is a scalar :class:`~numpy.ndarray`.
155-
156152
Examples:
157153
>>> import numpy
158154
@@ -201,70 +197,40 @@ def encode(cls, array: t.VarArray | t.ArrayLike[object]) -> t.EnumArray:
201197
:meth:`.EnumArray.decode` for decoding.
202198
203199
"""
204-
# Array of indices
205-
indices: t.IndexArray
206-
207200
if isinstance(array, EnumArray):
208201
return array
209-
210-
# Array-like
202+
if len(array) == 0:
203+
return EnumArray(numpy.asarray(array, t.EnumDType), cls)
211204
if isinstance(array, Sequence):
212-
if len(array) == 0:
213-
indices = numpy.array([], t.EnumDType)
214-
215-
elif _is_int_array_like(array):
216-
indices = _int_to_index(cls, array)
217-
218-
elif _is_str_array_like(array):
219-
indices = _str_to_index(cls, array)
220-
221-
elif _is_enum_array_like(array):
222-
indices = _enum_to_index(array)
223-
224-
else:
225-
raise EnumEncodingError(cls, array)
205+
return cls._encode_array_like(array)
206+
return cls._encode_array(array)
226207

208+
@classmethod
209+
def _encode_array(cls, value: t.VarArray) -> t.EnumArray:
210+
if _is_int_array(value):
211+
indices = _int_to_index(cls, value)
212+
elif _is_str_array(value): # type: ignore[unreachable]
213+
indices = _str_to_index(cls, value)
214+
elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__:
215+
indices = _enum_to_index(value)
227216
else:
228-
# Scalar arrays are not supported.
229-
if array.ndim == 0:
230-
msg = (
231-
"Scalar arrays are not supported: expecting a vector array, "
232-
f"instead. Please try again with `numpy.array([{array}])`."
233-
)
234-
raise NotImplementedError(msg)
235-
236-
# Empty arrays are returned as is.
237-
if array.size == 0:
238-
indices = numpy.array([], t.EnumDType)
239-
240-
# Index arrays.
241-
elif _is_int_array(array):
242-
indices = _int_to_index(cls, array)
243-
244-
# String arrays.
245-
elif _is_str_array(array): # type: ignore[unreachable]
246-
indices = _str_to_index(cls, array)
247-
248-
# Ensure we are comparing the comparable. The problem this fixes:
249-
# On entering this method "cls" will generally come from
250-
# variable.possible_values, while the array values may come from
251-
# directly importing a module containing an Enum class. However,
252-
# variables (and hence their possible_values) are loaded by a call
253-
# to load_module, which gives them a different identity from the
254-
# ones imported in the usual way.
255-
#
256-
# So, instead of relying on the "cls" passed in, we use only its
257-
# name to check that the values in the array, if non-empty, are of
258-
# the right type.
259-
elif _is_enum_array(array) and cls.__name__ is array[0].__class__.__name__:
260-
indices = _enum_to_index(array)
261-
262-
else:
263-
raise EnumEncodingError(cls, array)
264-
265-
if indices.size != len(array):
217+
raise EnumEncodingError(cls, value)
218+
if indices.size != len(value):
266219
raise EnumMemberNotFoundError(cls)
220+
return EnumArray(indices, cls)
267221

222+
@classmethod
223+
def _encode_array_like(cls, value: t.ArrayLike[object]) -> t.EnumArray:
224+
if _is_int_array_like(value):
225+
indices = _int_to_index(cls, value)
226+
elif _is_str_array_like(value): # type: ignore[unreachable]
227+
indices = _str_to_index(cls, value)
228+
elif _is_enum_array_like(value):
229+
indices = _enum_to_index(value)
230+
else:
231+
raise EnumEncodingError(cls, value)
232+
if indices.size != len(value):
233+
raise EnumMemberNotFoundError(cls)
268234
return EnumArray(indices, cls)
269235

270236

openfisca_core/indexed_enums/enum_array.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class EnumArray(t.EnumArray):
7070
"""
7171

7272
#: Enum type of the array items.
73-
possible_values: None | type[t.Enum] = None
73+
possible_values: None | type[t.Enum]
7474

7575
def __new__(
7676
cls,
@@ -157,8 +157,12 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
157157
isinstance(other, type(t.Enum))
158158
and other.__name__ is self.possible_values.__name__
159159
):
160-
index = numpy.array([enum.index for enum in self.possible_values])
161-
result = self.view(numpy.ndarray) == index[index <= max(self)]
160+
result = (
161+
self.view(numpy.ndarray)
162+
== self.possible_values.indices[
163+
self.possible_values.indices <= max(self)
164+
]
165+
)
162166
return result
163167
if (
164168
isinstance(other, t.Enum)
@@ -265,16 +269,16 @@ def decode(self) -> t.ObjArray:
265269
array([Housing.TENANT], dtype=object)
266270
267271
"""
272+
result: t.ObjArray
268273
if self.possible_values is None:
269274
msg = (
270275
f"The possible values of the {self.__class__.__name__} are "
271276
f"not defined."
272277
)
273278
raise TypeError(msg)
274-
return numpy.select(
275-
[self == item.index for item in self.possible_values],
276-
list(self.possible_values), # pyright: ignore[reportArgumentType]
277-
)
279+
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
280+
result = self.possible_values.enums[array]
281+
return result
278282

279283
def decode_to_str(self) -> t.StrArray:
280284
"""Decode itself to an array of strings.
@@ -300,16 +304,16 @@ def decode_to_str(self) -> t.StrArray:
300304
array(['TENANT'], dtype='<U6')
301305
302306
"""
307+
result: t.StrArray
303308
if self.possible_values is None:
304309
msg = (
305310
f"The possible values of the {self.__class__.__name__} are "
306311
f"not defined."
307312
)
308313
raise TypeError(msg)
309-
return numpy.select(
310-
[self == item.index for item in self.possible_values],
311-
[item.name for item in self.possible_values],
312-
)
314+
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
315+
result = self.possible_values.names[array]
316+
return result
313317

314318
def __repr__(self) -> str:
315319
return f"{self.__class__.__name__}({self.decode()!s})"

openfisca_core/indexed_enums/tests/test_enum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_enum_encode_with_enum_sequence():
3636
def test_enum_encode_with_enum_scalar_array():
3737
"""Does not encode when called with an enum scalar array."""
3838
array = numpy.array(Animal.DOG)
39-
with pytest.raises(NotImplementedError):
39+
with pytest.raises(TypeError):
4040
Animal.encode(array)
4141

4242

@@ -67,7 +67,7 @@ def test_enum_encode_with_int_sequence():
6767
def test_enum_encode_with_int_scalar_array():
6868
"""Does not encode when called with an int scalar array."""
6969
array = numpy.array(1)
70-
with pytest.raises(NotImplementedError):
70+
with pytest.raises(TypeError):
7171
Animal.encode(array)
7272

7373

@@ -98,7 +98,7 @@ def test_enum_encode_with_str_sequence():
9898
def test_enum_encode_with_str_scalar_array():
9999
"""Does not encode when called with a str scalar array."""
100100
array = numpy.array("DOG")
101-
with pytest.raises(NotImplementedError):
101+
with pytest.raises(TypeError):
102102
Animal.encode(array)
103103

104104

@@ -124,7 +124,7 @@ def test_enum_encode_with_any_scalar_array():
124124
"""Does not encode when called with unsupported types."""
125125
value = 1.5
126126
array = numpy.array(value)
127-
with pytest.raises(NotImplementedError):
127+
with pytest.raises(TypeError):
128128
Animal.encode(array)
129129

130130

openfisca_core/indexed_enums/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing_extensions import TypeAlias
22

3-
from openfisca_core.types import Array, ArrayLike, DTypeLike, Enum, EnumArray
3+
from openfisca_core.types import Array, ArrayLike, DTypeLike, Enum, EnumArray, EnumType
4+
5+
from enum import _EnumDict as EnumDict # noqa: PLC2701
46

57
from numpy import (
68
bool_ as BoolDType,
@@ -34,4 +36,6 @@
3436
"DTypeLike",
3537
"Enum",
3638
"EnumArray",
39+
"EnumDict",
40+
"EnumType",
3741
]

openfisca_core/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,13 @@ def plural(self, /) -> None | RolePlural: ...
108108
# Indexed enums
109109

110110

111-
class Enum(enum.Enum, metaclass=enum.EnumMeta):
111+
class EnumType(enum.EnumMeta):
112+
indices: Array[DTypeEnum]
113+
names: Array[DTypeStr]
114+
enums: Array[DTypeObject]
115+
116+
117+
class Enum(enum.Enum, metaclass=EnumType):
112118
index: int
113119
_member_names_: list[str]
114120

@@ -118,7 +124,7 @@ class EnumArray(Array[DTypeEnum], metaclass=abc.ABCMeta):
118124

119125
@abc.abstractmethod
120126
def __new__(
121-
cls, input_array: Array[DTypeEnum], possible_values: None | type[Enum] = ...
127+
cls, input_array: Array[DTypeEnum], possible_values: type[Enum]
122128
) -> Self: ...
123129

124130

0 commit comments

Comments
 (0)