From b8efde7d92ab812051cb46697e391399954a04b5 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 9 Oct 2024 19:51:14 +0200 Subject: [PATCH] test(enums): fix mypy errors (#1233) --- openfisca_core/indexed_enums/_enum_type.py | 17 ++++++---- openfisca_core/indexed_enums/enum.py | 2 +- openfisca_core/indexed_enums/enum_array.py | 20 +++++++++--- openfisca_core/indexed_enums/types.py | 24 +++++++------- openfisca_core/simulations/simulation.py | 6 +++- openfisca_core/types.py | 38 ++++++++++++++-------- 6 files changed, 68 insertions(+), 39 deletions(-) diff --git a/openfisca_core/indexed_enums/_enum_type.py b/openfisca_core/indexed_enums/_enum_type.py index 4208ab3ce..1af4b153c 100644 --- a/openfisca_core/indexed_enums/_enum_type.py +++ b/openfisca_core/indexed_enums/_enum_type.py @@ -7,15 +7,15 @@ from . import types as t -def _item_list(enum_class: type[t.Enum]) -> t.ItemList: +def _item_list(enum_class: t.EnumType) -> t.ItemList: """Return the non-vectorised list of enum items.""" - return [ + return [ # type: ignore[var-annotated] (index, name, value) for index, (name, value) in enumerate(enum_class.__members__.items()) ] -def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType: +def _item_dtype(enum_class: t.EnumType) -> t.RecDType: """Return the dtype of the indexed enum's items.""" size = max(map(len, enum_class.__members__.keys())) return numpy.dtype( @@ -30,7 +30,7 @@ def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType: ) -def _item_array(enum_class: type[t.Enum]) -> t.RecArray: +def _item_array(enum_class: t.EnumType) -> t.RecArray: """Return the indexed enum's items.""" items = _item_list(enum_class) dtype = _item_dtype(enum_class) @@ -76,17 +76,20 @@ class EnumType(t.EnumType): @property def indices(cls) -> t.IndexArray: """Return the indices of the indexed enum class.""" - return cls.items.index + indices: t.IndexArray = cls.items.index + return indices @property def names(cls) -> t.StrArray: """Return the names of the indexed enum class.""" - return cls.items.name + names: t.StrArray = cls.items.name + return names @property def enums(cls) -> t.ObjArray: """Return the members of the indexed enum class.""" - return cls.items.enum + enums: t.ObjArray = cls.items.enum + return enums def __new__( metacls, diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index a291acbd6..15bed5878 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -217,7 +217,7 @@ def encode( return EnumArray(indices, cls) # String array - if _is_str_array(array): + if _is_str_array(array): # type: ignore[unreachable] indices = cls.items[numpy.isin(cls.names, array)].index return EnumArray(indices, cls) diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index aa3db3f07..aa613315f 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -150,6 +150,8 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override] https://en.wikipedia.org/wiki/Liskov_substitution_principle """ + result: t.BoolArray + if self.possible_values is None: return NotImplemented if other is None: @@ -158,12 +160,16 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override] isinstance(other, type(t.Enum)) and other.__name__ is self.possible_values.__name__ ): - return self.view(numpy.ndarray) == other.indices[other.indices <= max(self)] + result = ( + self.view(numpy.ndarray) == other.indices[other.indices <= max(self)] + ) + return result if ( isinstance(other, t.Enum) and other.__class__.__name__ is self.possible_values.__name__ ): - return self.view(numpy.ndarray) == other.index + result = self.view(numpy.ndarray) == other.index + return result # For NumPy >=1.26.x. if isinstance(is_equal := self.view(numpy.ndarray) == other, numpy.ndarray): return is_equal @@ -263,6 +269,8 @@ def decode(self) -> t.ObjArray: array([Housing.TENANT], dtype=object) """ + result: t.ObjArray + if self.possible_values is None: msg = ( f"The possible values of the {self.__class__.__name__} are " @@ -271,7 +279,8 @@ def decode(self) -> t.ObjArray: raise TypeError(msg) arr = self.astype(t.EnumDType) arr = arr.reshape(1) if arr.ndim == 0 else arr - return self.possible_values.items[arr.astype(t.EnumDType)].enum + result = self.possible_values.items[arr.astype(t.EnumDType)].enum + return result def decode_to_str(self) -> t.StrArray: """Decode itself to an array of strings. @@ -297,6 +306,8 @@ def decode_to_str(self) -> t.StrArray: array(['TENANT'], dtype=' t.StrArray: raise TypeError(msg) arr = self.astype(t.EnumDType) arr = arr.reshape(1) if arr.ndim == 0 else arr - return self.possible_values.items[arr.astype(t.EnumDType)].name + result = self.possible_values.items[arr.astype(t.EnumDType)].name + return result def __repr__(self) -> str: items = ", ".join(str(item) for item in self.decode()) diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py index 72703d825..b43cdc1e5 100644 --- a/openfisca_core/indexed_enums/types.py +++ b/openfisca_core/indexed_enums/types.py @@ -3,28 +3,27 @@ from openfisca_core.types import ( Array, ArrayLike, - DTypeBool as BoolDType, - DTypeEnum as EnumDType, - DTypeGeneric as AnyDType, - DTypeInt as IntDType, DTypeLike, - DTypeObject as ObjDType, - DTypeStr as StrDType, Enum, EnumArray, EnumType, RecArray, ) -import enum +from enum import _EnumDict as EnumDict # noqa: PLC2701 import numpy - -#: Type for enum dicts. -EnumDict: TypeAlias = enum._EnumDict # noqa: SLF001 +from numpy import ( + bool_ as BoolDType, + generic as AnyDType, + int16 as EnumDType, + int32 as IntDType, + object_ as ObjDType, + str_ as StrDType, +) #: Type for the non-vectorised list of enum items. -ItemList: TypeAlias = list[tuple[int, str, Enum]] +ItemList: TypeAlias = list[tuple[int, str, EnumType]] #: Type for record arrays data type. RecDType: TypeAlias = numpy.dtype[numpy.void] @@ -48,12 +47,11 @@ AnyArray: TypeAlias = Array[AnyDType] __all__ = [ - "Array", "ArrayLike", - "BoolDType", "DTypeLike", "Enum", "EnumArray", + "EnumDict", "EnumType", "RecArray", ] diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index c32fea22a..b7d20fa97 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -3,7 +3,11 @@ from collections.abc import Mapping from typing import NamedTuple -from openfisca_core.types import Population, TaxBenefitSystem, Variable +from openfisca_core.types import ( + CorePopulation as Population, + TaxBenefitSystem, + Variable, +) import tempfile import warnings diff --git a/openfisca_core/types.py b/openfisca_core/types.py index 02d012687..b1d2a2710 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -137,8 +137,8 @@ def __new__( class Holder(Protocol): - def clone(self, population: Any, /) -> Holder: ... - def get_memory_usage(self, /) -> Any: ... + def clone(self, population: CorePopulation, /) -> Holder: ... + def get_memory_usage(self, /) -> dict[str, object]: ... # Parameters @@ -198,27 +198,39 @@ def offset(self, offset: str | int, unit: None | DateUnit = None, /) -> Period: # Populations -class Population(Protocol): - entity: Any +class CorePopulation(Protocol): ... - def get_holder(self, variable_name: VariableName, /) -> Any: ... + +class SinglePopulation(CorePopulation, Protocol): + entity: SingleEntity + + def get_holder(self, variable_name: VariableName, /) -> Holder: ... + + +class GroupPopulation(CorePopulation, Protocol): ... # Simulations class Simulation(Protocol): - def calculate(self, variable_name: VariableName, period: Any, /) -> Any: ... - def calculate_add(self, variable_name: VariableName, period: Any, /) -> Any: ... - def calculate_divide(self, variable_name: VariableName, period: Any, /) -> Any: ... - def get_population(self, plural: None | str, /) -> Any: ... + def calculate( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def calculate_add( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def calculate_divide( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def get_population(self, plural: None | str, /) -> CorePopulation: ... # Tax-Benefit systems class TaxBenefitSystem(Protocol): - person_entity: Any + person_entity: SingleEntity def get_variable( self, @@ -235,18 +247,18 @@ def get_variable( class Variable(Protocol): - entity: Any + entity: CoreEntity name: VariableName class Formula(Protocol): def __call__( self, - population: Population, + population: CorePopulation, instant: Instant, params: Params, /, - ) -> Array[Any]: ... + ) -> Array[DTypeGeneric]: ... class Params(Protocol):