Skip to content

Commit 5b8d8b3

Browse files
Merge pull request #210 from connorferster/features/iteration_recorder
HandcalcsCallRecorder
2 parents 4a1250d + cbaa06e commit 5b8d8b3

File tree

3 files changed

+135
-32
lines changed

3 files changed

+135
-32
lines changed

handcalcs/decorator.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ["handcalc"]
22

3-
from typing import Optional
4-
from functools import wraps
3+
from typing import Optional, Callable
4+
from functools import wraps, update_wrapper
55
import inspect
66
import innerscope
77
from .handcalcs import LatexRenderer
@@ -13,43 +13,112 @@ def handcalc(
1313
left: str = "",
1414
right: str = "",
1515
scientific_notation: Optional[bool] = None,
16-
decimal_separator: str = ".",
1716
jupyter_display: bool = False,
17+
record: bool = False,
1818
):
1919
def handcalc_decorator(func):
20-
@wraps(func)
21-
def wrapper(*args, **kwargs):
22-
line_args = {
23-
"override": override,
24-
"precision": precision,
25-
"sci_not": scientific_notation,
26-
}
27-
func_source = inspect.getsource(func)
28-
cell_source = _func_source_to_cell(func_source)
29-
# use innerscope to get the values of locals, closures, and globals when calling func
30-
scope = innerscope.call(func, *args, **kwargs)
31-
LatexRenderer.dec_sep = decimal_separator
32-
renderer = LatexRenderer(cell_source, scope, line_args)
33-
latex_code = renderer.render()
34-
if jupyter_display:
35-
try:
36-
from IPython.display import Latex, display
37-
except ModuleNotFoundError:
38-
ModuleNotFoundError(
39-
"jupyter_display option requires IPython.display to be installed."
40-
)
41-
display(Latex(latex_code))
42-
return scope.return_value
20+
if record:
21+
decorated = HandcalcsCallRecorder(
22+
func,
23+
override,
24+
precision,
25+
left,
26+
right,
27+
scientific_notation,
28+
jupyter_display,
29+
)
30+
else:
4331

44-
# https://stackoverflow.com/questions/9943504/right-to-left-string-replace-in-python
45-
latex_code = "".join(latex_code.replace("\\[", "", 1).rsplit("\\]", 1))
46-
return (left + latex_code + right, scope.return_value)
32+
@wraps(func)
33+
def decorated(*args, **kwargs):
34+
line_args = {
35+
"override": override,
36+
"precision": precision,
37+
"sci_not": scientific_notation,
38+
}
39+
func_source = inspect.getsource(func)
40+
cell_source = _func_source_to_cell(func_source)
41+
# innerscope retrieves values of locals, closures, and globals
42+
scope = innerscope.call(func, *args, **kwargs)
43+
renderer = LatexRenderer(cell_source, scope, line_args)
44+
latex_code = renderer.render()
45+
raw_latex_code = "".join(
46+
latex_code.replace("\\[", "", 1).rsplit("\\]", 1)
47+
)
48+
if jupyter_display:
49+
try:
50+
from IPython.display import Latex, display
51+
except ModuleNotFoundError:
52+
ModuleNotFoundError(
53+
"jupyter_display option requires IPython.display to be installed."
54+
)
55+
display(Latex(latex_code))
56+
return scope.return_value
57+
return (left + raw_latex_code + right, scope.return_value)
4758

48-
return wrapper
59+
return decorated
4960

5061
return handcalc_decorator
5162

5263

