Skip to content

Commit

Permalink
fix: slicer bug in sympy
Browse files Browse the repository at this point in the history
  • Loading branch information
Naman-ntc committed Nov 26, 2024
1 parent 44f2416 commit b17c91b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 11 deletions.
19 changes: 14 additions & 5 deletions src/r2e/pat/dependency_slicer/ast_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ def build_var_to_stmt_idxs(self) -> dict[str, list[int]]:
stmt = stmt_obj.stmt
if isinstance(stmt, ast.Assign):
for target in stmt.targets:
assigned_target = AstStatements.assigned_expr_name_str(target)
if assigned_target is not None:
var_to_stmt_idxs[assigned_target].append(idx)
assigned_targets = AstStatements.assigned_expr_name_str(target)
if assigned_targets is not None:
for assigned_target in assigned_targets:
var_to_stmt_idxs[assigned_target].append(idx)
elif isinstance(stmt, (ast.AugAssign, ast.AnnAssign)):
assigned_target = AstStatements.assigned_expr_name_str(stmt.target)
if assigned_target is not None:
Expand Down Expand Up @@ -211,13 +212,21 @@ def find_wildcard_imports(self):
return wildcard_stmt_idxs

@staticmethod
def assigned_expr_name_str(expr: ast.expr) -> str | None:
def assigned_expr_name_str(expr: ast.expr) -> list[str] | None:
if isinstance(expr, ast.Name):
return expr.id
return [expr.id]
elif isinstance(expr, ast.Attribute):
return AstStatements.assigned_expr_name_str(expr.value)
elif isinstance(expr, ast.Subscript):
return AstStatements.assigned_expr_name_str(expr.value)
elif isinstance(expr, (ast.Tuple, ast.List)):
assigned_names = []
for el in expr.elts:
assigned_name = AstStatements.assigned_expr_name_str(el)
if assigned_name is not None:
assigned_names.extend(assigned_name)
return assigned_names

else:
return None

Expand Down
79 changes: 73 additions & 6 deletions src/r2e/pat/dependency_slicer/slicer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __init__(
ast_stmt_list: list[AstStatement],
file_ast_cache: dict[str, AstStatements],
depth: int = -1,
slice_imports: bool = True,
):
self.repo = repo
self.ast_stmt_list = ast_stmt_list
self.file_ast_cache = file_ast_cache
try:
self.callgraph_explorer = CallGraphExplorer(self.repo)
except Exception as e:
print(repr(e))
self.callgraph_explorer = None

self.recursion_stack: list[AstStatement] = []
Expand All @@ -57,9 +57,14 @@ def __init__(

self.depth = depth

self.slice_imports = slice_imports

@classmethod
def from_function_models(
cls, function_models: Function | list[Function], depth: int = -1
cls,
function_models: Function | list[Function],
depth: int = -1,
slice_imports: bool = True,
):
function_models = (
function_models if isinstance(function_models, list) else [function_models]
Expand Down Expand Up @@ -87,10 +92,15 @@ def from_function_models(
assert resolved_function is not None
ast_stmt_list.append(resolved_function)

return cls(repo, ast_stmt_list, file_ast_cache, depth)
return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports)

@classmethod
def from_class_models(cls, class_models: Class | list[Class], depth: int = -1):
def from_class_models(
cls,
class_models: Class | list[Class],
depth: int = -1,
slice_imports: bool = True,
):
class_models = (
class_models if isinstance(class_models, list) else [class_models]
)
Expand All @@ -115,7 +125,60 @@ def from_class_models(cls, class_models: Class | list[Class], depth: int = -1):
assert resolved_class is not None
ast_stmt_list.append(resolved_class)

return cls(repo, ast_stmt_list, file_ast_cache, depth)
return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports)

@classmethod
def from_funclass_models(
cls,
funclass_models: list[Function | Class],
depth: int = -1,
slice_imports: bool = True,
):
funclass_models = (
funclass_models if isinstance(funclass_models, list) else [funclass_models]
)
assert (
len(set([f.repo for f in funclass_models])) == 1
), f"{[f.repo for f in funclass_models]} are not the same repos"

repo = funclass_models[0].repo

file_ast_cache: dict[str, AstStatements] = {}

ast_stmt_list: list[AstStatement] = []
for funclass_model in funclass_models:
if isinstance(funclass_model, Function):
assert funclass_model.function_name is not None
if funclass_model.file_path in file_ast_cache:
ast_stmts = file_ast_cache[funclass_model.file_path]
else:
ast_stmts = AstStatements(funclass_model.file)
file_ast_cache[funclass_model.file_path] = ast_stmts

resolved_function = ast_stmts.find_function_stmt_with_name(
funclass_model.function_name
)
assert (
resolved_function is not None
), f"{funclass_model.function_name} {funclass_model.file_path}"
ast_stmt_list.append(resolved_function)
elif isinstance(funclass_model, Class):
assert funclass_model.class_name is not None
if funclass_model.file_path in file_ast_cache:
ast_stmts = file_ast_cache[funclass_model.file_path]
else:
ast_stmts = AstStatements(funclass_model.file)
file_ast_cache[funclass_model.file_path] = ast_stmts

resolved_class = ast_stmts.find_class_stmt_with_name(
funclass_model.class_name
)
assert (
resolved_class is not None
), f"{funclass_model.class_name} {funclass_model.file_path}"
ast_stmt_list.append(resolved_class)

return cls(repo, ast_stmt_list, file_ast_cache, depth, slice_imports)

def run(self):
for ast_stmt in self.ast_stmt_list:
Expand All @@ -133,11 +196,15 @@ def visit(
if depth == 0:
return

print(f"Visiting {stmt} with depth {depth}")
if stmt in self.visited_set or stmt in self.recursion_stack:
return

# print(f"Visiting {stmt.file_path} {stmt.stmt.lineno} looking for {search_key}")

for ast_type, handler in HandlersMapping.items():
if isinstance(stmt.stmt, ast_type):
if (not self.slice_imports) and issubclass(handler, ImportHandler):
return
handler_instance = handler(stmt, all_stmts, search_key, self, depth)
handler_instance.handle()
return
Expand Down
14 changes: 14 additions & 0 deletions tests/pat/slicers/test_globals_finder/test_globals_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,17 @@ def test_formattedvalue(self):

def test_namedexpr(self):
self.compare("(a := b)", ["b"])

def test_sympy_bug(self):
code = r"""
def test_issue_7117():
# See also issue #5031 (hence the evaluate=False in these).
e = Eq(x + 1, 2*x)
q = Mul(2, e, evaluate=False)
assert latex(q) == r"2 \left(x + 1 = 2 x\right)"
q = Add(6, e, evaluate=False)
assert latex(q) == r"6 + \left(x + 1 = 2 x\right)"
q = Pow(e, 2, evaluate=False)
assert latex(q) == r"\left(x + 1 = 2 x\right)^{2}"
"""
self.compare(code, ["x", "Eq", "Mul", "latex", "Add", "Pow"])

0 comments on commit b17c91b

Please sign in to comment.