Skip to content

Commit

Permalink
test(enums): fix mypy errors (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 9, 2024
1 parent 803680c commit b8efde7
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 39 deletions.
17 changes: 10 additions & 7 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from . import types as t


def _item_list(enum_class: type[t.Enum]) -> t.ItemList:
def _item_list(enum_class: t.EnumType) -> t.ItemList:
"""Return the non-vectorised list of enum items."""
return [
return [ # type: ignore[var-annotated]
(index, name, value)
for index, (name, value) in enumerate(enum_class.__members__.items())
]


def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType:
def _item_dtype(enum_class: t.EnumType) -> t.RecDType:
"""Return the dtype of the indexed enum's items."""
size = max(map(len, enum_class.__members__.keys()))
return numpy.dtype(
Expand All @@ -30,7 +30,7 @@ def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType:
)


def _item_array(enum_class: type[t.Enum]) -> t.RecArray:
def _item_array(enum_class: t.EnumType) -> t.RecArray:
"""Return the indexed enum's items."""
items = _item_list(enum_class)
dtype = _item_dtype(enum_class)
Expand Down Expand Up @@ -76,17 +76,20 @@ class EnumType(t.EnumType):
@property
def indices(cls) -> t.IndexArray:
"""Return the indices of the indexed enum class."""
return cls.items.index
indices: t.IndexArray = cls.items.index
return indices

@property
def names(cls) -> t.StrArray:
"""Return the names of the indexed enum class."""
return cls.items.name
names: t.StrArray = cls.items.name
return names

@property
def enums(cls) -> t.ObjArray:
"""Return the members of the indexed enum class."""
return cls.items.enum
enums: t.ObjArray = cls.items.enum
return enums

def __new__(
metacls,
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def encode(
return EnumArray(indices, cls)

# String array
if _is_str_array(array):
if _is_str_array(array): # type: ignore[unreachable]
indices = cls.items[numpy.isin(cls.names, array)].index
return EnumArray(indices, cls)

Expand Down
20 changes: 16 additions & 4 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
https://en.wikipedia.org/wiki/Liskov_substitution_principle
"""
result: t.BoolArray

if self.possible_values is None:
return NotImplemented
if other is None:
Expand All @@ -158,12 +160,16 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
isinstance(other, type(t.Enum))
and other.__name__ is self.possible_values.__name__
):
return self.view(numpy.ndarray) == other.indices[other.indices <= max(self)]
result = (
self.view(numpy.ndarray) == other.indices[other.indices <= max(self)]
)
return result
if (
isinstance(other, t.Enum)
and other.__class__.__name__ is self.possible_values.__name__
):
return self.view(numpy.ndarray) == other.index
result = self.view(numpy.ndarray) == other.index
return result
# For NumPy >=1.26.x.
if isinstance(is_equal := self.view(numpy.ndarray) == other, numpy.ndarray):
return is_equal
Expand Down Expand Up @@ -263,6 +269,8 @@ def decode(self) -> t.ObjArray:
array([Housing.TENANT], dtype=object)
"""
result: t.ObjArray

if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
Expand All @@ -271,7 +279,8 @@ def decode(self) -> t.ObjArray:
raise TypeError(msg)
arr = self.astype(t.EnumDType)
arr = arr.reshape(1) if arr.ndim == 0 else arr
return self.possible_values.items[arr.astype(t.EnumDType)].enum
result = self.possible_values.items[arr.astype(t.EnumDType)].enum
return result

def decode_to_str(self) -> t.StrArray:
"""Decode itself to an array of strings.
Expand All @@ -297,6 +306,8 @@ def decode_to_str(self) -> t.StrArray:
array(['TENANT'], dtype='<U6')
"""
result: t.StrArray

if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
Expand All @@ -305,7 +316,8 @@ def decode_to_str(self) -> t.StrArray:
raise TypeError(msg)
arr = self.astype(t.EnumDType)
arr = arr.reshape(1) if arr.ndim == 0 else arr
return self.possible_values.items[arr.astype(t.EnumDType)].name
result = self.possible_values.items[arr.astype(t.EnumDType)].name
return result

def __repr__(self) -> str:
items = ", ".join(str(item) for item in self.decode())
Expand Down
24 changes: 11 additions & 13 deletions openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
from openfisca_core.types import (
Array,
ArrayLike,
DTypeBool as BoolDType,
DTypeEnum as EnumDType,
DTypeGeneric as AnyDType,
DTypeInt as IntDType,
DTypeLike,
DTypeObject as ObjDType,
DTypeStr as StrDType,
Enum,
EnumArray,
EnumType,
RecArray,
)

import enum
from enum import _EnumDict as EnumDict # noqa: PLC2701

import numpy

#: Type for enum dicts.
EnumDict: TypeAlias = enum._EnumDict # noqa: SLF001
from numpy import (
bool_ as BoolDType,
generic as AnyDType,
int16 as EnumDType,
int32 as IntDType,
object_ as ObjDType,
str_ as StrDType,
)

#: Type for the non-vectorised list of enum items.
ItemList: TypeAlias = list[tuple[int, str, Enum]]
ItemList: TypeAlias = list[tuple[int, str, EnumType]]

#: Type for record arrays data type.
RecDType: TypeAlias = numpy.dtype[numpy.void]
Expand All @@ -48,12 +47,11 @@
AnyArray: TypeAlias = Array[AnyDType]

__all__ = [
"Array",
"ArrayLike",
"BoolDType",
"DTypeLike",
"Enum",
"EnumArray",
"EnumDict",
"EnumType",
"RecArray",
]
6 changes: 5 additions & 1 deletion openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from collections.abc import Mapping
from typing import NamedTuple

from openfisca_core.types import Population, TaxBenefitSystem, Variable
from openfisca_core.types import (
CorePopulation as Population,
TaxBenefitSystem,
Variable,
)

import tempfile
import warnings
Expand Down
38 changes: 25 additions & 13 deletions openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def __new__(


class Holder(Protocol):
def clone(self, population: Any, /) -> Holder: ...
def get_memory_usage(self, /) -> Any: ...
def clone(self, population: CorePopulation, /) -> Holder: ...
def get_memory_usage(self, /) -> dict[str, object]: ...


# Parameters
Expand Down Expand Up @@ -198,27 +198,39 @@ def offset(self, offset: str | int, unit: None | DateUnit = None, /) -> Period:
# Populations


class Population(Protocol):
entity: Any
class CorePopulation(Protocol): ...

def get_holder(self, variable_name: VariableName, /) -> Any: ...

class SinglePopulation(CorePopulation, Protocol):
entity: SingleEntity

def get_holder(self, variable_name: VariableName, /) -> Holder: ...


class GroupPopulation(CorePopulation, Protocol): ...


# Simulations


class Simulation(Protocol):
def calculate(self, variable_name: VariableName, period: Any, /) -> Any: ...
def calculate_add(self, variable_name: VariableName, period: Any, /) -> Any: ...
def calculate_divide(self, variable_name: VariableName, period: Any, /) -> Any: ...
def get_population(self, plural: None | str, /) -> Any: ...
def calculate(
self, variable_name: VariableName, period: Period, /
) -> Array[DTypeGeneric]: ...
def calculate_add(
self, variable_name: VariableName, period: Period, /
) -> Array[DTypeGeneric]: ...
def calculate_divide(
self, variable_name: VariableName, period: Period, /
) -> Array[DTypeGeneric]: ...
def get_population(self, plural: None | str, /) -> CorePopulation: ...


# Tax-Benefit systems


class TaxBenefitSystem(Protocol):
person_entity: Any
person_entity: SingleEntity

def get_variable(
self,
Expand All @@ -235,18 +247,18 @@ def get_variable(


class Variable(Protocol):
entity: Any
entity: CoreEntity
name: VariableName


class Formula(Protocol):
def __call__(
self,
population: Population,
population: CorePopulation,
instant: Instant,
params: Params,
/,
) -> Array[Any]: ...
) -> Array[DTypeGeneric]: ...


class Params(Protocol):
Expand Down

0 comments on commit b8efde7

Please sign in to comment.