Skip to content

Commit 809affe

Browse files
authored
docs: add types to tracers (1/3) (#1280)
2 parents 2bd629b + d859a62 commit 809affe

File tree

12 files changed

+475
-226
lines changed

12 files changed

+475
-226
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
### 43.2.2 [#1280](https://github.com/openfisca/openfisca-core/pull/1280)
4+
5+
#### Documentation
6+
7+
- Add types to common tracers (`SimpleTracer`, `FlatTracer`, etc.)
8+
39
### 43.2.1 [#1283](https://github.com/openfisca/openfisca-core/pull/1283)
410

511
#### Technical changes

openfisca_core/populations/_errors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,9 @@ def __init__(
5757
super().__init__(msg)
5858

5959

60-
__all__ = ["InvalidArraySizeError", "PeriodValidityError"]
60+
__all__ = [
61+
"IncompatibleOptionsError",
62+
"InvalidArraySizeError",
63+
"InvalidOptionError",
64+
"PeriodValidityError",
65+
]

openfisca_core/populations/types.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Iterable, MutableMapping, Sequence
44
from typing import NamedTuple, Union
5-
from typing_extensions import NewType, TypeAlias, TypedDict
5+
from typing_extensions import TypeAlias, TypedDict
66

77
from openfisca_core.types import (
88
Array,
@@ -14,6 +14,7 @@
1414
Holder,
1515
MemoryUsage,
1616
Period,
17+
PeriodInt,
1718
PeriodStr,
1819
Role,
1920
Simulation,
@@ -52,9 +53,6 @@
5253

5354
# Periods
5455

55-
#: New type for a period integer.
56-
PeriodInt = NewType("PeriodInt", int)
57-
5856
#: Type alias for a period-like object.
5957
PeriodLike: TypeAlias = Union[Period, PeriodStr, PeriodInt]
6058

openfisca_core/tracers/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,22 @@
2121
#
2222
# See: https://www.python.org/dev/peps/pep-0008/#imports
2323

24-
from .computation_log import ComputationLog # noqa: F401
25-
from .flat_trace import FlatTrace # noqa: F401
26-
from .full_tracer import FullTracer # noqa: F401
27-
from .performance_log import PerformanceLog # noqa: F401
28-
from .simple_tracer import SimpleTracer # noqa: F401
29-
from .trace_node import TraceNode # noqa: F401
30-
from .tracing_parameter_node_at_instant import ( # noqa: F401
31-
TracingParameterNodeAtInstant,
32-
)
24+
from . import types
25+
from .computation_log import ComputationLog
26+
from .flat_trace import FlatTrace
27+
from .full_tracer import FullTracer
28+
from .performance_log import PerformanceLog
29+
from .simple_tracer import SimpleTracer
30+
from .trace_node import TraceNode
31+
from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant
32+
33+
__all__ = [
34+
"ComputationLog",
35+
"FlatTrace",
36+
"FullTracer",
37+
"PerformanceLog",
38+
"SimpleTracer",
39+
"TraceNode",
40+
"TracingParameterNodeAtInstant",
41+
"types",
42+
]
Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,24 @@
11
from __future__ import annotations
22

3-
import typing
4-
from typing import Union
3+
import sys
54

65
import numpy
76

87
from openfisca_core.indexed_enums import EnumArray
98

10-
if typing.TYPE_CHECKING:
11-
from numpy.typing import ArrayLike
12-
13-
from openfisca_core import tracers
14-
15-
Array = Union[EnumArray, ArrayLike]
9+
from . import types as t
1610

1711

1812
class ComputationLog:
19-
_full_tracer: tracers.FullTracer
13+
_full_tracer: t.FullTracer
2014

21-
def __init__(self, full_tracer: tracers.FullTracer) -> None:
15+
def __init__(self, full_tracer: t.FullTracer) -> None:
2216
self._full_tracer = full_tracer
2317

24-
def display(
25-
self,
26-
value: Array | None,
27-
) -> str:
28-
if isinstance(value, EnumArray):
29-
value = value.decode_to_str()
30-
31-
return numpy.array2string(value, max_line_width=float("inf"))
32-
3318
def lines(
3419
self,
3520
aggregate: bool = False,
36-
max_depth: int | None = None,
21+
max_depth: int = sys.maxsize,
3722
) -> list[str]:
3823
depth = 1
3924

@@ -44,7 +29,7 @@ def lines(
4429

4530
return self._flatten(lines_by_tree)
4631

47-
def print_log(self, aggregate=False, max_depth=None) -> None:
32+
def print_log(self, aggregate: bool = False, max_depth: int = sys.maxsize) -> None:
4833
"""Print the computation log of a simulation.
4934
5035
If ``aggregate`` is ``False`` (default), print the value of each
@@ -60,20 +45,20 @@ def print_log(self, aggregate=False, max_depth=None) -> None:
6045
If ``max_depth`` is set, for example to ``3``, only print computed
6146
vectors up to a depth of ``max_depth``.
6247
"""
63-
for _line in self.lines(aggregate, max_depth):
48+
for _ in self.lines(aggregate, max_depth):
6449
pass
6550

6651
def _get_node_log(
6752
self,
68-
node: tracers.TraceNode,
53+
node: t.TraceNode,
6954
depth: int,
7055
aggregate: bool,
71-
max_depth: int | None,
56+
max_depth: int = sys.maxsize,
7257
) -> list[str]:
73-
if max_depth is not None and depth > max_depth:
58+
if depth > max_depth:
7459
return []
7560

76-
node_log = [self._print_line(depth, node, aggregate, max_depth)]
61+
node_log = [self._print_line(depth, node, aggregate)]
7762

7863
children_logs = [
7964
self._get_node_log(child, depth + 1, aggregate, max_depth)
@@ -82,13 +67,7 @@ def _get_node_log(
8267

8368
return node_log + self._flatten(children_logs)
8469

85-
def _print_line(
86-
self,
87-
depth: int,
88-
node: tracers.TraceNode,
89-
aggregate: bool,
90-
max_depth: int | None,
91-
) -> str:
70+
def _print_line(self, depth: int, node: t.TraceNode, aggregate: bool) -> str:
9271
indent = " " * depth
9372
value = node.value
9473

@@ -97,9 +76,11 @@ def _print_line(
9776

9877
elif aggregate:
9978
try:
100-
formatted_value = str(
79+
formatted_value = str( # pyright: ignore[reportCallIssue]
10180
{
102-
"avg": numpy.mean(value),
81+
"avg": numpy.mean(
82+
value
83+
), # pyright: ignore[reportArgumentType,reportCallIssue]
10384
"max": numpy.max(value),
10485
"min": numpy.min(value),
10586
},
@@ -113,8 +94,15 @@ def _print_line(
11394

11495
return f"{indent}{node.name}<{node.period}> >> {formatted_value}"
11596

116-
def _flatten(
117-
self,
118-
lists: list[list[str]],
119-
) -> list[str]:
97+
@staticmethod
98+
def display(value: t.VarArray, max_depth: int = sys.maxsize) -> str:
99+
if isinstance(value, EnumArray):
100+
value = value.decode_to_str()
101+
return numpy.array2string(value, max_line_width=max_depth)
102+
103+
@staticmethod
104+
def _flatten(lists: list[list[str]]) -> list[str]:
120105
return [item for list_ in lists for item in list_]
106+
107+
108+
__all__ = ["ComputationLog"]

openfisca_core/tracers/flat_trace.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,20 @@
11
from __future__ import annotations
22

3-
import typing
4-
from typing import Union
5-
63
import numpy
74

85
from openfisca_core.indexed_enums import EnumArray
96

10-
if typing.TYPE_CHECKING:
11-
from numpy.typing import ArrayLike
12-
13-
from openfisca_core import tracers
14-
15-
Array = Union[EnumArray, ArrayLike]
16-
Trace = dict[str, dict]
7+
from . import types as t
178

189

1910
class FlatTrace:
20-
_full_tracer: tracers.FullTracer
11+
_full_tracer: t.FullTracer
2112

22-
def __init__(self, full_tracer: tracers.FullTracer) -> None:
13+
def __init__(self, full_tracer: t.FullTracer) -> None:
2314
self._full_tracer = full_tracer
2415

25-
def key(self, node: tracers.TraceNode) -> str:
26-
name = node.name
27-
period = node.period
28-
return f"{name}<{period}>"
29-
30-
def get_trace(self) -> dict:
31-
trace = {}
16+
def get_trace(self) -> t.FlatNodeMap:
17+
trace: t.FlatNodeMap = {}
3218

3319
for node in self._full_tracer.browse_trace():
3420
# We don't want cache read to overwrite data about the initial
@@ -45,34 +31,16 @@ def get_trace(self) -> dict:
4531

4632
return trace
4733

48-
def get_serialized_trace(self) -> dict:
34+
def get_serialized_trace(self) -> t.SerializedNodeMap:
4935
return {
5036
key: {**flat_trace, "value": self.serialize(flat_trace["value"])}
5137
for key, flat_trace in self.get_trace().items()
5238
}
5339

54-
def serialize(
55-
self,
56-
value: Array | None,
57-
) -> Array | None | list:
58-
if isinstance(value, EnumArray):
59-
value = value.decode_to_str()
60-
61-
if isinstance(value, numpy.ndarray) and numpy.issubdtype(
62-
value.dtype,
63-
numpy.dtype(bytes),
64-
):
65-
value = value.astype(numpy.dtype(str))
66-
67-
if isinstance(value, numpy.ndarray):
68-
value = value.tolist()
69-
70-
return value
71-
7240
def _get_flat_trace(
7341
self,
74-
node: tracers.TraceNode,
75-
) -> Trace:
42+
node: t.TraceNode,
43+
) -> t.FlatNodeMap:
7644
key = self.key(node)
7745

7846
return {
@@ -87,3 +55,34 @@ def _get_flat_trace(
8755
"formula_time": node.formula_time(),
8856
},
8957
}
58+
59+
@staticmethod
60+
def key(node: t.TraceNode) -> t.NodeKey:
61+
"""Return the key of a node."""
62+
name = node.name
63+
period = node.period
64+
return t.NodeKey(f"{name}<{period}>")
65+
66+
@staticmethod
67+
def serialize(
68+
value: None | t.VarArray | t.ArrayLike[object],
69+
) -> None | t.ArrayLike[object]:
70+
if value is None:
71+
return None
72+
73+
if isinstance(value, EnumArray):
74+
return value.decode_to_str().tolist()
75+
76+
if isinstance(value, numpy.ndarray) and numpy.issubdtype(
77+
value.dtype,
78+
numpy.dtype(bytes),
79+
):
80+
return value.astype(numpy.dtype(str)).tolist()
81+
82+
if isinstance(value, numpy.ndarray):
83+
return value.tolist()
84+
85+
return value
86+
87+
88+
__all__ = ["FlatTrace"]

0 commit comments

Comments
 (0)