Skip to content

Commit 2114923

Browse files
authored
Add DocstringRemover (#197)
* Add DocstringRemover * fix test
1 parent 0cba4c9 commit 2114923

File tree

7 files changed

+116
-0
lines changed

7 files changed

+116
-0
lines changed

src/latexify/ast_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ def is_constant(node: ast.AST) -> bool:
9696
return isinstance(node, ast.Constant)
9797

9898

99+
def is_str(node: ast.AST) -> bool:
100+
"""Checks if the node is a str constant.
101+
102+
Args:
103+
node: The node to examine.
104+
105+
Returns:
106+
True if the node is a str constant, False otherwise.
107+
"""
108+
if sys.version_info.minor < 8 and isinstance(node, ast.Str):
109+
return True
110+
111+
return isinstance(node, ast.Constant) and isinstance(node.value, str)
112+
113+
99114
def extract_int_or_none(node: ast.expr) -> int | None:
100115
"""Extracts int constant from the given Constant node.
101116

src/latexify/ast_utils_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,38 @@ def test_is_constant(value: ast.AST, expected: bool) -> None:
114114
assert ast_utils.is_constant(value) is expected
115115

116116

117+
@test_utils.require_at_most(7)
118+
@pytest.mark.parametrize(
119+
"value,expected",
120+
[
121+
(ast.Bytes(s=b"foo"), False),
122+
(ast.Constant("bar"), True),
123+
(ast.Ellipsis(), False),
124+
(ast.NameConstant(value=None), False),
125+
(ast.Num(n=123), False),
126+
(ast.Str(s="baz"), True),
127+
(ast.Expr(value=ast.Num(456)), False),
128+
(ast.Global(names=["qux"]), False),
129+
],
130+
)
131+
def test_is_str_legacy(value: ast.AST, expected: bool) -> None:
132+
assert ast_utils.is_str(value) is expected
133+
134+
135+
@test_utils.require_at_least(8)
136+
@pytest.mark.parametrize(
137+
"value,expected",
138+
[
139+
(ast.Constant(value=123), False),
140+
(ast.Constant(value="foo"), True),
141+
(ast.Expr(value=ast.Constant(value="foo")), False),
142+
(ast.Global(names=["foo"]), False),
143+
],
144+
)
145+
def test_is_str(value: ast.AST, expected: bool) -> None:
146+
assert ast_utils.is_str(value) is expected
147+
148+
117149
def test_extract_int_or_none() -> None:
118150
assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123
119151
assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0

src/latexify/generate_latex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def get_latex(
5656
if merged_config.identifiers is not None:
5757
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
5858
if merged_config.reduce_assignments:
59+
tree = transformers.DocstringRemover().visit(tree)
5960
tree = transformers.AssignmentReducer().visit(tree)
6061
if merged_config.expand_functions is not None:
6162
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)

src/latexify/generate_latex_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def f(x):
5454
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag
5555

5656

57+
def test_get_latex_reduce_assignments_with_docstring() -> None:
58+
def f(x):
59+
"""DocstringRemover is required."""
60+
y = 3 * x
61+
return y
62+
63+
latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}"
64+
latex_with_flag = r"f(x) = 3 x"
65+
66+
assert generate_latex.get_latex(f) == latex_without_flag
67+
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
68+
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag
69+
70+
5771
def test_get_latex_reduce_assignments_with_aug_assign() -> None:
5872
def f(x):
5973
y = 3

src/latexify/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from latexify.transformers.assignment_reducer import AssignmentReducer
44
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
5+
from latexify.transformers.docstring_remover import DocstringRemover
56
from latexify.transformers.function_expander import FunctionExpander
67
from latexify.transformers.identifier_replacer import IdentifierReplacer
78
from latexify.transformers.prefix_trimmer import PrefixTrimmer
89

910
__all__ = [
1011
"AssignmentReducer",
1112
"AugAssignReplacer",
13+
"DocstringRemover",
1214
"FunctionExpander",
1315
"IdentifierReplacer",
1416
"PrefixTrimmer",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Transformer to remove all docstrings."""
2+
3+
from __future__ import annotations
4+
5+
import ast
6+
from typing import Union
7+
8+
from latexify import ast_utils
9+
10+
11+
class DocstringRemover(ast.NodeTransformer):
12+
"""NodeTransformer to remove all docstrings.
13+
14+
Docstrings here are detected as Expr nodes with a single string constant.
15+
"""
16+
17+
def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]:
18+
if ast_utils.is_str(node.value):
19+
return None
20+
return node
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Tests for latexify.transformers.docstring_remover."""
2+
3+
import ast
4+
5+
from latexify import ast_utils, parser, test_utils
6+
from latexify.transformers.docstring_remover import DocstringRemover
7+
8+
9+
def test_remove_docstrings() -> None:
10+
def f():
11+
"""Test docstring."""
12+
x = 42
13+
f() # This Expr should not be removed.
14+
"""This string constant should also be removed."""
15+
return x
16+
17+
tree = parser.parse_function(f).body[0]
18+
assert isinstance(tree, ast.FunctionDef)
19+
20+
expected = ast.FunctionDef(
21+
name="f",
22+
body=[
23+
ast.Assign(
24+
targets=[ast.Name(id="x", ctx=ast.Store())],
25+
value=ast_utils.make_constant(42),
26+
),
27+
ast.Expr(value=ast.Call(func=ast.Name(id="f", ctx=ast.Load()))),
28+
ast.Return(value=ast.Name(id="x", ctx=ast.Load())),
29+
],
30+
)
31+
transformed = DocstringRemover().visit(tree)
32+
test_utils.assert_ast_equal(transformed, expected)

0 commit comments

Comments
 (0)