diff --git a/libcst/codemod/tests/test_codemod_cli.py b/libcst/codemod/tests/test_codemod_cli.py index 0fa3dbef..8773cd77 100644 --- a/libcst/codemod/tests/test_codemod_cli.py +++ b/libcst/codemod/tests/test_codemod_cli.py @@ -93,3 +93,34 @@ def baz() -> str: "- 3 warnings were generated.", output.stderr, ) + + def test_matcher_decorators_multiprocessing(self) -> None: + file_count = 5 + code = """ + def baz(): # type: int + return 5 + """ + with tempfile.TemporaryDirectory() as tmpdir: + p = Path(tmpdir) + # Using more than chunksize=4 files to trigger multiprocessing + for i in range(file_count): + (p / f"mod{i}.py").write_text(CodemodTest.make_fixture_data(code)) + output = subprocess.run( + [ + sys.executable, + "-m", + "libcst.tool", + "codemod", + # Good candidate since it uses matcher decorators + "convert_type_comments.ConvertTypeComments", + str(p), + "--jobs", + str(file_count), + ], + encoding="utf-8", + stderr=subprocess.PIPE, + ) + self.assertIn( + f"Transformed {file_count} files successfully.", + output.stderr, + ) diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index 9349c5b5..b9252173 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -60,6 +60,11 @@ class UnionType: } +def is_property(obj: object, attr_name: str) -> bool: + """Check if obj.attr is a property without evaluating it.""" + return isinstance(getattr(type(obj), attr_name, None), property) + + # pyre-ignore We don't care about Any here, its not exposed. def _match_decorator_unpickler(kwargs: Any) -> "MatchDecoratorMismatch": return MatchDecoratorMismatch(**kwargs) @@ -265,20 +270,22 @@ def _check_types( ) -def _gather_matchers(obj: object) -> Set[BaseMatcherNode]: - visit_matchers: Set[BaseMatcherNode] = set() +def _gather_matchers(obj: object) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]: + """ + Set of gating matchers that we need to track and evaluate. We use these + in conjunction with the call_if_inside and call_if_not_inside decorators + to determine whether to call a visit/leave function. + """ - for func in dir(obj): - try: - for matcher in getattr(getattr(obj, func), VISIT_POSITIVE_MATCHER_ATTR, []): - visit_matchers.add(cast(BaseMatcherNode, matcher)) - for matcher in getattr(getattr(obj, func), VISIT_NEGATIVE_MATCHER_ATTR, []): - visit_matchers.add(cast(BaseMatcherNode, matcher)) - except Exception: - # This could be a caculated property, and calling getattr() evaluates it. - # We have no control over the implementation detail, so if it raises, we - # should not crash. - pass + visit_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {} + + for attr_name in dir(obj): + if not is_property(obj, attr_name): + func = getattr(obj, attr_name) + for matcher in getattr(func, VISIT_POSITIVE_MATCHER_ATTR, []): + visit_matchers[cast(BaseMatcherNode, matcher)] = None + for matcher in getattr(func, VISIT_NEGATIVE_MATCHER_ATTR, []): + visit_matchers[cast(BaseMatcherNode, matcher)] = None return visit_matchers @@ -302,16 +309,12 @@ def _gather_constructed_visit_funcs( ] = {} for funcname in dir(obj): - try: - possible_func = getattr(obj, funcname) - if not ismethod(possible_func): - continue - func = cast(Callable[[cst.CSTNode], None], possible_func) - except Exception: - # This could be a caculated property, and calling getattr() evaluates it. - # We have no control over the implementation detail, so if it raises, we - # should not crash. + if is_property(obj, funcname): + continue + possible_func = getattr(obj, funcname) + if not ismethod(possible_func): continue + func = cast(Callable[[cst.CSTNode], None], possible_func) matchers = getattr(func, CONSTRUCTED_VISIT_MATCHER_ATTR, []) if matchers: # Make sure that we aren't accidentally putting a @visit on a visit_Node. @@ -337,16 +340,12 @@ def _gather_constructed_leave_funcs( ] = {} for funcname in dir(obj): - try: - possible_func = getattr(obj, funcname) - if not ismethod(possible_func): - continue - func = cast(Callable[[cst.CSTNode], None], possible_func) - except Exception: - # This could be a caculated property, and calling getattr() evaluates it. - # We have no control over the implementation detail, so if it raises, we - # should not crash. + if is_property(obj, funcname): + continue + possible_func = getattr(obj, funcname) + if not ismethod(possible_func): continue + func = cast(Callable[[cst.CSTNode], None], possible_func) matchers = getattr(func, CONSTRUCTED_LEAVE_MATCHER_ATTR, []) if matchers: # Make sure that we aren't accidentally putting a @leave on a leave_Node. @@ -448,12 +447,7 @@ class MatcherDecoratableTransformer(CSTTransformer): def __init__(self) -> None: CSTTransformer.__init__(self) - # List of gating matchers that we need to track and evaluate. We use these - # in conjuction with the call_if_inside and call_if_not_inside decorators - # to determine whether or not to call a visit/leave function. - self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = { - m: None for m in _gather_matchers(self) - } + self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None # Mapping of matchers to functions. If in the course of visiting the tree, # a node matches one of these matchers, the corresponding function will be # called as if it was a visit_* method. @@ -486,6 +480,16 @@ def __init__(self) -> None: expected_none_return=False, ) + @property + def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]: + if self.__matchers is None: + self.__matchers = _gather_matchers(self) + return self.__matchers + + @_matchers.setter + def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None: + self.__matchers = value + def on_visit(self, node: cst.CSTNode) -> bool: # First, evaluate any matchers that we have which we are not inside already. self._matchers = _visit_matchers(self._matchers, node, self) @@ -660,12 +664,7 @@ class MatcherDecoratableVisitor(CSTVisitor): def __init__(self) -> None: CSTVisitor.__init__(self) - # List of gating matchers that we need to track and evaluate. We use these - # in conjuction with the call_if_inside and call_if_not_inside decorators - # to determine whether or not to call a visit/leave function. - self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = { - m: None for m in _gather_matchers(self) - } + self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None # Mapping of matchers to functions. If in the course of visiting the tree, # a node matches one of these matchers, the corresponding function will be # called as if it was a visit_* method. @@ -693,6 +692,16 @@ def __init__(self) -> None: expected_none_return=True, ) + @property + def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]: + if self.__matchers is None: + self.__matchers = _gather_matchers(self) + return self.__matchers + + @_matchers.setter + def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None: + self.__matchers = value + def on_visit(self, node: cst.CSTNode) -> bool: # First, evaluate any matchers that we have which we are not inside already. self._matchers = _visit_matchers(self._matchers, node, self)