Skip to content

Commit b5b7968

Browse files
committed
fix(enums): do actual indexing (#1267)
1 parent 0195451 commit b5b7968

File tree

9 files changed

+279
-68
lines changed

9 files changed

+279
-68
lines changed

openfisca_core/indexed_enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Enumerations for variables with a limited set of possible values."""
22

33
from . import types
4+
from ._enum_type import EnumType
45
from .config import ENUM_ARRAY_DTYPE
56
from .enum import Enum
67
from .enum_array import EnumArray
@@ -9,5 +10,6 @@
910
"ENUM_ARRAY_DTYPE",
1011
"Enum",
1112
"EnumArray",
13+
"EnumType",
1214
"types",
1315
]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
from typing import final
4+
5+
import numpy
6+
7+
from . import types as t
8+
9+
10+
def _item_list(enum_class: type[t.Enum]) -> t.ItemList:
11+
"""Return the non-vectorised list of enum items."""
12+
return [
13+
(index, name, value)
14+
for index, (name, value) in enumerate(enum_class.__members__.items())
15+
]
16+
17+
18+
def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType:
19+
"""Return the dtype of the indexed enum's items."""
20+
size = max(map(len, enum_class.__members__.keys()))
21+
return numpy.dtype(
22+
(
23+
numpy.generic,
24+
{
25+
"index": (t.EnumDType, 0),
26+
"name": (f"U{size}", 2),
27+
"enum": (enum_class, 2 + size * 4),
28+
},
29+
)
30+
)
31+
32+
33+
def _item_array(enum_class: type[t.Enum]) -> t.RecArray:
34+
"""Return the indexed enum's items."""
35+
items = _item_list(enum_class)
36+
dtype = _item_dtype(enum_class)
37+
array = numpy.array(items, dtype=dtype)
38+
return array.view(numpy.recarray)
39+
40+
41+
@final
42+
class EnumType(t.EnumType):
43+
"""Meta class for creating an indexed :class:`.Enum`.
44+
45+
Examples:
46+
>>> from openfisca_core import indexed_enums as enum
47+
48+
>>> class Enum(enum.Enum, metaclass=enum.EnumType):
49+
... pass
50+
51+
>>> Enum.items
52+
Traceback (most recent call last):
53+
AttributeError: type object 'Enum' has no attribute 'items'
54+
55+
>>> class Housing(Enum):
56+
... OWNER = "Owner"
57+
... TENANT = "Tenant"
58+
59+
>>> Housing.items
60+
rec.array([(0, 'OWNER', <Housing.OWNER: 'Owner'>), ...])
61+
62+
>>> Housing.indices
63+
array([0, 1], dtype=int16)
64+
65+
>>> Housing.names
66+
array(['OWNER', 'TENANT'], dtype='<U6')
67+
68+
>>> Housing.enums
69+
array([<Housing.OWNER: 'Owner'>, <Housing.TENANT: 'Tenant'>], dtype...)
70+
71+
"""
72+
73+
#: The items of the indexed enum class.
74+
items: t.RecArray
75+
76+
@property
77+
def indices(cls) -> t.IndexArray:
78+
"""Return the indices of the indexed enum class."""
79+
return cls.items.index
80+
81+
@property
82+
def names(cls) -> t.StrArray:
83+
"""Return the names of the indexed enum class."""
84+
return cls.items.name
85+
86+
@property
87+
def enums(cls) -> t.ObjArray:
88+
"""Return the members of the indexed enum class."""
89+
return cls.items.enum
90+
91+
def __new__(
92+
metacls,
93+
cls: str,
94+
bases: tuple[type, ...],
95+
classdict: t.EnumDict,
96+
**kwds: object,
97+
) -> t.EnumType:
98+
"""Create a new indexed enum class."""
99+
# Create the enum class.
100+
enum_class = super().__new__(metacls, cls, bases, classdict, **kwds)
101+
102+
# If the enum class has no members, return it as is.
103+
if not enum_class.__members__:
104+
return enum_class
105+
106+
# Add the items attribute to the enum class.
107+
enum_class.items = _item_array(enum_class)
108+
109+
# Return the modified enum class.
110+
return enum_class
111+
112+
def __dir__(cls) -> list[str]:
113+
return sorted({"items", "indices", "names", "enums", *super().__dir__()})
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from typing_extensions import TypeIs
4+
5+
import numpy
6+
7+
from . import types as t
8+
9+
10+
def _is_int_array(array: t.AnyArray) -> TypeIs[t.IndexArray | t.IntArray]:
11+
"""Narrow the type of a given array to an array of :obj:`numpy.integer`.
12+
13+
Args:
14+
array: Array to check.
15+
16+
Returns:
17+
bool: True if ``array`` is an array of :obj:`numpy.integer`, False otherwise.
18+
19+
Examples:
20+
>>> import numpy
21+
22+
>>> array = numpy.array([1], dtype=numpy.int16)
23+
>>> _is_int_array(array)
24+
True
25+
26+
>>> array = numpy.array([1], dtype=numpy.int32)
27+
>>> _is_int_array(array)
28+
True
29+
30+
>>> array = numpy.array([1.0])
31+
>>> _is_int_array(array)
32+
False
33+
34+
"""
35+
return numpy.issubdtype(array.dtype, numpy.integer)
36+
37+
38+
def _is_str_array(array: t.AnyArray) -> TypeIs[t.StrArray]:
39+
"""Narrow the type of a given array to an array of :obj:`numpy.str_`.
40+
41+
Args:
42+
array: Array to check.
43+
44+
Returns:
45+
bool: True if ``array`` is an array of :obj:`numpy.str_`, False otherwise.
46+
47+
Examples:
48+
>>> import numpy
49+
50+
>>> from openfisca_core import indexed_enums as enum
51+
52+
>>> class Housing(enum.Enum):
53+
... OWNER = "owner"
54+
... TENANT = "tenant"
55+
56+
>>> array = numpy.array([Housing.OWNER])
57+
>>> _is_str_array(array)
58+
False
59+
60+
>>> array = numpy.array(["owner"])
61+
>>> _is_str_array(array)
62+
True
63+
64+
"""
65+
return numpy.issubdtype(array.dtype, str)
66+
67+
68+
__all__ = ["_is_int_array", "_is_str_array"]

openfisca_core/indexed_enums/enum.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import numpy
44

55
from . import types as t
6-
from .config import ENUM_ARRAY_DTYPE
6+
from ._enum_type import EnumType
7+
from ._type_guards import _is_int_array, _is_str_array
78
from .enum_array import EnumArray
89

910

10-
class Enum(t.Enum):
11+
class Enum(t.Enum, metaclass=EnumType):
1112
"""Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_.
1213
1314
Its items have an :class:`int` index, useful and performant when running
@@ -115,20 +116,19 @@ def __ne__(self, other: object) -> bool:
115116
return NotImplemented
116117
return self.index != other.index
117118

118-
#: :meth:`.__hash__` must also be defined so as to stay hashable.
119-
__hash__ = object.__hash__
119+
def __hash__(self) -> int:
120+
return hash(self.index)
120121

121122
@classmethod
122123
def encode(
123124
cls,
124125
array: (
125126
EnumArray
126-
| t.Array[t.DTypeStr]
127-
| t.Array[t.DTypeInt]
128-
| t.Array[t.DTypeEnum]
129-
| t.Array[t.DTypeObject]
130-
| t.ArrayLike[str]
127+
| t.IntArray
128+
| t.StrArray
129+
| t.ObjArray
131130
| t.ArrayLike[int]
131+
| t.ArrayLike[str]
132132
| t.ArrayLike[t.Enum]
133133
),
134134
) -> EnumArray:
@@ -143,7 +143,6 @@ def encode(
143143
Raises:
144144
TypeError: If ``array`` is a scalar :class:`~numpy.ndarray`.
145145
TypeError: If ``array`` is of a diffent :class:`.Enum` type.
146-
NotImplementedError: If ``array`` is of an unsupported type.
147146
148147
Examples:
149148
>>> import numpy
@@ -187,7 +186,7 @@ def encode(
187186
>>> array = numpy.array([b"TENANT"])
188187
>>> enum_array = Housing.encode(array)
189188
Traceback (most recent call last):
190-
NotImplementedError: Unsupported encoding: bytes48.
189+
TypeError: Failed to encode "[b'TENANT']" of type 'bytes_', as i...
191190
192191
.. seealso::
193192
:meth:`.EnumArray.decode` for decoding.
@@ -200,7 +199,7 @@ def encode(
200199
return cls.encode(numpy.array(array))
201200

202201
if array.size == 0:
203-
return EnumArray(array, cls)
202+
return EnumArray(numpy.array([]), cls)
204203

205204
if array.ndim == 0:
206205
msg = (
@@ -209,49 +208,37 @@ def encode(
209208
)
210209
raise TypeError(msg)
211210

212-
# Enum data type array
213-
if numpy.issubdtype(array.dtype, t.DTypeEnum):
214-
indexes = numpy.array([item.index for item in cls], t.DTypeEnum)
215-
return EnumArray(indexes[array[array < indexes.size]], cls)
216-
217211
# Integer array
218-
if numpy.issubdtype(array.dtype, int):
219-
array = numpy.array(array, dtype=t.DTypeEnum)
220-
return cls.encode(array)
212+
if _is_int_array(array):
213+
indices = numpy.array(array[array < len(cls.items)], dtype=t.EnumDType)
214+
return EnumArray(indices, cls)
221215

222216
# String array
223-
if numpy.issubdtype(array.dtype, t.DTypeStr):
224-
enums = [cls.__members__[key] for key in array if key in cls.__members__]
225-
return cls.encode(enums)
226-
227-
# Enum items arrays
228-
if numpy.issubdtype(array.dtype, t.DTypeObject):
229-
# Ensure we are comparing the comparable. The problem this fixes:
230-
# On entering this method "cls" will generally come from
231-
# variable.possible_values, while the array values may come from
232-
# directly importing a module containing an Enum class. However,
233-
# variables (and hence their possible_values) are loaded by a call
234-
# to load_module, which gives them a different identity from the
235-
# ones imported in the usual way.
236-
#
237-
# So, instead of relying on the "cls" passed in, we use only its
238-
# name to check that the values in the array, if non-empty, are of
239-
# the right type.
240-
if cls.__name__ is array[0].__class__.__name__:
241-
array = numpy.select(
242-
[array == item for item in array[0].__class__],
243-
[item.index for item in array[0].__class__],
244-
).astype(ENUM_ARRAY_DTYPE)
245-
return EnumArray(array, cls)
246-
247-
msg = (
248-
f"Diverging enum types are not supported: expected {cls.__name__}, "
249-
f"but got {array[0].__class__.__name__} instead."
250-
)
251-
raise TypeError(msg)
252-
253-
msg = f"Unsupported encoding: {array.dtype.name}."
254-
raise NotImplementedError(msg)
217+
if _is_str_array(array):
218+
indices = cls.items[numpy.isin(cls.names, array)].index
219+
return EnumArray(indices, cls)
220+
221+
# Ensure we are comparing the comparable. The problem this fixes:
222+
# On entering this method "cls" will generally come from
223+
# variable.possible_values, while the array values may come from
224+
# directly importing a module containing an Enum class. However,
225+
# variables (and hence their possible_values) are loaded by a call
226+
# to load_module, which gives them a different identity from the
227+
# ones imported in the usual way.
228+
#
229+
# So, instead of relying on the "cls" passed in, we use only its
230+
# name to check that the values in the array, if non-empty, are of
231+
# the right type.
232+
if cls.__name__ is array[0].__class__.__name__:
233+
indices = cls.items[numpy.isin(cls.enums, array)].index
234+
return EnumArray(indices, cls)
235+
236+
msg = (
237+
f"Failed to encode \"{array}\" of type '{array[0].__class__.__name__}', "
238+
"as it is not supported. Please, try again with an array of "
239+
f"'{int.__name__}', '{str.__name__}', or '{cls.__name__}'."
240+
)
241+
raise TypeError(msg)
255242

256243

257244
__all__ = ["Enum"]

openfisca_core/indexed_enums/enum_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class EnumArray(t.EnumArray):
7474

7575
def __new__(
7676
cls,
77-
input_array: t.Array[t.DTypeEnum],
77+
input_array: t.IndexArray,
7878
possible_values: None | type[t.Enum] = None,
7979
) -> Self:
8080
"""See comment above."""

openfisca_core/indexed_enums/tests/test_enum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_enum_encode_with_array_of_enum():
2727

2828
def test_enum_encode_with_enum_sequence():
2929
"""Does encode when called with an enum sequence."""
30-
sequence = list(Animal)
30+
sequence = list(Animal) + list(Colour)
3131
enum_array = Animal.encode(sequence)
3232
assert Animal.DOG in enum_array
3333

@@ -89,7 +89,7 @@ def test_enum_encode_with_array_of_string():
8989

9090
def test_enum_encode_with_str_sequence():
9191
"""Does encode when called with a str sequence."""
92-
sequence = ("DOG",)
92+
sequence = ("DOG", "JAIBA")
9393
enum_array = Animal.encode(sequence)
9494
assert Animal.DOG in enum_array
9595

@@ -130,5 +130,5 @@ def test_enum_encode_with_any_scalar_array():
130130
def test_enum_encode_with_any_sequence():
131131
"""Does not encode when called with unsupported types."""
132132
sequence = memoryview(b"DOG")
133-
with pytest.raises(NotImplementedError):
134-
Animal.encode(sequence)
133+
enum_array = Animal.encode(sequence)
134+
assert len(enum_array) == 0

0 commit comments

Comments
 (0)