64+
class HandcalcsCallRecorder:
65+
"""
66+
Records function calls for the func stored in .callable
67+
"""
68+
69+
def __init__(
70+
self,
71+
func: Callable,
72+
_override: str = "",
73+
_precision: int = 3,
74+
_left: str = "",
75+
_right: str = "",
76+
_scientific_notation: Optional[bool] = None,
77+
_jupyter_display: bool = False,
78+
):
79+
self.callable = func
80+
self.history = list()
81+
self._override = _override
82+
self._precision = _precision
83+
self._left = _left
84+
self._right = _right
85+
self._scientific_notation = _scientific_notation
86+
self._jupyter_display = _jupyter_display
87+
update_wrapper(self, func)
88+
89+
def __repr__(self):
90+
return f"{self.__class__.__name__}({self.callable.__name__}, num_of_calls: {len(self.history)})"
91+
92+
@property
93+
def calls(self):
94+
return len(self.history)
95+
96+
def __call__(self, *args, **kwargs):
97+
line_args = {
98+
"override": self._override,
99+
"precision": self._precision,
100+
"sci_not": self._scientific_notation,
101+
}
102+
func_source = inspect.getsource(self.callable)
103+
cell_source = _func_source_to_cell(func_source)
104+
# innerscope retrieves values of locals, closures, and globals
105+
scope = innerscope.call(self.callable, *args, **kwargs)
106+
renderer = LatexRenderer(cell_source, scope, line_args)
107+
latex_code = renderer.render()
108+
raw_latex_code = "".join(latex_code.replace("\\[", "", 1).rsplit("\\]", 1))
109+
self.history.append({"return": scope.return_value, "latex": raw_latex_code})
110+
if self._jupyter_display:
111+
try:
112+
from IPython.display import Latex, display
113+
except ModuleNotFoundError:
114+
ModuleNotFoundError(
115+
"jupyter_display option requires IPython.display to be installed."
116+
)
117+
display(Latex(latex_code))
118+
return scope.return_value
119+
return (self._left + raw_latex_code + self._right, scope.return_value)
120+
121+
53122
def _func_source_to_cell(source: str):
54123
"""
55124
Returns a string that represents `source` but with no signature, doc string,

handcalcs/handcalcs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def dict_get(d: dict, item: Any) -> Any:
173173

174174
# The renderer class ("output" class)
175175
class LatexRenderer:
176-
# dec_sep = "."
177-
178176
def __init__(self, python_code_str: str, results: dict, line_args: dict):
179177
self.source = python_code_str
180178
self.results = results

test_handcalcs/test_decorator_file.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from handcalcs.decorator import HandcalcsCallRecorder, handcalc
2+
import pytest
3+
4+
# Define a simple arithmetic function for testing
5+
def simple_func(a: float, b: float) -> float:
6+
c = a + b
7+
return c
8+
9+
@pytest.fixture
10+
def recorder():
11+
return HandcalcsCallRecorder(simple_func)
12+
13+
def test_simple_arithmetic(recorder):
14+
result = recorder(1, 2)
15+
assert result[1] == 3
16+
assert result[0] == '\n\\begin{aligned}\nc &= a + b = 1 + 2 &= 3 \n\\end{aligned}\n'
17+
18+
def test_call_recording(recorder):
19+
recorder(1.0, 2.0)
20+
recorder(3.5, 4.5)
21+
assert recorder.calls == 2 # There should be two recorded calls.
22+
assert recorder.history[0]['return'] == 3.0
23+
assert recorder.history[1]['return'] == 8.0
24+
25+
def test_decorator_with_recording():
26+
decorated_func = handcalc(record=True)(simple_func)
27+
result = decorated_func(1.0, 2.0)
28+
assert result[1] == 3.0
29+
assert decorated_func.calls == 1
30+
assert decorated_func.history[0]['return'] == 3.0
31+
assert decorated_func.history[0]['latex'] == '\n\\begin{aligned}\nc &= a + b = 1.000 + 2.000 &= 3.000 \n\\end{aligned}\n'
32+
33+
decorated_func = handcalc(record=False)(simple_func)
34+
latex, result = decorated_func(1.0, 2.0)
35+
assert result == 3.0
36+
assert latex == '\n\\begin{aligned}\nc &= a + b = 1.000 + 2.000 &= 3.000 \n\\end{aligned}\n'

0 commit comments

Comments
 (0)