Skip to content

Commit 448d963

Browse files
committed
docs(tracers): add types to full tracer (#1245)
1 parent 12f8d0c commit 448d963

File tree

4 files changed

+161
-24
lines changed

4 files changed

+161
-24
lines changed

openfisca_core/tracers/flat_trace.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,6 @@ def get_serialized_trace(self) -> t.SerializedNodeMap:
3737
for key, flat_trace in self.get_trace().items()
3838
}
3939

40-
def serialize(
41-
self,
42-
value: None | t.VarArray | t.ArrayLike[object],
43-
) -> None | t.ArrayLike[object]:
44-
if value is None:
45-
return None
46-
47-
if isinstance(value, EnumArray):
48-
return value.decode_to_str().tolist()
49-
50-
if isinstance(value, numpy.ndarray) and numpy.issubdtype(
51-
value.dtype,
52-
numpy.dtype(bytes),
53-
):
54-
return value.astype(numpy.dtype(str)).tolist()
55-
56-
if isinstance(value, numpy.ndarray):
57-
return value.tolist()
58-
59-
return value
60-
6140
def _get_flat_trace(
6241
self,
6342
node: t.TraceNode,
@@ -83,3 +62,27 @@ def key(node: t.TraceNode) -> t.NodeKey:
8362
name = node.name
8463
period = node.period
8564
return t.NodeKey(f"{name}<{period}>")
65+
66+
@staticmethod
67+
def serialize(
68+
value: None | t.VarArray | t.ArrayLike[object],
69+
) -> None | t.ArrayLike[object]:
70+
if value is None:
71+
return None
72+
73+
if isinstance(value, EnumArray):
74+
return value.decode_to_str().tolist()
75+
76+
if isinstance(value, numpy.ndarray) and numpy.issubdtype(
77+
value.dtype,
78+
numpy.dtype(bytes),
79+
):
80+
return value.astype(numpy.dtype(str)).tolist()
81+
82+
if isinstance(value, numpy.ndarray):
83+
return value.tolist()
84+
85+
return value
86+
87+
88+
__all__ = ["FlatTrace"]

openfisca_core/tracers/full_tracer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ def generate_performance_tables(self, dir_path: str) -> None:
9191
def get_nb_requests(self, variable: str) -> int:
9292
return sum(self._get_nb_requests(tree, variable) for tree in self.trees)
9393

94-
def get_flat_trace(self) -> dict:
94+
def get_flat_trace(self) -> t.FlatNodeMap:
9595
return self.flat_trace.get_trace()
9696

97-
def get_serialized_flat_trace(self) -> dict:
97+
def get_serialized_flat_trace(self) -> t.SerializedNodeMap:
9898
return self.flat_trace.get_serialized_trace()
9999

100100
def browse_trace(self) -> Iterator[t.TraceNode]:
@@ -161,3 +161,6 @@ def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int:
161161
@staticmethod
162162
def _get_time_in_sec() -> t.Time:
163163
return time.time_ns() / (10**9)
164+
165+
166+
__all__ = ["FullTracer"]

openfisca_core/tracers/types.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from typing import NewType, Protocol
5+
from typing_extensions import TypeAlias, TypedDict
6+
7+
from openfisca_core.types import (
8+
Array,
9+
ArrayLike,
10+
ParameterNode,
11+
ParameterNodeChild,
12+
Period,
13+
PeriodInt,
14+
VariableName,
15+
)
16+
17+
from numpy import generic as VarDType
18+
19+
#: A type of a generic array.
20+
VarArray: TypeAlias = Array[VarDType]
21+
22+
#: A type representing a unit time.
23+
Time: TypeAlias = float
24+
25+
#: A type representing a mapping of flat traces.
26+
FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"]
27+
28+
#: A type representing a mapping of serialized traces.
29+
SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"]
30+
31+
#: A stack of simple traces.
32+
SimpleStack: TypeAlias = list["SimpleTraceMap"]
33+
34+
#: Key of a trace.
35+
NodeKey = NewType("NodeKey", str)
36+
37+
38+
class FlatTraceMap(TypedDict, total=True):
39+
dependencies: list[NodeKey]
40+
parameters: dict[NodeKey, None | ArrayLike[object]]
41+
value: None | VarArray
42+
calculation_time: Time
43+
formula_time: Time
44+
45+
46+
class SerializedTraceMap(TypedDict, total=True):
47+
dependencies: list[NodeKey]
48+
parameters: dict[NodeKey, None | ArrayLike[object]]
49+
value: None | ArrayLike[object]
50+
calculation_time: Time
51+
formula_time: Time
52+
53+
54+
class SimpleTraceMap(TypedDict, total=True):
55+
name: VariableName
56+
period: int | Period
57+
58+
59+
class ComputationLog(Protocol):
60+
def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ...
61+
62+
63+
class FlatTrace(Protocol):
64+
def get_trace(self, /) -> FlatNodeMap: ...
65+
def get_serialized_trace(self, /) -> SerializedNodeMap: ...
66+
67+
68+
class FullTracer(Protocol):
69+
@property
70+
def trees(self, /) -> list[TraceNode]: ...
71+
def browse_trace(self, /) -> Iterator[TraceNode]: ...
72+
73+
74+
class PerformanceLog(Protocol):
75+
def generate_graph(self, dir_path: str, /) -> None: ...
76+
def generate_performance_tables(self, dir_path: str, /) -> None: ...
77+
78+
79+
class SimpleTracer(Protocol):
80+
@property
81+
def stack(self, /) -> SimpleStack: ...
82+
def record_calculation_start(
83+
self, variable: VariableName, period: PeriodInt | Period, /
84+
) -> None: ...
85+
def record_calculation_end(self, /) -> None: ...
86+
87+
88+
class TraceNode(Protocol):
89+
children: list[TraceNode]
90+
end: Time
91+
name: str
92+
parameters: list[TraceNode]
93+
parent: None | TraceNode
94+
period: PeriodInt | Period
95+
start: Time
96+
value: None | VarArray
97+
98+
def calculation_time(self, *, round_: bool = ...) -> Time: ...
99+
def formula_time(self, /) -> Time: ...
100+
def append_child(self, node: TraceNode, /) -> None: ...
101+
102+
103+
__all__ = [
104+
"ArrayLike",
105+
"ParameterNode",
106+
"ParameterNodeChild",
107+
"PeriodInt",
108+
]

openfisca_core/types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,31 @@ class MemoryUsage(TypedDict, total=False):
148148

149149
# Parameters
150150

151+
#: A type representing a node of parameters.
152+
ParameterNode: TypeAlias = Union[
153+
"ParameterNodeAtInstant", "VectorialParameterNodeAtInstant"
154+
]
151155

152-
class ParameterNodeAtInstant(Protocol): ...
156+
#: A type representing a ???
157+
ParameterNodeChild: TypeAlias = Union[ParameterNode, ArrayLike[object]]
158+
159+
160+
class ParameterNodeAtInstant(Protocol):
161+
_instant_str: InstantStr
162+
163+
def __contains__(self, __item: object, /) -> bool: ...
164+
def __getitem__(
165+
self, __index: str | Array[DTypeGeneric], /
166+
) -> ParameterNodeChild: ...
167+
168+
169+
class VectorialParameterNodeAtInstant(Protocol):
170+
_instant_str: InstantStr
171+
172+
def __contains__(self, item: object, /) -> bool: ...
173+
def __getitem__(
174+
self, __index: str | Array[DTypeGeneric], /
175+
) -> ParameterNodeChild: ...
153176

154177

155178
# Periods

0 commit comments

Comments
 (0)