33import libcst as cst
44from libcst .metadata import QualifiedNameProvider , WhitespaceInclusivePositionProvider
55from libcst .codemod .visitors import ImportItem
6- from typing import Optional , List , Set , Union
6+ from typing import Optional , List , Set , Tuple , Union
77from abc import ABC
88
99IS_TTY = hasattr (sys .stdout , "isatty" ) and sys .stdout .isatty ()
@@ -83,19 +83,34 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]:
8383
8484def call_with_name_changes (
8585 node : cst .Call , old_qualified_name : str , new_qualified_name : str
86- ) -> Optional [cst .Call ]:
86+ ) -> Optional [Tuple [ cst .Call , Set [ ImportItem ]] ]:
8787 """
88- Return new `Call` node with name changes.
88+ Return an optional tuple:
89+ new `Call` node with name changes
90+ and a set of newly needed imports.
8991 """
9092 old_begin , _ , old_last = old_qualified_name .rpartition ("." )
9193 new_begin , _ , new_last = new_qualified_name .rpartition ("." )
94+ needed_imports : Set [ImportItem ] = set ()
9295
9396 # If the only difference is the last name part.
9497 if old_begin == new_begin :
95- replacement = node .with_deep_changes (
96- old_node = cst .ensure_type (node .func , cst .Attribute ).attr ,
97- value = new_last ,
98- )
98+ if isinstance (node .func , cst .Attribute ):
99+ replacement = node .with_deep_changes (
100+ old_node = node .func .attr ,
101+ value = new_last ,
102+ )
103+ elif isinstance (node .func , cst .Name ):
104+ replacement = node .with_deep_changes (
105+ old_node = node .func ,
106+ value = new_last ,
107+ )
108+ needed_imports .add (
109+ ImportItem (
110+ module_name = new_begin ,
111+ obj_name = new_last ,
112+ )
113+ )
99114
100115 # If the last name part is the same and
101116 # originally called without a dot: don't change the call site,
@@ -106,7 +121,10 @@ def call_with_name_changes(
106121 # Replace with new_qualified_name.
107122 else :
108123 replacement = node .with_changes (func = cst .parse_expression (new_qualified_name ))
109- return replacement
124+ if replacement is None :
125+ return None
126+ else :
127+ return replacement , needed_imports
110128
111129
112130def deep_multi_replace (tree , replacement_map ):
0 commit comments