Skip to content

Commit 1cfb8e7

Browse files
authored
Added remaining Numpy NDArray single function expressions (#183)
* Added support for single expressions involving the following functions: numpy.linalg.{matrix_power, qr, svd, det, matrix_rank, inv, pinv}. * Fixes for CI tests. * Fixed issues with line lengths and import order. * Refactored code.
1 parent d11334e commit 1cfb8e7

File tree

2 files changed

+297
-2
lines changed

2 files changed

+297
-2
lines changed

src/latexify/codegen/expression_codegen.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(
2323
"""Initializer.
2424
2525
Args:
26-
use_math_symbols: Whether to convert identifiers with a math symbol surface
27-
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
26+
use_math_symbols: Whether to convert identifiers with a math symbol
27+
surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
2828
use_set_symbols: Whether to use set symbols or not.
2929
"""
3030
self._identifier_converter = identifier_converter.IdentifierConverter(
@@ -240,6 +240,130 @@ def _generate_transpose(self, node: ast.Call) -> str | None:
240240
else:
241241
return None
242242

243+
def _generate_determinant(self, node: ast.Call) -> str | None:
244+
"""Generates LaTeX for numpy.linalg.det.
245+
Args:
246+
node: ast.Call node containing the appropriate method invocation.
247+
Returns:
248+
Generated LaTeX, or None if the node has unsupported syntax.
249+
Raises:
250+
LatexifyError: Unsupported argument type given.
251+
"""
252+
name = ast_utils.extract_function_name_or_none(node)
253+
assert name == "det"
254+
255+
if len(node.args) != 1:
256+
return None
257+
258+
func_arg = node.args[0]
259+
if isinstance(func_arg, ast.Name):
260+
arg_id = rf"\mathbf{{{func_arg.id}}}"
261+
return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
262+
elif isinstance(func_arg, ast.List):
263+
matrix = self._generate_matrix(node)
264+
return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
265+
266+
return None
267+
268+
def _generate_matrix_rank(self, node: ast.Call) -> str | None:
269+
"""Generates LaTeX for numpy.linalg.matrix_rank.
270+
Args:
271+
node: ast.Call node containing the appropriate method invocation.
272+
Returns:
273+
Generated LaTeX, or None if the node has unsupported syntax.
274+
Raises:
275+
LatexifyError: Unsupported argument type given.
276+
"""
277+
name = ast_utils.extract_function_name_or_none(node)
278+
assert name == "matrix_rank"
279+
280+
if len(node.args) != 1:
281+
return None
282+
283+
func_arg = node.args[0]
284+
if isinstance(func_arg, ast.Name):
285+
arg_id = rf"\mathbf{{{func_arg.id}}}"
286+
return (
287+
rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
288+
)
289+
elif isinstance(func_arg, ast.List):
290+
matrix = self._generate_matrix(node)
291+
return (
292+
rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
293+
)
294+
295+
return None
296+
297+
def _generate_matrix_power(self, node: ast.Call) -> str | None:
298+
"""Generates LaTeX for numpy.linalg.matrix_power.
299+
Args:
300+
node: ast.Call node containing the appropriate method invocation.
301+
Returns:
302+
Generated LaTeX, or None if the node has unsupported syntax.
303+
Raises:
304+
LatexifyError: Unsupported argument type given.
305+
"""
306+
name = ast_utils.extract_function_name_or_none(node)
307+
assert name == "matrix_power"
308+
309+
if len(node.args) != 2:
310+
return None
311+
312+
func_arg = node.args[0]
313+
power_arg = node.args[1]
314+
if isinstance(power_arg, ast.Num):
315+
if isinstance(func_arg, ast.Name):
316+
return rf"\mathbf{{{func_arg.id}}}^{{{power_arg.n}}}"
317+
elif isinstance(func_arg, ast.List):
318+
matrix = self._generate_matrix(node)
319+
if matrix is not None:
320+
return rf"{matrix}^{{{power_arg.n}}}"
321+
return None
322+
323+
def _generate_inv(self, node: ast.Call) -> str | None:
324+
"""Generates LaTeX for numpy.linalg.inv.
325+
Args:
326+
node: ast.Call node containing the appropriate method invocation.
327+
Returns:
328+
Generated LaTeX, or None if the node has unsupported syntax.
329+
Raises:
330+
LatexifyError: Unsupported argument type given.
331+
"""
332+
name = ast_utils.extract_function_name_or_none(node)
333+
assert name == "inv"
334+
335+
if len(node.args) != 1:
336+
return None
337+
338+
func_arg = node.args[0]
339+
if isinstance(func_arg, ast.Name):
340+
return rf"\mathbf{{{func_arg.id}}}^{{-1}}"
341+
elif isinstance(func_arg, ast.List):
342+
return rf"{self._generate_matrix(node)}^{{-1}}"
343+
return None
344+
345+
def _generate_pinv(self, node: ast.Call) -> str | None:
346+
"""Generates LaTeX for numpy.linalg.pinv.
347+
Args:
348+
node: ast.Call node containing the appropriate method invocation.
349+
Returns:
350+
Generated LaTeX, or None if the node has unsupported syntax.
351+
Raises:
352+
LatexifyError: Unsupported argument type given.
353+
"""
354+
name = ast_utils.extract_function_name_or_none(node)
355+
assert name == "pinv"
356+
357+
if len(node.args) != 1:
358+
return None
359+
360+
func_arg = node.args[0]
361+
if isinstance(func_arg, ast.Name):
362+
return rf"\mathbf{{{func_arg.id}}}^{{+}}"
363+
elif isinstance(func_arg, ast.List):
364+
return rf"{self._generate_matrix(node)}^{{+}}"
365+
return None
366+
243367
def visit_Call(self, node: ast.Call) -> str:
244368
"""Visit a Call node."""
245369
func_name = ast_utils.extract_function_name_or_none(node)
@@ -256,6 +380,16 @@ def visit_Call(self, node: ast.Call) -> str:
256380
special_latex = self._generate_identity(node)
257381
elif func_name == "transpose":
258382
special_latex = self._generate_transpose(node)
383+
elif func_name == "det":
384+
special_latex = self._generate_determinant(node)
385+
elif func_name == "matrix_rank":
386+
special_latex = self._generate_matrix_rank(node)
387+
elif func_name == "matrix_power":
388+
special_latex = self._generate_matrix_power(node)
389+
elif func_name == "inv":
390+
special_latex = self._generate_inv(node)
391+
elif func_name == "pinv":
392+
special_latex = self._generate_pinv(node)
259393
else:
260394
special_latex = None
261395

src/latexify/codegen/expression_codegen_test.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,167 @@ def test_transpose(code: str, latex: str) -> None:
995995
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
996996

997997

998+
@pytest.mark.parametrize(
999+
"code,latex",
1000+
[
1001+
("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"),
1002+
("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"),
1003+
(
1004+
"det([[1, 2], [3, 4]])",
1005+
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
1006+
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
1007+
),
1008+
(
1009+
"det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
1010+
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
1011+
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
1012+
),
1013+
# Unsupported
1014+
("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"),
1015+
("det(2)", r"\mathrm{det} \mathopen{}\left( 2 \mathclose{}\right)"),
1016+
(
1017+
"det(a, (1, 0))",
1018+
r"\mathrm{det} \mathopen{}\left( a, "
1019+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
1020+
),
1021+
],
1022+
)
1023+
def test_determinant(code: str, latex: str) -> None:
1024+
tree = ast_utils.parse_expr(code)
1025+
assert isinstance(tree, ast.Call)
1026+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
1027+
1028+
1029+
@pytest.mark.parametrize(
1030+
"code,latex",
1031+
[
1032+
(
1033+
"matrix_rank(A)",
1034+
r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)",
1035+
),
1036+
(
1037+
"matrix_rank(b)",
1038+
r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)",
1039+
),
1040+
(
1041+
"matrix_rank([[1, 2], [3, 4]])",
1042+
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
1043+
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
1044+
),
1045+
(
1046+
"matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
1047+
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
1048+
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
1049+
),
1050+
# Unsupported
1051+
(
1052+
"matrix_rank()",
1053+
r"\mathrm{matrix\_rank} \mathopen{}\left( \mathclose{}\right)",
1054+
),
1055+
(
1056+
"matrix_rank(2)",
1057+
r"\mathrm{matrix\_rank} \mathopen{}\left( 2 \mathclose{}\right)",
1058+
),
1059+
(
1060+
"matrix_rank(a, (1, 0))",
1061+
r"\mathrm{matrix\_rank} \mathopen{}\left( a, "
1062+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
1063+
),
1064+
],
1065+
)
1066+
def test_matrix_rank(code: str, latex: str) -> None:
1067+
tree = ast_utils.parse_expr(code)
1068+
assert isinstance(tree, ast.Call)
1069+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
1070+
1071+
1072+
@pytest.mark.parametrize(
1073+
"code,latex",
1074+
[
1075+
("matrix_power(A, 2)", r"\mathbf{A}^{2}"),
1076+
("matrix_power(b, 2)", r"\mathbf{b}^{2}"),
1077+
(
1078+
"matrix_power([[1, 2], [3, 4]], 2)",
1079+
r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{2}",
1080+
),
1081+
(
1082+
"matrix_power([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 42)",
1083+
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{42}",
1084+
),
1085+
# Unsupported
1086+
(
1087+
"matrix_power()",
1088+
r"\mathrm{matrix\_power} \mathopen{}\left( \mathclose{}\right)",
1089+
),
1090+
(
1091+
"matrix_power(2)",
1092+
r"\mathrm{matrix\_power} \mathopen{}\left( 2 \mathclose{}\right)",
1093+
),
1094+
(
1095+
"matrix_power(a, (1, 0))",
1096+
r"\mathrm{matrix\_power} \mathopen{}\left( a, "
1097+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
1098+
),
1099+
],
1100+
)
1101+
def test_matrix_power(code: str, latex: str) -> None:
1102+
tree = ast_utils.parse_expr(code)
1103+
assert isinstance(tree, ast.Call)
1104+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
1105+
1106+
1107+
@pytest.mark.parametrize(
1108+
"code,latex",
1109+
[
1110+
("inv(A)", r"\mathbf{A}^{-1}"),
1111+
("inv(b)", r"\mathbf{b}^{-1}"),
1112+
("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"),
1113+
(
1114+
"inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
1115+
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}",
1116+
),
1117+
# Unsupported
1118+
("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"),
1119+
("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"),
1120+
(
1121+
"inv(a, (1, 0))",
1122+
r"\mathrm{inv} \mathopen{}\left( a, "
1123+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
1124+
),
1125+
],
1126+
)
1127+
def test_inv(code: str, latex: str) -> None:
1128+
tree = ast_utils.parse_expr(code)
1129+
assert isinstance(tree, ast.Call)
1130+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
1131+
1132+
1133+
@pytest.mark.parametrize(
1134+
"code,latex",
1135+
[
1136+
("pinv(A)", r"\mathbf{A}^{+}"),
1137+
("pinv(b)", r"\mathbf{b}^{+}"),
1138+
("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"),
1139+
(
1140+
"pinv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
1141+
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{+}",
1142+
),
1143+
# Unsupported
1144+
("pinv()", r"\mathrm{pinv} \mathopen{}\left( \mathclose{}\right)"),
1145+
("pinv(2)", r"\mathrm{pinv} \mathopen{}\left( 2 \mathclose{}\right)"),
1146+
(
1147+
"pinv(a, (1, 0))",
1148+
r"\mathrm{pinv} \mathopen{}\left( a, "
1149+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
1150+
),
1151+
],
1152+
)
1153+
def test_pinv(code: str, latex: str) -> None:
1154+
tree = ast_utils.parse_expr(code)
1155+
assert isinstance(tree, ast.Call)
1156+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
1157+
1158+
9981159
# Check list for #89.
9991160
# https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
10001161
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)