Skip to content

Commit e0ddde4

Browse files
authored
support removal of multiply operator (#182)
1 parent 37d47c9 commit e0ddde4

File tree

6 files changed

+212
-39
lines changed

6 files changed

+212
-39
lines changed

src/integration_tests/algorithmic_style_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def collatz(n):
6363
\If{$n \mathbin{\%} 2 = 0$}
6464
\State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$
6565
\Else
66-
\State $n \gets 3 \cdot n + 1$
66+
\State $n \gets 3 n + 1$
6767
\EndIf
6868
\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$
6969
\EndWhile
@@ -80,7 +80,7 @@ def collatz(n):
8080
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
8181
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
8282
r" \hspace{2em} \mathbf{else} \\"
83-
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
83+
r" \hspace{3em} n \gets 3 n + 1 \\"
8484
r" \hspace{2em} \mathbf{end \ if} \\"
8585
r" \hspace{2em}"
8686
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"

src/integration_tests/regression_test.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ def test_quadratic_solution() -> None:
1111
def solve(a, b, c):
1212
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
1313

14-
latex = (
15-
r"\mathrm{solve}(a, b, c) ="
16-
r" \frac{-b + \sqrt{ b^{2} - 4 \cdot a \cdot c }}{2 \cdot a}"
17-
)
14+
latex = r"\mathrm{solve}(a, b, c) =" r" \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}"
1815
integration_utils.check_function(solve, latex)
1916

2017

@@ -47,7 +44,7 @@ def xtimesbeta(x, beta):
4744
xtimesbeta, latex_without_symbols, use_math_symbols=False
4845
)
4946

50-
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \cdot \beta"
47+
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta"
5148
integration_utils.check_function(
5249
xtimesbeta, latex_with_symbols, use_math_symbols=True
5350
)
@@ -145,7 +142,7 @@ def test_nested_function() -> None:
145142
def nested(x):
146143
return 3 * x
147144

148-
integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 \cdot x")
145+
integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x")
149146

150147

151148
def test_double_nested_function() -> None:
@@ -155,7 +152,7 @@ def inner(y):
155152

156153
return inner
157154

158-
integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x \cdot y")
155+
integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y")
159156

160157

161158
def test_reduce_assignments() -> None:
@@ -165,11 +162,11 @@ def f(x):
165162

166163
integration_utils.check_function(
167164
f,
168-
r"\begin{array}{l} a = x + x \\ f(x) = 3 \cdot a \end{array}",
165+
r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}",
169166
)
170167
integration_utils.check_function(
171168
f,
172-
r"f(x) = 3 \cdot \mathopen{}\left( x + x \mathclose{}\right)",
169+
r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)",
173170
reduce_assignments=True,
174171
)
175172

@@ -184,15 +181,15 @@ def f(x):
184181
r"\begin{array}{l}"
185182
r" a = x^{2} \\"
186183
r" b = a + a \\"
187-
r" f(x) = 3 \cdot b"
184+
r" f(x) = 3 b"
188185
r" \end{array}"
189186
)
190187

191188
integration_utils.check_function(f, latex_without_option)
192189
integration_utils.check_function(f, latex_without_option, reduce_assignments=False)
193190
integration_utils.check_function(
194191
f,
195-
r"f(x) = 3 \cdot \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
192+
r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
196193
reduce_assignments=True,
197194
)
198195

@@ -228,7 +225,7 @@ def solve(a, b):
228225
r"\mathrm{solve}(a, b) ="
229226
r" \frac{a + b - b}{a - b} - \mathopen{}\left("
230227
r" a + b \mathclose{}\right) - \mathopen{}\left("
231-
r" a - b \mathclose{}\right) - a \cdot b"
228+
r" a - b \mathclose{}\right) - a b"
232229
)
233230
integration_utils.check_function(solve, latex)
234231

src/latexify/codegen/expression_codegen.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import ast
6+
import re
67

78
from latexify import analyzers, ast_utils, exceptions
89
from latexify.codegen import codegen_utils, expression_rules, identifier_converter
@@ -406,12 +407,94 @@ def _wrap_binop_operand(
406407

407408
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
408409

410+
_l_bracket_pattern = re.compile(r"^\\mathopen.*")
411+
_r_bracket_pattern = re.compile(r".*\\mathclose[^ ]+$")
412+
_r_word_pattern = re.compile(r"\\mathrm\{[^ ]+\}$")
413+
414+
def _should_remove_multiply_op(
415+
self, l_latex: str, r_latex: str, l_expr: ast.expr, r_expr: ast.expr
416+
):
417+
"""Determine whether the multiply operator should be removed or not.
418+
419+
See also:
420+
https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
421+
422+
This is an ad-hoc implementation.
423+
This function doesn't fully implements the above requirements, but only
424+
essential ones necessary to release v0.3.
425+
"""
426+
427+
# NOTE(odashi): For compatibility with Python 3.7, we compare the generated
428+
# caracter type directly to determine the "numeric" type.
429+
430+
if isinstance(l_expr, ast.Call):
431+
l_type = "f"
432+
elif self._r_bracket_pattern.match(l_latex):
433+
l_type = "b"
434+
elif self._r_word_pattern.match(l_latex):
435+
l_type = "w"
436+
elif l_latex[-1].isnumeric():
437+
l_type = "n"
438+
else:
439+
le = l_expr
440+
while True:
441+
if isinstance(le, ast.UnaryOp):
442+
le = le.operand
443+
elif isinstance(le, ast.BinOp):
444+
le = le.right
445+
elif isinstance(le, ast.Compare):
446+
le = le.comparators[-1]
447+
elif isinstance(le, ast.BoolOp):
448+
le = le.values[-1]
449+
else:
450+
break
451+
l_type = "a" if isinstance(le, ast.Name) and len(le.id) == 1 else "m"
452+
453+
if isinstance(r_expr, ast.Call):
454+
r_type = "f"
455+
elif self._l_bracket_pattern.match(r_latex):
456+
r_type = "b"
457+
elif r_latex.startswith("\\mathrm"):
458+
r_type = "w"
459+
elif r_latex[0].isnumeric():
460+
r_type = "n"
461+
else:
462+
re = r_expr
463+
while True:
464+
if isinstance(re, ast.UnaryOp):
465+
if isinstance(re.op, ast.USub):
466+
# NOTE(odashi): Unary "-" always require \cdot.
467+
return False
468+
re = re.operand
469+
elif isinstance(re, ast.BinOp):
470+
re = re.left
471+
elif isinstance(re, ast.Compare):
472+
re = re.left
473+
elif isinstance(re, ast.BoolOp):
474+
re = re.values[0]
475+
else:
476+
break
477+
r_type = "a" if isinstance(re, ast.Name) and len(re.id) == 1 else "m"
478+
479+
if r_type == "n":
480+
return False
481+
if l_type in "bn":
482+
return True
483+
if l_type in "am" and r_type in "am":
484+
return True
485+
return False
486+
409487
def visit_BinOp(self, node: ast.BinOp) -> str:
410488
"""Visit a BinOp node."""
411489
prec = expression_rules.get_precedence(node)
412490
rule = self._bin_op_rules[type(node.op)]
413491
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
414492
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
493+
494+
if type(node.op) in [ast.Mult, ast.MatMult]:
495+
if self._should_remove_multiply_op(lhs, rhs, node.left, node.right):
496+
return f"{rule.latex_left}{lhs} {rhs}{rule.latex_right}"
497+
415498
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
416499

417500
def visit_UnaryOp(self, node: ast.UnaryOp) -> str:

0 commit comments

Comments
 (0)