From 99e35b40ec8f3d5280fa106955e194e4c511e325 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 28 Nov 2024 10:54:46 -0800 Subject: [PATCH] Cache the result of `sub_one_annotation` which happens during the construction of "BadType". Also update types on the signatures along the way. When a bad match of type happens, it doesn't necessarily mean that there would be a type error. It means that out of many possible signature matches (or type matches) between different types including generics, there is something that doesn't match and we cannot use that result. This result will be used later when producing diagnostics to indicate which pair of types mismatch, in the case where there is no single type match out of those combinations. The problem is that when the combinations (e.g. multiple signatures, generics, recursive types) that needs to be verified need to be type checked, it tries out so many different combinations of types, and generate tons of the same "BadMatch" which is just redundant. Making it even worse, it's not only the computation but also the memory pressure caused by this because it seems to construct something which is memory intensive, thus GC frequently showing up in profile in these particular cases. There is still some risk that this might increase the peak memory consumption because once put in the cache, the objects will have the same lifetime as the type checker, but the drastic performance improvement in these particular cases seem worth it. A better solution might be actually making the type checker not call into generating these seemingly redundant type checks on these crazy combinations, but at a glance this seemed like a fundamental design issue that lives in the complex nature of the type graph, and without being able to understand / changing it this seems to be the shortest path without hurting code health too much. PiperOrigin-RevId: 701062879 --- pytype/abstract/abstract_utils.py | 6 ++-- pytype/annotation_utils.py | 49 ++++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/pytype/abstract/abstract_utils.py b/pytype/abstract/abstract_utils.py index 2aa6b0672..3073ae59b 100644 --- a/pytype/abstract/abstract_utils.py +++ b/pytype/abstract/abstract_utils.py @@ -841,9 +841,9 @@ def is_generic_protocol(val: "_base.BaseValue") -> bool: def combine_substs( - substs1: Collection[dict[str, cfg.Variable]] | None, - substs2: Collection[dict[str, cfg.Variable]] | None, -) -> Collection[dict[str, cfg.Variable]]: + substs1: Sequence[Mapping[str, cfg.Variable]] | None, + substs2: Sequence[Mapping[str, cfg.Variable]] | None, +) -> Sequence[dict[str, cfg.Variable]]: """Combines the two collections of type parameter substitutions.""" if substs1 and substs2: return tuple({**sub1, **sub2} for sub1 in substs1 for sub2 in substs2) # pylint: disable=g-complex-comprehension diff --git a/pytype/annotation_utils.py b/pytype/annotation_utils.py index 298eca77f..743f6cfaf 100644 --- a/pytype/annotation_utils.py +++ b/pytype/annotation_utils.py @@ -1,7 +1,7 @@ """Utilities for inline type annotations.""" import collections -from collections.abc import Sequence +from collections.abc import Callable, Mapping, Sequence import dataclasses import itertools from typing import Any @@ -27,6 +27,20 @@ class AnnotatedValue: class AnnotationUtils(utils.ContextWeakrefMixin): """Utility class for inline type annotations.""" + def __init__(self, ctx): + super().__init__(ctx) + # calling sub_one_annotation is costly, due to calling multiple of chained + # constructors (via annot.replace) and generating complex data structure. + # And in some corner cases which includes recursive generic types with + # overloads, it causes massive call to construction of bad match which calls + # sub_one_annotations. + # A better solution might be not to make those seemingly redundant request + # from the type checker, but for now this is a comprimise to gain + # performance in those weird corner cases. + self.annotation_sub_cache: dict[ + tuple[cfg.CFGNode, abstract.BaseValue], abstract.BaseValue + ] = dict() + def sub_annotations(self, node, annotations, substs, instantiate_unbound): """Apply type parameter substitutions to a dictionary of annotations.""" if substs and all(substs): @@ -42,7 +56,7 @@ def _get_type_parameter_subst( self, node: cfg.CFGNode, annot: abstract.TypeParameter, - substs: Sequence[dict[str, cfg.Variable]], + substs: Sequence[Mapping[str, cfg.Variable]], instantiate_unbound: bool, ) -> abstract.BaseValue: """Helper for sub_one_annotation.""" @@ -68,16 +82,37 @@ def _get_type_parameter_subst( vals = [annot] return self.ctx.convert.merge_classes(vals) - def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True): + def sub_one_annotation( + self, + node: cfg.CFGNode, + annot: abstract.BaseValue, + substs: Sequence[Mapping[str, cfg.Variable]], + instantiate_unbound: bool = True, + ): def get_type_parameter_subst(annotation): return self._get_type_parameter_subst( node, annotation, substs, instantiate_unbound ) - return self._do_sub_one_annotation(node, annot, get_type_parameter_subst) + if not substs or (len(substs) == 1 and not substs[0]): + res = self.annotation_sub_cache.get((node, annot), None) + if res: + return res - def _do_sub_one_annotation(self, node, annot, get_type_parameter_subst_fn): + res = self._do_sub_one_annotation(node, annot, get_type_parameter_subst) + if not substs or not substs[0]: + self.annotation_sub_cache[(node, annot)] = res + return res + + def _do_sub_one_annotation( + self, + node: cfg.CFGNode, + annot: abstract.BaseValue, + get_type_parameter_subst_fn: Callable[ + [abstract.BaseValue], abstract.BaseValue + ], + ): """Apply type parameter substitutions to an annotation.""" # We push annotations onto 'stack' and move them to the 'done' stack as they # are processed. For each annotation, we also track an 'inner_type_keys' @@ -92,7 +127,7 @@ def _do_sub_one_annotation(self, node, annot, get_type_parameter_subst_fn): done = [] while stack: cur, inner_type_keys = stack.pop() - if not cur.formal: + if not cur.formal: # pytype: disable=attribute-error done.append(cur) elif isinstance(cur, mixin.NestedAnnotation): if cur.is_late_annotation() and any(t[0] == cur for t in stack): @@ -411,7 +446,7 @@ def _sub_and_instantiate(self, node, name, typ, substs): class_substs = abstract_utils.combine_substs( substs, [{"typing.Self": self.ctx.vm.frame.first_arg}] ) - type_for_value = self.sub_one_annotation( + type_for_value = self.sub_one_annotation( # pytype: disable=wrong-arg-types node, typ, class_substs, instantiate_unbound=False ) else: