How to deal with overlapping protocols. #1491
-
I'm having trouble annotating a class that can take a map-like object or a sequence-like object and provides random iteration over the keys or integer indices. The main issue is that the protocols overlap and I can't seem to get the How do I need to modify # 3.11
import random
from typing import (
Any,
Generic,
Iterable,
Iterator,
Protocol,
SupportsIndex,
TypeAlias,
TypeVar,
assert_type,
runtime_checkable,
)
import pytest
K = TypeVar("K")
V = TypeVar("V")
V_co = TypeVar("V_co", covariant=True)
Key_contra = TypeVar("Key_contra", contravariant=True)
Int_contra = TypeVar("Int_contra", contravariant=True, bound=int)
@runtime_checkable
class SupportsLenAndGetItem(Protocol[Key_contra, V_co]):
def __len__(self) -> int: ...
def __getitem__(self, key: Key_contra, /) -> V_co: ...
@runtime_checkable
class MapDataset(Protocol[K, V_co]): # indexed by keys.
def __len__(self) -> int: ...
def __getitem__(self, key: K, /) -> V_co: ...
def keys(self) -> Iterable[K]: ...
@runtime_checkable
class IterableDataset(Protocol[Int_contra, V_co]): # Assume indexed by 0...n-1
def __len__(self) -> int: ...
def __getitem__(self, key: Int_contra, /) -> V_co: ...
def __iter__(self) -> Iterator[V_co]: ...
Dataset: TypeAlias = MapDataset[K, V] | IterableDataset[K, V]
class Sampler(Generic[K]):
"""Sample random indices that can be used to sample from a dataset."""
data: SupportsLenAndGetItem[K, Any]
index: list[K]
def __init__(self, data_source: Dataset[K, Any]) -> None:
self.data = data_source
match data_source:
case MapDataset() as map_data: # in this case, K given by the Mapping
self.index = list(map_data.keys())
case IterableDataset() as seq_data: # can we forcibly bind K to int?
self.index = list(range(len(seq_data)))
case _:
raise TypeError
def __iter__(self) -> Iterator[K]:
random.shuffle(self.index)
yield from self.index
def test_map_data_a() -> None:
data: MapDataset[str, str] = {"x": "foo", "y": "bar"}
sampler = Sampler(data)
assert_type(sampler, Sampler[str])
def test_map_data_b() -> None:
data: MapDataset[int, str] = {10: "foo", 11: "bar"}
sampler = Sampler(data)
assert_type(sampler, Sampler[int])
def test_raw_map_data() -> None:
data = {10: "foo", 11: "bar"}
sampler = Sampler(data)
assert_type(sampler, Sampler[int])
def test_seq_data() -> None:
data: IterableDataset[int, str] = ["foo", "bar"]
sampler = Sampler(data)
assert_type(sampler, Sampler[int]) # Possibly Sampler[SupportsIndex]
def test_raw_seq_data() -> None:
data = ["foo", "bar"]
sampler = Sampler(data)
assert_type(sampler, Sampler[int]) # Possibly Sampler[SupportsIndex]
if __name__ == "__main__":
pytest.main() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I would probably use overloads on @overload
def __init__(self: Sampler[int], data_source: IterableDataset[int, Any]) -> None: ...
@overload
def __init__(self: Sampler[K], data_source: MapDataset[K, Any]) -> None: ...
def __init__(self, data_source: Dataset[Any, Any]) -> None: using You could also probably simplify things by getting rid of |
Beta Was this translation helpful? Give feedback.
I would probably use overloads on
__init__
for this:using
Any
in the function annotation instead ofK
avoids type errors in the implementation of__init__
, but you could also just usetype:ignore
comments.You could also probably simplify things by getting rid of
Int_contra
inIterableDataset
I don't think you get anything valuable out of it.