diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 06fc1fbc9..2e9ebf148 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -141,8 +141,10 @@ def __eq__(self, other: object) -> bool: """ if other.__class__.__name__ is self.possible_values.__name__: return self.view(numpy.ndarray) == other.index - - return self.view(numpy.ndarray) == other + is_eq = self.view(numpy.ndarray) == other + if isinstance(is_eq, numpy.ndarray): + return is_eq + return numpy.array([is_eq], dtype=t.BoolDType) def __ne__(self, other: object) -> bool: """Inequality. diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py index a16b03750..ffc2cc9f2 100644 --- a/openfisca_core/indexed_enums/types.py +++ b/openfisca_core/indexed_enums/types.py @@ -4,6 +4,7 @@ from openfisca_core.types import ( Array, ArrayLike, + DTypeBool as BoolDType, DTypeEnum as EnumDType, DTypeGeneric as AnyDType, DTypeInt as IntDType, @@ -49,6 +50,7 @@ __all__ = [ "Array", "ArrayLike", + "BoolDType", "DTypeLike", "Enum", "EnumArray",