Skip to content

Commit 428b1e1

Browse files
authored
Fix bug with function name replacement (#23)
1 parent bfe27bc commit 428b1e1

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
2+
from torch import ger
23
deprecated = torch.norm()
34
sinusoid_inp = torch.ger(pos_seq, inv_freq)
45
other = something.ger(pos_seq, inv_freq)
56
deprecated = torch.norm()
67
one_more = torch.ger(pos_seq, inv_freq)
8+
9+
just_name = ger(pos_seq, inv_freq)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
2+
from torch import outer, ger
23
deprecated = torch.norm()
34
sinusoid_inp = torch.outer(pos_seq, inv_freq)
45
other = something.ger(pos_seq, inv_freq)
56
deprecated = torch.norm()
67
one_more = torch.outer(pos_seq, inv_freq)
8+
9+
just_name = outer(pos_seq, inv_freq)

torchfix/common.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import libcst as cst
44
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider
55
from libcst.codemod.visitors import ImportItem
6-
from typing import Optional, List, Set, Union
6+
from typing import Optional, List, Set, Tuple, Union
77
from abc import ABC
88

99
IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
@@ -83,19 +83,34 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]:
8383

8484
def call_with_name_changes(
8585
node: cst.Call, old_qualified_name: str, new_qualified_name: str
86-
) -> Optional[cst.Call]:
86+
) -> Optional[Tuple[cst.Call, Set[ImportItem]]]:
8787
"""
88-
Return new `Call` node with name changes.
88+
Return an optional tuple:
89+
new `Call` node with name changes
90+
and a set of newly needed imports.
8991
"""
9092
old_begin, _, old_last = old_qualified_name.rpartition(".")
9193
new_begin, _, new_last = new_qualified_name.rpartition(".")
94+
needed_imports: Set[ImportItem] = set()
9295

9396
# If the only difference is the last name part.
9497
if old_begin == new_begin:
95-
replacement = node.with_deep_changes(
96-
old_node=cst.ensure_type(node.func, cst.Attribute).attr,
97-
value=new_last,
98-
)
98+
if isinstance(node.func, cst.Attribute):
99+
replacement = node.with_deep_changes(
100+
old_node=node.func.attr,
101+
value=new_last,
102+
)
103+
elif isinstance(node.func, cst.Name):
104+
replacement = node.with_deep_changes(
105+
old_node=node.func,
106+
value=new_last,
107+
)
108+
needed_imports.add(
109+
ImportItem(
110+
module_name=new_begin,
111+
obj_name=new_last,
112+
)
113+
)
99114

100115
# If the last name part is the same and
101116
# originally called without a dot: don't change the call site,
@@ -106,7 +121,10 @@ def call_with_name_changes(
106121
# Replace with new_qualified_name.
107122
else:
108123
replacement = node.with_changes(func=cst.parse_expression(new_qualified_name))
109-
return replacement
124+
if replacement is None:
125+
return None
126+
else:
127+
return replacement, needed_imports
110128

111129

112130
def deep_multi_replace(tree, replacement_map):

torchfix/visitors/deprecated_symbols/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def _call_replacement(
4949
qualified_name, {}
5050
).get("replacement", "")
5151
if function_name_replacement:
52-
replacement = call_with_name_changes(
52+
replacement_and_imports = call_with_name_changes(
5353
node, qualified_name, function_name_replacement
5454
)
55-
55+
if replacement_and_imports is not None:
56+
replacement, imports = replacement_and_imports
57+
self.needed_imports.update(imports)
5658
return replacement
5759

5860
def visit_Call(self, node):

0 commit comments

Comments
 (0)