From b17c91bc8cf7d575d343ada4699f83ba89015062 Mon Sep 17 00:00:00 2001 From: Naman Jain Date: Tue, 26 Nov 2024 07:21:34 +0000 Subject: [PATCH] fix: slicer bug in sympy --- .../pat/dependency_slicer/ast_statements.py | 19 +++-- src/r2e/pat/dependency_slicer/slicer_main.py | 79 +++++++++++++++++-- .../test_globals_finder.py | 14 ++++ 3 files changed, 101 insertions(+), 11 deletions(-) diff --git a/src/r2e/pat/dependency_slicer/ast_statements.py b/src/r2e/pat/dependency_slicer/ast_statements.py index 5a10963..1ac2181 100644 --- a/src/r2e/pat/dependency_slicer/ast_statements.py +++ b/src/r2e/pat/dependency_slicer/ast_statements.py @@ -179,9 +179,10 @@ def build_var_to_stmt_idxs(self) -> dict[str, list[int]]: stmt = stmt_obj.stmt if isinstance(stmt, ast.Assign): for target in stmt.targets: - assigned_target = AstStatements.assigned_expr_name_str(target) - if assigned_target is not None: - var_to_stmt_idxs[assigned_target].append(idx) + assigned_targets = AstStatements.assigned_expr_name_str(target) + if assigned_targets is not None: + for assigned_target in assigned_targets: + var_to_stmt_idxs[assigned_target].append(idx) elif isinstance(stmt, (ast.AugAssign, ast.AnnAssign)): assigned_target = AstStatements.assigned_expr_name_str(stmt.target) if assigned_target is not None: @@ -211,13 +212,21 @@ def find_wildcard_imports(self): return wildcard_stmt_idxs @staticmethod - def assigned_expr_name_str(expr: ast.expr) -> str | None: + def assigned_expr_name_str(expr: ast.expr) -> list[str] | None: if isinstance(expr, ast.Name): - return expr.id + return [expr.id] elif isinstance(expr, ast.Attribute): return AstStatements.assigned_expr_name_str(expr.value) elif isinstance(expr, ast.Subscript): return AstStatements.assigned_expr_name_str(expr.value) + elif isinstance(expr, (ast.Tuple, ast.List)): + assigned_names = [] + for el in expr.elts: + assigned_name = AstStatements.assigned_expr_name_str(el) + if assigned_name is not None: + assigned_names.extend(assigned_name) + return assigned_names + else: return None diff --git a/src/r2e/pat/dependency_slicer/slicer_main.py b/src/r2e/pat/dependency_slicer/slicer_main.py index 77b165f..903ce82 100644 --- a/src/r2e/pat/dependency_slicer/slicer_main.py +++ b/src/r2e/pat/dependency_slicer/slicer_main.py @@ -40,6 +40,7 @@ def __init__( ast_stmt_list: list[AstStatement], file_ast_cache: dict[str, AstStatements], depth: int = -1, + slice_imports: bool = True, ): self.repo = repo self.ast_stmt_list = ast_stmt_list @@ -47,7 +48,6 @@ def __init__( try: self.callgraph_explorer = CallGraphExplorer(self.repo) except Exception as e: - print(repr(e)) self.callgraph_explorer = None self.recursion_stack: list[AstStatement] = [] @@ -57,9 +57,14 @@ def __init__( self.depth = depth + self.slice_imports = slice_imports + @classmethod def from_function_models( - cls, function_models: Function | list[Function], depth: int = -1 + cls, + function_models: Function | list[Function], + depth: int = -1, + slice_imports: bool = True, ): function_models = ( function_models if isinstance(function_models, list) else [function_models] @@ -87,10 +92,15 @@ def from_function_models( assert resolved_function is not None ast_stmt_list.append(resolved_function) - return cls(repo, ast_stmt_list, file_ast_cache, depth) + return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports) @classmethod - def from_class_models(cls, class_models: Class | list[Class], depth: int = -1): + def from_class_models( + cls, + class_models: Class | list[Class], + depth: int = -1, + slice_imports: bool = True, + ): class_models = ( class_models if isinstance(class_models, list) else [class_models] ) @@ -115,7 +125,60 @@ def from_class_models(cls, class_models: Class | list[Class], depth: int = -1): assert resolved_class is not None ast_stmt_list.append(resolved_class) - return cls(repo, ast_stmt_list, file_ast_cache, depth) + return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports) + + @classmethod + def from_funclass_models( + cls, + funclass_models: list[Function | Class], + depth: int = -1, + slice_imports: bool = True, + ): + funclass_models = ( + funclass_models if isinstance(funclass_models, list) else [funclass_models] + ) + assert ( + len(set([f.repo for f in funclass_models])) == 1 + ), f"{[f.repo for f in funclass_models]} are not the same repos" + + repo = funclass_models[0].repo + + file_ast_cache: dict[str, AstStatements] = {} + + ast_stmt_list: list[AstStatement] = [] + for funclass_model in funclass_models: + if isinstance(funclass_model, Function): + assert funclass_model.function_name is not None + if funclass_model.file_path in file_ast_cache: + ast_stmts = file_ast_cache[funclass_model.file_path] + else: + ast_stmts = AstStatements(funclass_model.file) + file_ast_cache[funclass_model.file_path] = ast_stmts + + resolved_function = ast_stmts.find_function_stmt_with_name( + funclass_model.function_name + ) + assert ( + resolved_function is not None + ), f"{funclass_model.function_name} {funclass_model.file_path}" + ast_stmt_list.append(resolved_function) + elif isinstance(funclass_model, Class): + assert funclass_model.class_name is not None + if funclass_model.file_path in file_ast_cache: + ast_stmts = file_ast_cache[funclass_model.file_path] + else: + ast_stmts = AstStatements(funclass_model.file) + file_ast_cache[funclass_model.file_path] = ast_stmts + + resolved_class = ast_stmts.find_class_stmt_with_name( + funclass_model.class_name + ) + assert ( + resolved_class is not None + ), f"{funclass_model.class_name} {funclass_model.file_path}" + ast_stmt_list.append(resolved_class) + + return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports) def run(self): for ast_stmt in self.ast_stmt_list: @@ -133,11 +196,15 @@ def visit( if depth == 0: return - print(f"Visiting {stmt} with depth {depth}") if stmt in self.visited_set or stmt in self.recursion_stack: return + + # print(f"Visiting {stmt.file_path} {stmt.stmt.lineno} looking for {search_key}") + for ast_type, handler in HandlersMapping.items(): if isinstance(stmt.stmt, ast_type): + if (not self.slice_imports) and issubclass(handler, ImportHandler): + return handler_instance = handler(stmt, all_stmts, search_key, self, depth) handler_instance.handle() return diff --git a/tests/pat/slicers/test_globals_finder/test_globals_finder.py b/tests/pat/slicers/test_globals_finder/test_globals_finder.py index 5bf1524..6aee7dc 100644 --- a/tests/pat/slicers/test_globals_finder/test_globals_finder.py +++ b/tests/pat/slicers/test_globals_finder/test_globals_finder.py @@ -343,3 +343,17 @@ def test_formattedvalue(self): def test_namedexpr(self): self.compare("(a := b)", ["b"]) + + def test_sympy_bug(self): + code = r""" +def test_issue_7117(): + # See also issue #5031 (hence the evaluate=False in these). + e = Eq(x + 1, 2*x) + q = Mul(2, e, evaluate=False) + assert latex(q) == r"2 \left(x + 1 = 2 x\right)" + q = Add(6, e, evaluate=False) + assert latex(q) == r"6 + \left(x + 1 = 2 x\right)" + q = Pow(e, 2, evaluate=False) + assert latex(q) == r"\left(x + 1 = 2 x\right)^{2}" +""" + self.compare(code, ["x", "Eq", "Mul", "latex", "Add", "Pow"])