Skip to content

Commit

Permalink
Allow TypedDict to inherit from Generic but mark generic variables as…
Browse files Browse the repository at this point in the history
… Any

PiperOrigin-RevId: 688141419
  • Loading branch information
oprypin authored and copybara-github committed Nov 26, 2024
1 parent 42144f9 commit 903ef8a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
19 changes: 16 additions & 3 deletions pytype/overlays/typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ def optional(self):
def add(self, k, v, total):
"""Adds key and value."""
req = _is_required(v)
if req is None:
if isinstance(v, abstract.TypeParameter):
# TODO(b/328744430): Properly support generic TypedDicts.
# For now, mark all type parameters as Any to avoid erroring out entirely.
value = v.ctx.convert.unsolvable
elif req is None:
value = v
elif isinstance(v, abstract.ParameterizedClass):
value = v.formal_type_parameters[abstract_utils.T]
else:
value = v.ctx.convert.unsolvable
required = total if req is None else req
self.fields[k] = value # pylint: disable=unsupported-assignment-operation
self.fields[k] = value
if required:
self.required.add(k)

Expand Down Expand Up @@ -118,7 +122,16 @@ def _validate_bases(self, cls_name, bases):
"""Check that all base classes are valid."""
for base_var in bases:
for base in base_var.data:
if not isinstance(base, (TypedDictClass, TypedDictBuilder)):
# Allow inheriting only from TypedDict and from Generic, e.g.:
# `class Foo(TypedDict, Generic[T])`
if not (
isinstance(base, (TypedDictClass, TypedDictBuilder))
or (
isinstance(base, abstract.ParameterizedClass)
and isinstance(base.base_cls, (abstract.PyTDClass))
and base.base_cls.full_name == "typing.Generic"
)
):
details = (
f"TypedDict {cls_name} cannot inherit from a non-TypedDict class."
)
Expand Down
33 changes: 33 additions & 0 deletions pytype/tests/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,17 @@ class X(TypedDict, total=False):
""",
)

self.Check("""
from typing import Any, Generic, TypedDict, TypeVar, is_typeddict
T = TypeVar('T')
class Foo(TypedDict, Generic[T]):
foo: str
bar: T
x: Foo
assert_type(x['foo'], str)
assert_type(x['bar'], Any) # TODO(b/328744430): Properly support generics.
""")

def test_ambiguous_field_type(self):
self.CheckWithErrors("""
from typing_extensions import TypedDict
Expand Down Expand Up @@ -1014,6 +1025,28 @@ class Y:
""",
)

def test_generic(self):
ty = self.Infer("""
from typing import Generic, TypedDict, TypeVar, is_typeddict
T = TypeVar('T')
class X(TypedDict, Generic[T]):
bar: T
if is_typeddict(X):
X_is_typeddict = True
else:
X_is_not_typeddict = True
""")
self.assertTypesMatchPytd(
ty,
"""
from typing import Any, TypeVar, TypedDict
T = TypeVar('T')
class X(TypedDict):
bar: Any # TODO(b/328744430): Properly support generic TypedDicts.
X_is_typeddict: bool
""",
)

def test_ambiguous(self):
ty = self.Infer("""
from typing_extensions import is_typeddict
Expand Down

0 comments on commit 903ef8a

Please sign in to comment.